arithmetic_eval/fns/
mod.rs

1//! Standard functions for the interpreter, and the tools to define new native functions.
2//!
3//! # Defining native functions
4//!
5//! There are several ways to define new native functions:
6//!
7//! - Implement [`NativeFn`] manually. This is the most versatile approach, but it can be overly
8//!   verbose.
9//! - Use [`FnWrapper`] or the [`wrap`] function. This allows specifying arguments / output
10//!   with custom types (such as `bool` or a [`Number`]), but does not work for non-`'static`
11//!   types.
12//! - Use [`wrap_fn`](crate::wrap_fn) or [`wrap_fn_with_context`](crate::wrap_fn_with_context)
13//!   macros. These macros support
14//!   the same eloquent interface as `wrap`, and also do not have `'static` requirement for args.
15//!   As a downside, debugging compile-time errors when using macros can be rather painful.
16//!
17//! ## Why multiple ways to do the same thing?
18//!
19//! In the ideal world, `FnWrapper` would be used for all cases, since it does not involve
20//! macro magic. Unfortunately, stable Rust currently does not provide means to describe
21//! lifetime restrictions on args / return type of wrapped functions in the general case
22//! (this requires [generic associated types][GAT]). As such, the (implicit) `'static` requirement
23//! is a temporary measure, and macros fill the gaps in their usual clunky manner.
24//!
25//! [`Number`]: crate::Number
26//! [GAT]: https://github.com/rust-lang/rust/issues/44265
27
28use core::cmp::Ordering;
29
30use crate::{
31    alloc::Vec, error::AuxErrorInfo, CallContext, Error, ErrorKind, EvalResult, Function, NativeFn,
32    SpannedValue, Value,
33};
34
35mod array;
36mod assertions;
37mod flow;
38#[cfg(feature = "std")]
39mod std;
40mod wrapper;
41
42#[cfg(feature = "std")]
43pub use self::std::Dbg;
44pub use self::{
45    array::{Array, Filter, Fold, Len, Map, Merge, Push},
46    assertions::{Assert, AssertEq},
47    flow::{If, Loop, While},
48    wrapper::{
49        enforce_closure_type, wrap, Binary, ErrorOutput, FnWrapper, FromValueError,
50        FromValueErrorKind, FromValueErrorLocation, IntoEvalResult, Quaternary, Ternary,
51        TryFromValue, Unary,
52    },
53};
54
55fn extract_primitive<'a, T, A>(
56    ctx: &CallContext<'_, 'a, A>,
57    value: SpannedValue<'a, T>,
58    error_msg: &str,
59) -> Result<T, Error<'a>> {
60    match value.extra {
61        Value::Prim(value) => Ok(value),
62        _ => Err(ctx
63            .call_site_error(ErrorKind::native(error_msg))
64            .with_span(&value, AuxErrorInfo::InvalidArg)),
65    }
66}
67
68fn extract_array<'a, T, A>(
69    ctx: &CallContext<'_, 'a, A>,
70    value: SpannedValue<'a, T>,
71    error_msg: &str,
72) -> Result<Vec<Value<'a, T>>, Error<'a>> {
73    if let Value::Tuple(array) = value.extra {
74        Ok(array)
75    } else {
76        let err = ErrorKind::native(error_msg);
77        Err(ctx
78            .call_site_error(err)
79            .with_span(&value, AuxErrorInfo::InvalidArg))
80    }
81}
82
83fn extract_fn<'a, T, A>(
84    ctx: &CallContext<'_, 'a, A>,
85    value: SpannedValue<'a, T>,
86    error_msg: &str,
87) -> Result<Function<'a, T>, Error<'a>> {
88    if let Value::Function(function) = value.extra {
89        Ok(function)
90    } else {
91        let err = ErrorKind::native(error_msg);
92        Err(ctx
93            .call_site_error(err)
94            .with_span(&value, AuxErrorInfo::InvalidArg))
95    }
96}
97
98/// Comparator functions on two primitive arguments. All functions use [`Arithmetic`] to determine
99/// ordering between the args.
100///
101/// # Type
102///
103/// ```text
104/// fn(Num, Num) -> Ordering // for `Compare::Raw`
105/// fn(Num, Num) -> Num // for `Compare::Min` and `Compare::Max`
106/// ```
107///
108/// [`Arithmetic`]: crate::arith::Arithmetic
109///
110/// # Examples
111///
112/// Using `min` function:
113///
114/// ```
115/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
116/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
117/// # fn main() -> anyhow::Result<()> {
118/// let program = r#"
119///     // Finds a minimum number in an array.
120///     extended_min = |...xs| xs.fold(INFINITY, min);
121///     extended_min(2, -3, 7, 1, 3) == -3
122/// "#;
123/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
124///
125/// let module = Environment::new()
126///     .insert("INFINITY", Value::Prim(f32::INFINITY))
127///     .insert_native_fn("fold", fns::Fold)
128///     .insert_native_fn("min", fns::Compare::Min)
129///     .compile_module("test_min", &program)?;
130/// assert_eq!(module.run()?, Value::Bool(true));
131/// # Ok(())
132/// # }
133/// ```
134///
135/// Using `cmp` function with [`Comparisons`](crate::Comparisons).
136///
137/// ```
138/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
139/// # use arithmetic_eval::{fns, Comparisons, Environment, Value, VariableMap};
140/// # use core::iter::FromIterator;
141/// # fn main() -> anyhow::Result<()> {
142/// let program = r#"
143///     (1, -7, 0, 2).map(|x| cmp(x, 0)) == (GREATER, LESS, EQUAL, GREATER)
144/// "#;
145/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
146///
147/// let module = Environment::from_iter(Comparisons.iter())
148///     .insert_native_fn("map", fns::Map)
149///     .compile_module("test_cmp", &program)?;
150/// assert_eq!(module.run()?, Value::Bool(true));
151/// # Ok(())
152/// # }
153/// ```
154#[derive(Debug, Clone, Copy)]
155#[non_exhaustive]
156pub enum Compare {
157    /// Returns an [`Ordering`] wrapped into an [`OpaqueRef`](crate::OpaqueRef),
158    /// or [`Value::void()`] if the provided values are not comparable.
159    Raw,
160    /// Returns the minimum of the two values. If the values are equal / not comparable, returns the first one.
161    Min,
162    /// Returns the maximum of the two values. If the values are equal / not comparable, returns the first one.
163    Max,
164}
165
166impl Compare {
167    fn extract_primitives<'a, T>(
168        mut args: Vec<SpannedValue<'a, T>>,
169        ctx: &mut CallContext<'_, 'a, T>,
170    ) -> Result<(T, T), Error<'a>> {
171        ctx.check_args_count(&args, 2)?;
172        let y = args.pop().unwrap();
173        let x = args.pop().unwrap();
174        let x = extract_primitive(ctx, x, COMPARE_ERROR_MSG)?;
175        let y = extract_primitive(ctx, y, COMPARE_ERROR_MSG)?;
176        Ok((x, y))
177    }
178}
179
180const COMPARE_ERROR_MSG: &str = "Compare requires 2 primitive arguments";
181
182impl<T> NativeFn<T> for Compare {
183    fn evaluate<'a>(
184        &self,
185        args: Vec<SpannedValue<'a, T>>,
186        ctx: &mut CallContext<'_, 'a, T>,
187    ) -> EvalResult<'a, T> {
188        let (x, y) = Self::extract_primitives(args, ctx)?;
189        let maybe_ordering = ctx.arithmetic().partial_cmp(&x, &y);
190
191        if let Self::Raw = self {
192            Ok(maybe_ordering.map_or_else(Value::void, Value::opaque_ref))
193        } else {
194            let ordering =
195                maybe_ordering.ok_or_else(|| ctx.call_site_error(ErrorKind::CannotCompare))?;
196            let value = match (ordering, self) {
197                (Ordering::Equal, _)
198                | (Ordering::Less, Self::Min)
199                | (Ordering::Greater, Self::Max) => x,
200                _ => y,
201            };
202            Ok(Value::Prim(value))
203        }
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::{Environment, ExecutableModule, WildcardId};
211
212    use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
213    use assert_matches::assert_matches;
214
215    #[test]
216    fn if_basic() {
217        let block = r#"
218            x = 1.0;
219            if(x < 2, x + 5, 3 - x)
220        "#;
221        let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
222        let module = ExecutableModule::builder(WildcardId, &block)
223            .unwrap()
224            .with_import("if", Value::native_fn(If))
225            .build();
226        assert_eq!(module.run().unwrap(), Value::Prim(6.0));
227    }
228
229    #[test]
230    fn if_with_closures() {
231        let block = r#"
232            x = 4.5;
233            if(x < 2, || x + 5, || 3 - x)()
234        "#;
235        let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
236        let module = ExecutableModule::builder(WildcardId, &block)
237            .unwrap()
238            .with_import("if", Value::native_fn(If))
239            .build();
240        assert_eq!(module.run().unwrap(), Value::Prim(-1.5));
241    }
242
243    #[test]
244    fn cmp_sugar() {
245        let program = "x = 1.0; x > 0 && x <= 3";
246        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
247        let module = ExecutableModule::builder(WildcardId, &block)
248            .unwrap()
249            .build();
250        assert_eq!(module.run().unwrap(), Value::Bool(true));
251
252        let bogus_program = "x = 1.0; x > (1, 2)";
253        let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program).unwrap();
254        let bogus_module = ExecutableModule::builder(WildcardId, &bogus_block)
255            .unwrap()
256            .build();
257
258        let err = bogus_module.run().unwrap_err();
259        let err = err.source();
260        assert_matches!(err.kind(), ErrorKind::CannotCompare);
261        assert_eq!(*err.main_span().code().fragment(), "(1, 2)");
262    }
263
264    #[test]
265    fn loop_basic() {
266        let program = r#"
267            // Finds the greatest power of 2 lesser or equal to the value.
268            discrete_log2 = |x| {
269                loop(0, |i| {
270                    continue = 2^i <= x;
271                    (continue, if(continue, i + 1, i - 1))
272                })
273            };
274
275            (discrete_log2(1), discrete_log2(2),
276                discrete_log2(4), discrete_log2(6.5), discrete_log2(1000))
277        "#;
278        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
279
280        let module = ExecutableModule::builder(WildcardId, &block)
281            .unwrap()
282            .with_import("loop", Value::native_fn(Loop))
283            .with_import("if", Value::native_fn(If))
284            .build();
285
286        assert_eq!(
287            module.run().unwrap(),
288            Value::Tuple(vec![
289                Value::Prim(0.0),
290                Value::Prim(1.0),
291                Value::Prim(2.0),
292                Value::Prim(2.0),
293                Value::Prim(9.0),
294            ])
295        );
296    }
297
298    #[test]
299    fn max_value_with_fold() {
300        let program = r#"
301            max_value = |...xs| {
302                fold(xs, -Inf, |acc, x| if(x > acc, x, acc))
303            };
304            max_value(1, -2, 7, 2, 5) == 7 && max_value(3, -5, 9) == 9
305        "#;
306        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
307
308        let module = ExecutableModule::builder(WildcardId, &block)
309            .unwrap()
310            .with_import("Inf", Value::Prim(f32::INFINITY))
311            .with_import("fold", Value::native_fn(Fold))
312            .with_import("if", Value::native_fn(If))
313            .build();
314
315        assert_eq!(module.run().unwrap(), Value::Bool(true));
316    }
317
318    #[test]
319    fn reverse_list_with_fold() {
320        const SAMPLES: &[(&[f32], &[f32])] = &[
321            (&[1.0, 2.0, 3.0], &[3.0, 2.0, 1.0]),
322            (&[], &[]),
323            (&[1.0], &[1.0]),
324        ];
325
326        let program = r#"
327            reverse = |xs| {
328                fold(xs, (), |acc, x| merge((x,), acc))
329            };
330            xs = (-4, 3, 0, 1);
331            xs.reverse() == (1, 0, 3, -4)
332        "#;
333        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
334
335        let module = ExecutableModule::builder(WildcardId, &block)
336            .unwrap()
337            .with_import("merge", Value::native_fn(Merge))
338            .with_import("fold", Value::native_fn(Fold))
339            .build();
340
341        let mut env = module.imports().into_iter().collect::<Environment<'_, _>>();
342        assert_eq!(module.run_in_env(&mut env).unwrap(), Value::Bool(true));
343
344        let test_block = Untyped::<F32Grammar>::parse_statements("xs.reverse()").unwrap();
345        let mut test_module = ExecutableModule::builder("test", &test_block)
346            .unwrap()
347            .with_import("reverse", env["reverse"].clone())
348            .set_imports(|_| Value::void());
349
350        for &(input, expected) in SAMPLES {
351            let input = input.iter().copied().map(Value::Prim).collect();
352            let expected = expected.iter().copied().map(Value::Prim).collect();
353            test_module.set_import("xs", Value::Tuple(input));
354            assert_eq!(test_module.run().unwrap(), Value::Tuple(expected));
355        }
356    }
357
358    #[test]
359    fn error_with_min_function_args() {
360        let program = "5 - min(1, (2, 3))";
361        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
362        let module = ExecutableModule::builder(WildcardId, &block)
363            .unwrap()
364            .with_import("min", Value::native_fn(Compare::Min))
365            .build();
366
367        let err = module.run().unwrap_err();
368        let err = err.source();
369        assert_eq!(*err.main_span().code().fragment(), "min(1, (2, 3))");
370        assert_matches!(
371            err.kind(),
372            ErrorKind::NativeCall(ref msg) if msg.contains("requires 2 primitive arguments")
373        );
374    }
375
376    #[test]
377    fn error_with_min_function_incomparable_args() {
378        let program = "5 - min(1, NAN)";
379        let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
380        let module = ExecutableModule::builder(WildcardId, &block)
381            .unwrap()
382            .with_import("NAN", Value::Prim(f32::NAN))
383            .with_import("min", Value::native_fn(Compare::Min))
384            .build();
385
386        let err = module.run().unwrap_err();
387        let err = err.source();
388        assert_eq!(*err.main_span().code().fragment(), "min(1, NAN)");
389        assert_matches!(err.kind(), ErrorKind::CannotCompare);
390    }
391}