Skip to main content

minijinja_contrib/
pycompat.rs

1use minijinja::value::{from_args, ValueKind};
2use minijinja::{format_filter, Error, ErrorKind, FormatStyle, State, Value};
3
4/// An unknown method callback implementing python methods on primitives.
5///
6/// This implements a lot of Python methods on basic types so that the
7/// compatibility with Jinja2 templates improves.
8///
9/// ```
10/// use minijinja::Environment;
11/// use minijinja_contrib::pycompat::unknown_method_callback;
12///
13/// let mut env = Environment::new();
14/// env.set_unknown_method_callback(unknown_method_callback);
15/// ```
16///
17/// Today the following methods are implemented:
18///
19/// * `dict.get`
20/// * `dict.items`
21/// * `dict.keys`
22/// * `dict.values`
23/// * `list.count`
24/// * `str.capitalize`
25/// * `str.count`
26/// * `str.endswith`
27/// * `str.find`
28/// * `str.format`
29/// * `str.isalnum`
30/// * `str.isalpha`
31/// * `str.isascii`
32/// * `str.isdigit`
33/// * `str.islower`
34/// * `str.isnumeric`
35/// * `str.isupper`
36/// * `str.join`
37/// * `str.lower`
38/// * `str.lstrip`
39/// * `str.replace`
40/// * `str.rfind`
41/// * `str.rstrip`
42/// * `str.split`
43/// * `str.splitlines`
44/// * `str.startswith`
45/// * `str.strip`
46/// * `str.title`
47/// * `str.upper`
48#[cfg_attr(docsrs, doc(cfg(feature = "pycompat")))]
49pub fn unknown_method_callback(
50    _state: &State,
51    value: &Value,
52    method: &str,
53    args: &[Value],
54) -> Result<Value, Error> {
55    match value.kind() {
56        ValueKind::String => string_methods(value, method, args),
57        ValueKind::Map => map_methods(value, method, args),
58        ValueKind::Seq => seq_methods(value, method, args),
59        _ => Err(Error::from(ErrorKind::UnknownMethod)),
60    }
61}
62
63fn string_methods(value: &Value, method: &str, args: &[Value]) -> Result<Value, Error> {
64    let Some(s) = value.as_str() else {
65        return Err(Error::from(ErrorKind::UnknownMethod));
66    };
67
68    match method {
69        "upper" => {
70            let () = from_args(args)?;
71            Ok(Value::from(s.to_uppercase()))
72        }
73        "lower" => {
74            let () = from_args(args)?;
75            Ok(Value::from(s.to_lowercase()))
76        }
77        "islower" => {
78            let () = from_args(args)?;
79            Ok(Value::from(s.chars().all(|x| x.is_lowercase())))
80        }
81        "isupper" => {
82            let () = from_args(args)?;
83            Ok(Value::from(s.chars().all(|x| x.is_uppercase())))
84        }
85        "isspace" => {
86            let () = from_args(args)?;
87            Ok(Value::from(s.chars().all(|x| x.is_whitespace())))
88        }
89        "isdigit" | "isnumeric" => {
90            // this is not a perfect mapping to what Python does, but
91            // close enough for most uses in templates.
92            let () = from_args(args)?;
93            Ok(Value::from(s.chars().all(|x| x.is_numeric())))
94        }
95        "isalnum" => {
96            let () = from_args(args)?;
97            Ok(Value::from(s.chars().all(|x| x.is_alphanumeric())))
98        }
99        "isalpha" => {
100            let () = from_args(args)?;
101            Ok(Value::from(s.chars().all(|x| x.is_alphabetic())))
102        }
103        "isascii" => {
104            let () = from_args(args)?;
105            Ok(Value::from(s.is_ascii()))
106        }
107        "strip" => {
108            let (chars,): (Option<&str>,) = from_args(args)?;
109            Ok(Value::from(if let Some(chars) = chars {
110                s.trim_matches(&chars.chars().collect::<Vec<_>>()[..])
111            } else {
112                s.trim()
113            }))
114        }
115        "lstrip" => {
116            let (chars,): (Option<&str>,) = from_args(args)?;
117            Ok(Value::from(if let Some(chars) = chars {
118                s.trim_start_matches(&chars.chars().collect::<Vec<_>>()[..])
119            } else {
120                s.trim_start()
121            }))
122        }
123        "rstrip" => {
124            let (chars,): (Option<&str>,) = from_args(args)?;
125            Ok(Value::from(if let Some(chars) = chars {
126                s.trim_end_matches(&chars.chars().collect::<Vec<_>>()[..])
127            } else {
128                s.trim_end()
129            }))
130        }
131        "replace" => {
132            let (old, new, count): (&str, &str, Option<i32>) = from_args(args)?;
133            let count = count.unwrap_or(-1);
134            Ok(Value::from(if count < 0 {
135                s.replace(old, new)
136            } else {
137                s.replacen(old, new, count as usize)
138            }))
139        }
140        "title" => {
141            let () = from_args(args)?;
142            // one shall not call into these filters.  However we consider ourselves
143            // privileged.
144            Ok(Value::from(minijinja::filters::title(s.into())))
145        }
146        "split" => {
147            let (sep, maxsplits) = from_args(args)?;
148            // one shall not call into these filters.  However we consider ourselves
149            // privileged.
150            Ok(minijinja::filters::split(s.into(), sep, maxsplits)
151                .try_iter()?
152                .collect::<Value>())
153        }
154        "splitlines" => {
155            let (keepends,): (Option<bool>,) = from_args(args)?;
156            if !keepends.unwrap_or(false) {
157                Ok(s.lines().map(Value::from).collect())
158            } else {
159                let mut rv = Vec::new();
160                let mut rest = s;
161                while let Some(offset) = rest.find('\n') {
162                    rv.push(Value::from(&rest[..offset + 1]));
163                    rest = &rest[offset + 1..];
164                }
165                if !rest.is_empty() {
166                    rv.push(Value::from(rest));
167                }
168                Ok(Value::from(rv))
169            }
170        }
171        "capitalize" => {
172            let () = from_args(args)?;
173            // one shall not call into these filters.  However we consider ourselves
174            // privileged.
175            Ok(Value::from(minijinja::filters::capitalize(s.into())))
176        }
177        "count" => {
178            let (what,): (&str,) = from_args(args)?;
179            let mut c = 0;
180            let mut rest = s;
181            while let Some(offset) = rest.find(what) {
182                c += 1;
183                rest = &rest[offset + what.len()..];
184            }
185            Ok(Value::from(c))
186        }
187        "find" => {
188            let (what,): (&str,) = from_args(args)?;
189            Ok(Value::from(match s.find(what) {
190                Some(x) => x as i64,
191                None => -1,
192            }))
193        }
194        "format" => format_filter(FormatStyle::StrFormat, s, args).map(Value::from),
195        "rfind" => {
196            let (what,): (&str,) = from_args(args)?;
197            Ok(Value::from(match s.rfind(what) {
198                Some(x) => x as i64,
199                None => -1,
200            }))
201        }
202        "startswith" => {
203            let (prefix,): (&Value,) = from_args(args)?;
204            if let Some(prefix) = prefix.as_str() {
205                Ok(Value::from(s.starts_with(prefix)))
206            } else if matches!(prefix.kind(), ValueKind::Iterable | ValueKind::Seq) {
207                for prefix in prefix.try_iter()? {
208                    if s.starts_with(prefix.as_str().ok_or_else(|| {
209                        Error::new(
210                            ErrorKind::InvalidOperation,
211                            format!(
212                                "tuple for startswith must contain only strings, not {}",
213                                prefix.kind()
214                            ),
215                        )
216                    })?) {
217                        return Ok(Value::from(true));
218                    }
219                }
220                Ok(Value::from(false))
221            } else {
222                Err(Error::new(
223                    ErrorKind::InvalidOperation,
224                    format!(
225                        "startswith argument must be string or a tuple of strings, not {}",
226                        prefix.kind()
227                    ),
228                ))
229            }
230        }
231        "endswith" => {
232            let (suffix,): (&Value,) = from_args(args)?;
233            if let Some(suffix) = suffix.as_str() {
234                Ok(Value::from(s.ends_with(suffix)))
235            } else if matches!(suffix.kind(), ValueKind::Iterable | ValueKind::Seq) {
236                for suffix in suffix.try_iter()? {
237                    if s.ends_with(suffix.as_str().ok_or_else(|| {
238                        Error::new(
239                            ErrorKind::InvalidOperation,
240                            format!(
241                                "tuple for endswith must contain only strings, not {}",
242                                suffix.kind()
243                            ),
244                        )
245                    })?) {
246                        return Ok(Value::from(true));
247                    }
248                }
249                Ok(Value::from(false))
250            } else {
251                Err(Error::new(
252                    ErrorKind::InvalidOperation,
253                    format!(
254                        "endswith argument must be string or a tuple of strings, not {}",
255                        suffix.kind()
256                    ),
257                ))
258            }
259        }
260        "join" => {
261            use std::fmt::Write;
262            let (values,): (&Value,) = from_args(args)?;
263            let mut rv = String::new();
264            for (idx, value) in values.try_iter()?.enumerate() {
265                if idx > 0 {
266                    rv.push_str(s);
267                }
268                write!(rv, "{value}").ok();
269            }
270            Ok(Value::from(rv))
271        }
272        _ => Err(Error::from(ErrorKind::UnknownMethod)),
273    }
274}
275
276fn map_methods(value: &Value, method: &str, args: &[Value]) -> Result<Value, Error> {
277    let Some(obj) = value.as_object() else {
278        return Err(Error::from(ErrorKind::UnknownMethod));
279    };
280
281    match method {
282        "keys" => {
283            let () = from_args(args)?;
284            Ok(Value::make_object_iterable(obj.clone(), |obj| {
285                match obj.try_iter() {
286                    Some(iter) => iter,
287                    None => Box::new(None.into_iter()),
288                }
289            }))
290        }
291        "values" => {
292            let () = from_args(args)?;
293            Ok(Value::make_object_iterable(obj.clone(), |obj| {
294                match obj.try_iter_pairs() {
295                    Some(iter) => Box::new(iter.map(|(_, v)| v)),
296                    None => Box::new(None.into_iter()),
297                }
298            }))
299        }
300        "items" => {
301            let () = from_args(args)?;
302            Ok(Value::make_object_iterable(obj.clone(), |obj| {
303                match obj.try_iter_pairs() {
304                    Some(iter) => Box::new(iter.map(|(k, v)| Value::from(vec![k, v]))),
305                    None => Box::new(None.into_iter()),
306                }
307            }))
308        }
309        "get" => {
310            let (key, default): (&Value, Option<Value>) = from_args(args)?;
311            Ok(match obj.get_value(key) {
312                Some(value) => value,
313                None => default.unwrap_or_else(|| Value::from(())),
314            })
315        }
316        _ => Err(Error::from(ErrorKind::UnknownMethod)),
317    }
318}
319
320fn seq_methods(value: &Value, method: &str, args: &[Value]) -> Result<Value, Error> {
321    let Some(obj) = value.as_object() else {
322        return Err(Error::from(ErrorKind::UnknownMethod));
323    };
324
325    match method {
326        "count" => {
327            let (what,): (&Value,) = from_args(args)?;
328            Ok(Value::from(if let Some(iter) = obj.try_iter() {
329                iter.filter(|x| x == what).count()
330            } else {
331                0
332            }))
333        }
334        _ => Err(Error::from(ErrorKind::UnknownMethod)),
335    }
336}