arithmetic_eval/fns/wrapper/
mod.rs

1//! Wrapper for eloquent `NativeFn` definitions.
2
3use core::{fmt, marker::PhantomData};
4
5use crate::{
6    alloc::Vec, error::AuxErrorInfo, CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue,
7};
8
9mod traits;
10
11pub use self::traits::{
12    ErrorOutput, FromValueError, FromValueErrorKind, FromValueErrorLocation, IntoEvalResult,
13    TryFromValue,
14};
15
16/// Wraps a function enriching it with the information about its arguments.
17/// This is a slightly shorter way to create wrappers compared to calling [`FnWrapper::new()`].
18///
19/// See [`FnWrapper`] for more details on function requirements.
20pub const fn wrap<T, F>(function: F) -> FnWrapper<T, F> {
21    FnWrapper::new(function)
22}
23
24/// Wrapper of a function containing information about its arguments.
25///
26/// Using `FnWrapper` allows to define [native functions](NativeFn) with minimum boilerplate
27/// and with increased type safety. `FnWrapper`s can be constructed explcitly or indirectly
28/// via [`Environment::insert_wrapped_fn()`], [`Value::wrapped_fn()`], or [`wrap()`].
29///
30/// Arguments of a wrapped function must implement [`TryFromValue`] trait for the applicable
31/// grammar, and the output type must implement [`IntoEvalResult`]. If arguments and/or output
32/// have non-`'static` lifetime, use the [`wrap_fn`] macro. If you need [`CallContext`] (e.g.,
33/// to call functions provided as an argument), use the [`wrap_fn_with_context`] macro.
34///
35/// [`Environment::insert_wrapped_fn()`]: crate::Environment::insert_wrapped_fn()
36/// [`wrap_fn`]: crate::wrap_fn
37/// [`wrap_fn_with_context`]: crate::wrap_fn_with_context
38/// [`Value::wrapped_fn()`]: crate::Value::wrapped_fn()
39///
40/// # Examples
41///
42/// ## Basic function
43///
44/// ```
45/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
46/// use arithmetic_eval::{fns, Environment, Value, VariableMap};
47///
48/// # fn main() -> anyhow::Result<()> {
49/// let max = fns::wrap(|x: f32, y: f32| if x > y { x } else { y });
50///
51/// let program = "max(1, 3) == 3 && max(-1, -3) == -1";
52/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
53/// let module = Environment::new()
54///     .insert_native_fn("max", max)
55///     .compile_module("test_max", &program)?;
56/// assert_eq!(module.run()?, Value::Bool(true));
57/// # Ok(())
58/// # }
59/// ```
60///
61/// ## Fallible function with complex args
62///
63/// ```
64/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
65/// # use arithmetic_eval::{fns::FnWrapper, Environment, Value, VariableMap};
66/// fn zip_arrays(xs: Vec<f32>, ys: Vec<f32>) -> Result<Vec<(f32, f32)>, String> {
67///     if xs.len() == ys.len() {
68///         Ok(xs.into_iter().zip(ys).map(|(x, y)| (x, y)).collect())
69///     } else {
70///         Err("Arrays must have the same size".to_owned())
71///     }
72/// }
73///
74/// # fn main() -> anyhow::Result<()> {
75/// let program = "(1, 2, 3).zip((4, 5, 6)) == ((1, 4), (2, 5), (3, 6))";
76/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
77///
78/// let module = Environment::new()
79///     .insert_wrapped_fn("zip", zip_arrays)
80///     .compile_module("test_zip", &program)?;
81/// assert_eq!(module.run()?, Value::Bool(true));
82/// # Ok(())
83/// # }
84/// ```
85pub struct FnWrapper<T, F> {
86    function: F,
87    _arg_types: PhantomData<T>,
88}
89
90impl<T, F> fmt::Debug for FnWrapper<T, F>
91where
92    F: fmt::Debug,
93{
94    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
95        formatter
96            .debug_struct("FnWrapper")
97            .field("function", &self.function)
98            .finish()
99    }
100}
101
102impl<T, F: Clone> Clone for FnWrapper<T, F> {
103    fn clone(&self) -> Self {
104        Self {
105            function: self.function.clone(),
106            _arg_types: PhantomData,
107        }
108    }
109}
110
111impl<T, F: Copy> Copy for FnWrapper<T, F> {}
112
113// Ideally, we would want to constrain `T` and `F`, but this would make it impossible to declare
114// the constructor as `const fn`; see https://github.com/rust-lang/rust/issues/57563.
115impl<T, F> FnWrapper<T, F> {
116    /// Creates a new wrapper.
117    ///
118    /// Note that the created wrapper is not guaranteed to be usable as [`NativeFn`]. For this
119    /// to be the case, `function` needs to be a function or an [`Fn`] closure,
120    /// and the `T` type argument needs to be a tuple with the function return type
121    /// and the argument types (in this order).
122    ///
123    /// [`NativeFn`]: crate::NativeFn
124    pub const fn new(function: F) -> Self {
125        Self {
126            function,
127            _arg_types: PhantomData,
128        }
129    }
130}
131
132macro_rules! arity_fn {
133    ($arity:tt => $($arg_name:ident : $t:ident),*) => {
134        impl<Num, F, Ret, $($t,)*> NativeFn<Num> for FnWrapper<(Ret, $($t,)*), F>
135        where
136            F: Fn($($t,)*) -> Ret,
137            $($t: for<'val> TryFromValue<'val, Num>,)*
138            Ret: for<'val> IntoEvalResult<'val, Num>,
139        {
140            #[allow(clippy::shadow_unrelated)] // makes it easier to write macro
141            #[allow(unused_variables, unused_mut)] // `args_iter` is unused for 0-ary functions
142            fn evaluate<'a>(
143                &self,
144                args: Vec<SpannedValue<'a, Num>>,
145                context: &mut CallContext<'_, 'a, Num>,
146            ) -> EvalResult<'a, Num> {
147                context.check_args_count(&args, $arity)?;
148                let mut args_iter = args.into_iter().enumerate();
149
150                $(
151                    let (index, $arg_name) = args_iter.next().unwrap();
152                    let span = $arg_name.with_no_extra();
153                    let $arg_name = $t::try_from_value($arg_name.extra).map_err(|mut err| {
154                        err.set_arg_index(index);
155                        context
156                            .call_site_error(ErrorKind::Wrapper(err))
157                            .with_span(&span, AuxErrorInfo::InvalidArg)
158                    })?;
159                )*
160
161                let output = (self.function)($($arg_name,)*);
162                output.into_eval_result().map_err(|err| err.into_spanned(context))
163            }
164        }
165    };
166}
167
168arity_fn!(0 =>);
169arity_fn!(1 => x0: T);
170arity_fn!(2 => x0: T, x1: U);
171arity_fn!(3 => x0: T, x1: U, x2: V);
172arity_fn!(4 => x0: T, x1: U, x2: V, x3: W);
173arity_fn!(5 => x0: T, x1: U, x2: V, x3: W, x4: X);
174arity_fn!(6 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
175arity_fn!(7 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
176arity_fn!(8 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
177arity_fn!(9 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
178arity_fn!(10 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
179
180/// Unary function wrapper.
181pub type Unary<T> = FnWrapper<(T, T), fn(T) -> T>;
182
183/// Binary function wrapper.
184pub type Binary<T> = FnWrapper<(T, T, T), fn(T, T) -> T>;
185
186/// Ternary function wrapper.
187pub type Ternary<T> = FnWrapper<(T, T, T, T), fn(T, T, T) -> T>;
188
189/// Quaternary function wrapper.
190pub type Quaternary<T> = FnWrapper<(T, T, T, T, T), fn(T, T, T, T) -> T>;
191
192/// An alternative for [`wrap`] function which works for arguments / return results with
193/// non-`'static` lifetime.
194///
195/// The macro must be called with 2 arguments (in this order):
196///
197/// - Function arity (from 0 to 10 inclusive)
198/// - Function or closure with the specified number of arguments. Using a function is recommended;
199///   using a closure may lead to hard-to-debug type inference errors.
200///
201/// As with `wrap`, all function arguments must implement [`TryFromValue`] and the return result
202/// must implement [`IntoEvalResult`]. Unlike `wrap`, the arguments / return result do not
203/// need to have a `'static` lifetime; examples include [`Value`]s, [`Function`]s
204/// and [`EvalResult`]s. Lifetimes of all arguments and the return result must match.
205///
206/// [`Value`]: crate::Value
207/// [`Function`]: crate::Function
208///
209/// # Examples
210///
211/// ```
212/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
213/// # use arithmetic_eval::{wrap_fn, Function, Environment, Value, VariableMap};
214/// fn is_function<T>(value: Value<'_, T>) -> bool {
215///     value.is_function()
216/// }
217///
218/// # fn main() -> anyhow::Result<()> {
219/// let program = "is_function(is_function) && !is_function(1)";
220/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
221///
222/// let module = Environment::new()
223///     .insert_native_fn("is_function", wrap_fn!(1, is_function))
224///     .compile_module("test", &program)?;
225/// assert_eq!(module.run()?, Value::Bool(true));
226/// # Ok(())
227/// # }
228/// ```
229///
230/// Usage of lifetimes:
231///
232/// ```
233/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
234/// # use arithmetic_eval::{
235/// #     wrap_fn, CallContext, Function, Environment, Prelude, Value, VariableMap,
236/// # };
237/// # use core::iter::FromIterator;
238/// // Note that both `Value`s have the same lifetime due to elision.
239/// fn take_if<T>(value: Value<'_, T>, condition: bool) -> Value<'_, T> {
240///     if condition { value } else { Value::void() }
241/// }
242///
243/// # fn main() -> anyhow::Result<()> {
244/// let program = "(1, 2).take_if(true) == (1, 2) && (3, 4).take_if(false) != (3, 4)";
245/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
246///
247/// let module = Environment::from_iter(Prelude.iter())
248///     .insert_native_fn("take_if", wrap_fn!(2, take_if))
249///     .compile_module("test_take_if", &program)?;
250/// assert_eq!(module.run()?, Value::Bool(true));
251/// # Ok(())
252/// # }
253/// ```
254#[macro_export]
255macro_rules! wrap_fn {
256    (0, $function:expr) => { $crate::wrap_fn!(@arg 0 =>; $function) };
257    (1, $function:expr) => { $crate::wrap_fn!(@arg 1 => x0; $function) };
258    (2, $function:expr) => { $crate::wrap_fn!(@arg 2 => x0, x1; $function) };
259    (3, $function:expr) => { $crate::wrap_fn!(@arg 3 => x0, x1, x2; $function) };
260    (4, $function:expr) => { $crate::wrap_fn!(@arg 4 => x0, x1, x2, x3; $function) };
261    (5, $function:expr) => { $crate::wrap_fn!(@arg 5 => x0, x1, x2, x3, x4; $function) };
262    (6, $function:expr) => { $crate::wrap_fn!(@arg 6 => x0, x1, x2, x3, x4, x5; $function) };
263    (7, $function:expr) => { $crate::wrap_fn!(@arg 7 => x0, x1, x2, x3, x4, x5, x6; $function) };
264    (8, $function:expr) => {
265        $crate::wrap_fn!(@arg 8 => x0, x1, x2, x3, x4, x5, x6, x7; $function)
266    };
267    (9, $function:expr) => {
268        $crate::wrap_fn!(@arg 9 => x0, x1, x2, x3, x4, x5, x6, x7, x8; $function)
269    };
270    (10, $function:expr) => {
271        $crate::wrap_fn!(@arg 10 => x0, x1, x2, x3, x4, x5, x6, x7, x8, x9; $function)
272    };
273
274    ($($ctx:ident,)? @arg $arity:expr => $($arg_name:ident),*; $function:expr) => {{
275        let function = $function;
276        $crate::fns::enforce_closure_type(move |args, context| {
277            context.check_args_count(&args, $arity)?;
278            let mut args_iter = args.into_iter().enumerate();
279
280            $(
281                let (index, $arg_name) = args_iter.next().unwrap();
282                let span = $arg_name.with_no_extra();
283                let $arg_name = $crate::fns::TryFromValue::try_from_value($arg_name.extra)
284                    .map_err(|mut err| {
285                        err.set_arg_index(index);
286                        context
287                            .call_site_error($crate::error::ErrorKind::Wrapper(err))
288                            .with_span(&span, $crate::error::AuxErrorInfo::InvalidArg)
289                    })?;
290            )+
291
292            // We need `$ctx` just as a marker that the function receives a context.
293            let output = function($({ let $ctx = (); context },)? $($arg_name,)+);
294            $crate::fns::IntoEvalResult::into_eval_result(output)
295                .map_err(|err| err.into_spanned(context))
296        })
297    }}
298}
299
300/// Analogue of [`wrap_fn`](crate::wrap_fn) macro that injects the [`CallContext`]
301/// as the first argument. This can be used to call functions within the implementation.
302///
303/// As with `wrap_fn`, this macro must be called with 2 args: the arity of the function
304/// (**excluding** `CallContext`), and then the function / closure itself.
305///
306/// # Examples
307///
308/// ```
309/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
310/// # use arithmetic_eval::{
311/// #     wrap_fn_with_context, CallContext, Function, Environment, Value, Error, VariableMap,
312/// # };
313/// fn map_array<'a>(
314///     context: &mut CallContext<'_, 'a, f32>,
315///     array: Vec<Value<'a, f32>>,
316///     map_fn: Function<'a, f32>,
317/// ) -> Result<Vec<Value<'a, f32>>, Error<'a>> {
318///     array
319///         .into_iter()
320///         .map(|value| {
321///             let arg = context.apply_call_span(value);
322///             map_fn.evaluate(vec![arg], context)
323///         })
324///         .collect()
325/// }
326///
327/// # fn main() -> anyhow::Result<()> {
328/// let program = "(1, 2, 3).map(|x| x + 3) == (4, 5, 6)";
329/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
330///
331/// let module = Environment::new()
332///     .insert_native_fn("map", wrap_fn_with_context!(2, map_array))
333///     .compile_module("test_map", &program)?;
334/// assert_eq!(module.run()?, Value::Bool(true));
335/// # Ok(())
336/// # }
337/// ```
338#[macro_export]
339macro_rules! wrap_fn_with_context {
340    (0, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 0 =>; $function) };
341    (1, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 1 => x0; $function) };
342    (2, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 2 => x0, x1; $function) };
343    (3, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 3 => x0, x1, x2; $function) };
344    (4, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 4 => x0, x1, x2, x3; $function) };
345    (5, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 5 => x0, x1, x2, x3, x4; $function) };
346    (6, $function:expr) => {
347        $crate::wrap_fn!(_ctx, @arg 6 => x0, x1, x2, x3, x4, x5; $function)
348    };
349    (7, $function:expr) => {
350        $crate::wrap_fn!(_ctx, @arg 7 => x0, x1, x2, x3, x4, x5, x6; $function)
351    };
352    (8, $function:expr) => {
353        $crate::wrap_fn!(_ctx, @arg 8 => x0, x1, x2, x3, x4, x5, x6, x7; $function)
354    };
355    (9, $function:expr) => {
356        $crate::wrap_fn!(_ctx, @arg 9 => x0, x1, x2, x3, x4, x5, x6, x7, x8; $function)
357    };
358    (10, $function:expr) => {
359        $crate::wrap_fn!(_ctx, @arg 10 => x0, x1, x2, x3, x4, x5, x6, x7, x8, x9; $function)
360    };
361}
362
363#[doc(hidden)] // necessary for `wrap_fn` macro
364pub fn enforce_closure_type<T, A, F>(function: F) -> F
365where
366    F: for<'a> Fn(Vec<SpannedValue<'a, T>>, &mut CallContext<'_, 'a, A>) -> EvalResult<'a, T>,
367{
368    function
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::{
375        alloc::{format, ToOwned},
376        Environment, ExecutableModule, Prelude, Value, WildcardId,
377    };
378
379    use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
380    use assert_matches::assert_matches;
381
382    #[test]
383    fn functions_with_primitive_args() {
384        let unary_fn = Unary::new(|x: f32| x + 3.0);
385        let binary_fn = Binary::new(f32::min);
386        let ternary_fn = Ternary::new(|x: f32, y, z| if x > 0.0 { y } else { z });
387
388        let program = r#"
389            unary_fn(2) == 5 && binary_fn(1, -3) == -3 &&
390                ternary_fn(1, 2, 3) == 2 && ternary_fn(-1, 2, 3) == 3
391        "#;
392        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
393
394        let module = ExecutableModule::builder(WildcardId, &block)
395            .unwrap()
396            .with_import("unary_fn", Value::native_fn(unary_fn))
397            .with_import("binary_fn", Value::native_fn(binary_fn))
398            .with_import("ternary_fn", Value::native_fn(ternary_fn))
399            .build();
400        assert_eq!(module.run().unwrap(), Value::Bool(true));
401    }
402
403    fn array_min_max(values: Vec<f32>) -> (f32, f32) {
404        let mut min = f32::INFINITY;
405        let mut max = f32::NEG_INFINITY;
406
407        for value in values {
408            if value < min {
409                min = value;
410            }
411            if value > max {
412                max = value;
413            }
414        }
415        (min, max)
416    }
417
418    fn overly_convoluted_fn(xs: Vec<(f32, f32)>, ys: (Vec<f32>, f32)) -> f32 {
419        xs.into_iter().map(|(a, b)| a + b).sum::<f32>() + ys.0.into_iter().sum::<f32>() + ys.1
420    }
421
422    #[test]
423    fn functions_with_composite_args() {
424        let program = r#"
425            (1, 5, -3, 2, 1).array_min_max() == (-3, 5) &&
426                total_sum(((1, 2), (3, 4)), ((5, 6, 7), 8)) == 36
427        "#;
428        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
429
430        let module = ExecutableModule::builder(WildcardId, &block)
431            .unwrap()
432            .with_import("array_min_max", Value::wrapped_fn(array_min_max))
433            .with_import("total_sum", Value::wrapped_fn(overly_convoluted_fn))
434            .build();
435        assert_eq!(module.run().unwrap(), Value::Bool(true));
436    }
437
438    fn sum_arrays(xs: Vec<f32>, ys: Vec<f32>) -> Result<Vec<f32>, String> {
439        if xs.len() == ys.len() {
440            Ok(xs.into_iter().zip(ys).map(|(x, y)| x + y).collect())
441        } else {
442            Err("Summed arrays must have the same size".to_owned())
443        }
444    }
445
446    #[test]
447    fn fallible_function() {
448        let program = "(1, 2, 3).sum_arrays((4, 5, 6)) == (5, 7, 9)";
449        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
450        let module = ExecutableModule::builder(WildcardId, &block)
451            .unwrap()
452            .with_import("sum_arrays", Value::wrapped_fn(sum_arrays))
453            .build();
454        assert_eq!(module.run().unwrap(), Value::Bool(true));
455    }
456
457    #[test]
458    fn fallible_function_with_bogus_program() {
459        let program = "(1, 2, 3).sum_arrays((4, 5))";
460        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
461
462        let err = ExecutableModule::builder(WildcardId, &block)
463            .unwrap()
464            .with_import("sum_arrays", Value::wrapped_fn(sum_arrays))
465            .build()
466            .run()
467            .unwrap_err();
468        assert!(err
469            .source()
470            .kind()
471            .to_short_string()
472            .contains("Summed arrays must have the same size"));
473    }
474
475    #[test]
476    fn function_with_bool_return_value() {
477        let contains = wrap(|(a, b): (f32, f32), x: f32| (a..=b).contains(&x));
478
479        let program = "(-1, 2).contains(0) && !(1, 3).contains(0)";
480        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
481        let module = ExecutableModule::builder(WildcardId, &block)
482            .unwrap()
483            .with_import("contains", Value::native_fn(contains))
484            .build();
485        assert_eq!(module.run().unwrap(), Value::Bool(true));
486    }
487
488    #[test]
489    fn function_with_void_return_value() {
490        let mut env = Environment::new();
491        env.insert_wrapped_fn("assert_eq", |expected: f32, actual: f32| {
492            if (expected - actual).abs() < f32::EPSILON {
493                Ok(())
494            } else {
495                Err(format!(
496                    "Assertion failed: expected {}, got {}",
497                    expected, actual
498                ))
499            }
500        });
501
502        let program = "assert_eq(3, 1 + 2)";
503        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
504        let module = ExecutableModule::builder(WildcardId, &block)
505            .unwrap()
506            .with_imports_from(&env)
507            .build();
508        assert!(module.run().unwrap().is_void());
509
510        let bogus_program = "assert_eq(3, 1 - 2)";
511        let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program).unwrap();
512        let err = ExecutableModule::builder(WildcardId, &bogus_block)
513            .unwrap()
514            .with_imports_from(&env)
515            .build()
516            .run()
517            .unwrap_err();
518        assert_matches!(
519            err.source().kind(),
520            ErrorKind::NativeCall(ref msg) if msg.contains("Assertion failed")
521        );
522    }
523
524    #[test]
525    fn function_with_bool_argument() {
526        let program = "flip_sign(-1, true) == 1 && flip_sign(-1, false) == -1";
527        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
528
529        let module = ExecutableModule::builder(WildcardId, &block)
530            .unwrap()
531            .with_imports_from(&Prelude)
532            .with_import(
533                "flip_sign",
534                Value::wrapped_fn(|val: f32, flag: bool| if flag { -val } else { val }),
535            )
536            .build();
537        assert_eq!(module.run().unwrap(), Value::Bool(true));
538    }
539
540    #[test]
541    fn error_reporting_with_destructuring() {
542        let program = "((true, 1), (2, 3)).destructure()";
543        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
544
545        let err = ExecutableModule::builder(WildcardId, &block)
546            .unwrap()
547            .with_imports_from(&Prelude)
548            .with_import(
549                "destructure",
550                Value::wrapped_fn(|values: Vec<(bool, f32)>| {
551                    values
552                        .into_iter()
553                        .map(|(flag, x)| if flag { x } else { 0.0 })
554                        .sum::<f32>()
555                }),
556            )
557            .build()
558            .run()
559            .unwrap_err();
560
561        let err_message = err.source().kind().to_short_string();
562        assert!(err_message.contains("Cannot convert primitive value to bool"));
563        assert!(err_message.contains("location: arg0[1].0"));
564    }
565}