arithmetic_eval/fns/
array.rs

1//! Functions on arrays.
2
3use core::cmp::Ordering;
4
5use num_traits::{FromPrimitive, One, Zero};
6
7use crate::{
8    alloc::{vec, Vec},
9    error::AuxErrorInfo,
10    fns::{extract_array, extract_fn, extract_primitive},
11    CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
12};
13
14/// Function generating an array by mapping its indexes.
15///
16/// # Type
17///
18/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
19///
20/// ```text
21/// (Num, (Num) -> 'T) -> ['T]
22/// ```
23///
24/// # Examples
25///
26/// ```
27/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
28/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
29/// # fn main() -> anyhow::Result<()> {
30/// let program = r#"array(3, |i| 2 * i + 1) == (1, 3, 5)"#;
31/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
32///
33/// let module = Environment::new()
34///     .insert_native_fn("array", fns::Array)
35///     .compile_module("test_array", &program)?;
36/// assert_eq!(module.run()?, Value::Bool(true));
37/// # Ok(())
38/// # }
39/// ```
40#[derive(Debug, Clone, Copy, Default)]
41pub struct Array;
42
43impl<T> NativeFn<T> for Array
44where
45    T: Clone + Zero + One,
46{
47    fn evaluate<'a>(
48        &self,
49        mut args: Vec<SpannedValue<'a, T>>,
50        ctx: &mut CallContext<'_, 'a, T>,
51    ) -> EvalResult<'a, T> {
52        ctx.check_args_count(&args, 2)?;
53        let generation_fn = extract_fn(
54            ctx,
55            args.pop().unwrap(),
56            "`array` requires second arg to be a generation function",
57        )?;
58        let len = extract_primitive(
59            ctx,
60            args.pop().unwrap(),
61            "`array` requires first arg to be a number",
62        )?;
63
64        let mut index = T::zero();
65        let mut array = vec![];
66        loop {
67            let next_index = ctx
68                .arithmetic()
69                .add(index.clone(), T::one())
70                .map_err(|err| ctx.call_site_error(ErrorKind::Arithmetic(err)))?;
71
72            let cmp = ctx.arithmetic().partial_cmp(&next_index, &len);
73            if matches!(cmp, Some(Ordering::Less) | Some(Ordering::Equal)) {
74                let spanned = ctx.apply_call_span(Value::Prim(index));
75                array.push(generation_fn.evaluate(vec![spanned], ctx)?);
76                index = next_index;
77            } else {
78                break;
79            }
80        }
81        Ok(Value::Tuple(array))
82    }
83}
84
85/// Function returning array / object length.
86///
87/// # Type
88///
89/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
90///
91/// ```text
92/// ([T]) -> Num
93/// ```
94///
95/// # Examples
96///
97/// ```
98/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
99/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
100/// # fn main() -> anyhow::Result<()> {
101/// let program = r#"().len() == 0 && (1, 2, 3).len() == 3"#;
102/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
103///
104/// let module = Environment::new()
105///     .insert_native_fn("len", fns::Len)
106///     .compile_module("test_len", &program)?;
107/// assert_eq!(module.run()?, Value::Bool(true));
108/// # Ok(())
109/// # }
110/// ```
111#[derive(Debug, Clone, Copy, Default)]
112pub struct Len;
113
114impl<T: FromPrimitive> NativeFn<T> for Len {
115    fn evaluate<'a>(
116        &self,
117        mut args: Vec<SpannedValue<'a, T>>,
118        ctx: &mut CallContext<'_, 'a, T>,
119    ) -> EvalResult<'a, T> {
120        ctx.check_args_count(&args, 1)?;
121        let arg = args.pop().unwrap();
122
123        let len = match arg.extra {
124            Value::Tuple(array) => array.len(),
125            Value::Object(object) => object.len(),
126            _ => {
127                let err = ErrorKind::native("`len` requires object or tuple arg");
128                return Err(ctx
129                    .call_site_error(err)
130                    .with_span(&arg, AuxErrorInfo::InvalidArg));
131            }
132        };
133        let len = T::from_usize(len).ok_or_else(|| {
134            let err = ErrorKind::native("Cannot convert length to number");
135            ctx.call_site_error(err)
136        })?;
137        Ok(Value::Prim(len))
138    }
139}
140
141/// Map function that evaluates the provided function on each item of the tuple.
142///
143/// # Type
144///
145/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
146///
147/// ```text
148/// (['T; N], ('T) -> 'U) -> ['U; N]
149/// ```
150///
151/// # Examples
152///
153/// ```
154/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
155/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
156/// # fn main() -> anyhow::Result<()> {
157/// let program = r#"
158///     xs = (1, -2, 3, -0.3);
159///     map(xs, |x| if(x > 0, x, 0)) == (1, 0, 3, 0)
160/// "#;
161/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
162///
163/// let module = Environment::new()
164///     .insert_native_fn("if", fns::If)
165///     .insert_native_fn("map", fns::Map)
166///     .compile_module("test_map", &program)?;
167/// assert_eq!(module.run()?, Value::Bool(true));
168/// # Ok(())
169/// # }
170/// ```
171#[derive(Debug, Clone, Copy, Default)]
172pub struct Map;
173
174impl<T: Clone> NativeFn<T> for Map {
175    fn evaluate<'a>(
176        &self,
177        mut args: Vec<SpannedValue<'a, T>>,
178        ctx: &mut CallContext<'_, 'a, T>,
179    ) -> EvalResult<'a, T> {
180        ctx.check_args_count(&args, 2)?;
181        let map_fn = extract_fn(
182            ctx,
183            args.pop().unwrap(),
184            "`map` requires second arg to be a mapping function",
185        )?;
186        let array = extract_array(
187            ctx,
188            args.pop().unwrap(),
189            "`map` requires first arg to be a tuple",
190        )?;
191
192        let mapped: Result<Vec<_>, _> = array
193            .into_iter()
194            .map(|value| {
195                let spanned = ctx.apply_call_span(value);
196                map_fn.evaluate(vec![spanned], ctx)
197            })
198            .collect();
199        mapped.map(Value::Tuple)
200    }
201}
202
203/// Filter function that evaluates the provided function on each item of the tuple and retains
204/// only elements for which the function returned `true`.
205///
206/// # Type
207///
208/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
209///
210/// ```text
211/// (['T; N], ('T) -> Bool) -> ['T]
212/// ```
213///
214/// # Examples
215///
216/// ```
217/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
218/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
219/// # fn main() -> anyhow::Result<()> {
220/// let program = r#"
221///     xs = (1, -2, 3, -7, -0.3);
222///     filter(xs, |x| x > -1) == (1, 3, -0.3)
223/// "#;
224/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
225///
226/// let module = Environment::new()
227///     .insert_native_fn("filter", fns::Filter)
228///     .compile_module("test_filter", &program)?;
229/// assert_eq!(module.run()?, Value::Bool(true));
230/// # Ok(())
231/// # }
232/// ```
233#[derive(Debug, Clone, Copy, Default)]
234pub struct Filter;
235
236impl<T: Clone> NativeFn<T> for Filter {
237    fn evaluate<'a>(
238        &self,
239        mut args: Vec<SpannedValue<'a, T>>,
240        ctx: &mut CallContext<'_, 'a, T>,
241    ) -> EvalResult<'a, T> {
242        ctx.check_args_count(&args, 2)?;
243        let filter_fn = extract_fn(
244            ctx,
245            args.pop().unwrap(),
246            "`filter` requires second arg to be a filter function",
247        )?;
248        let array = extract_array(
249            ctx,
250            args.pop().unwrap(),
251            "`filter` requires first arg to be a tuple",
252        )?;
253
254        let mut filtered = vec![];
255        for value in array {
256            let spanned = ctx.apply_call_span(value.clone());
257            match filter_fn.evaluate(vec![spanned], ctx)? {
258                Value::Bool(true) => filtered.push(value),
259                Value::Bool(false) => { /* do nothing */ }
260                _ => {
261                    let err = ErrorKind::native(
262                        "`filter` requires filtering function to return booleans",
263                    );
264                    return Err(ctx.call_site_error(err));
265                }
266            }
267        }
268        Ok(Value::Tuple(filtered))
269    }
270}
271
272/// Reduce (aka fold) function that reduces the provided tuple to a single value.
273///
274/// # Type
275///
276/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
277///
278/// ```text
279/// (['T], 'Acc, ('Acc, 'T) -> 'Acc) -> 'Acc
280/// ```
281///
282/// # Examples
283///
284/// ```
285/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
286/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
287/// # fn main() -> anyhow::Result<()> {
288/// let program = r#"
289///     xs = (1, -2, 3, -7);
290///     fold(xs, 1, |acc, x| acc * x) == 42
291/// "#;
292/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
293///
294/// let module = Environment::new()
295///     .insert_native_fn("fold", fns::Fold)
296///     .compile_module("test_fold", &program)?;
297/// assert_eq!(module.run()?, Value::Bool(true));
298/// # Ok(())
299/// # }
300/// ```
301#[derive(Debug, Clone, Copy, Default)]
302pub struct Fold;
303
304impl<T: Clone> NativeFn<T> for Fold {
305    fn evaluate<'a>(
306        &self,
307        mut args: Vec<SpannedValue<'a, T>>,
308        ctx: &mut CallContext<'_, 'a, T>,
309    ) -> EvalResult<'a, T> {
310        ctx.check_args_count(&args, 3)?;
311        let fold_fn = extract_fn(
312            ctx,
313            args.pop().unwrap(),
314            "`fold` requires third arg to be a folding function",
315        )?;
316        let acc = args.pop().unwrap().extra;
317        let array = extract_array(
318            ctx,
319            args.pop().unwrap(),
320            "`fold` requires first arg to be a tuple",
321        )?;
322
323        array.into_iter().try_fold(acc, |acc, value| {
324            let spanned_args = vec![ctx.apply_call_span(acc), ctx.apply_call_span(value)];
325            fold_fn.evaluate(spanned_args, ctx)
326        })
327    }
328}
329
330/// Function that appends a value onto a tuple.
331///
332/// # Type
333///
334/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
335///
336/// ```text
337/// (['T; N], 'T) -> ['T; N + 1]
338/// ```
339///
340/// # Examples
341///
342/// ```
343/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
344/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
345/// # fn main() -> anyhow::Result<()> {
346/// let program = r#"
347///     repeat = |x, times| {
348///         (_, acc) = (0, ()).while(
349///             |(i, _)| i < times,
350///             |(i, acc)| (i + 1, push(acc, x)),
351///         );
352///         acc
353///     };
354///     repeat(-2, 3) == (-2, -2, -2) &&
355///         repeat((7,), 4) == ((7,), (7,), (7,), (7,))
356/// "#;
357/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
358///
359/// let module = Environment::new()
360///     .insert_native_fn("while", fns::While)
361///     .insert_native_fn("push", fns::Push)
362///     .compile_module("test_push", &program)?;
363/// assert_eq!(module.run()?, Value::Bool(true));
364/// # Ok(())
365/// # }
366/// ```
367#[derive(Debug, Clone, Copy, Default)]
368pub struct Push;
369
370impl<T> NativeFn<T> for Push {
371    fn evaluate<'a>(
372        &self,
373        mut args: Vec<SpannedValue<'a, T>>,
374        ctx: &mut CallContext<'_, 'a, T>,
375    ) -> EvalResult<'a, T> {
376        ctx.check_args_count(&args, 2)?;
377        let elem = args.pop().unwrap().extra;
378        let mut array = extract_array(
379            ctx,
380            args.pop().unwrap(),
381            "`fold` requires first arg to be a tuple",
382        )?;
383
384        array.push(elem);
385        Ok(Value::Tuple(array))
386    }
387}
388
389/// Function that merges two tuples.
390///
391/// # Type
392///
393/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
394///
395/// ```text
396/// (['T], ['T]) -> ['T]
397/// ```
398///
399/// # Examples
400///
401/// ```
402/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
403/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
404/// # fn main() -> anyhow::Result<()> {
405/// let program = r#"
406///     // Merges all arguments (which should be tuples) into a single tuple.
407///     super_merge = |...xs| fold(xs, (), merge);
408///     super_merge((1, 2), (3,), (), (4, 5, 6)) == (1, 2, 3, 4, 5, 6)
409/// "#;
410/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
411///
412/// let module = Environment::new()
413///     .insert_native_fn("fold", fns::Fold)
414///     .insert_native_fn("merge", fns::Merge)
415///     .compile_module("test_merge", &program)?;
416/// assert_eq!(module.run()?, Value::Bool(true));
417/// # Ok(())
418/// # }
419/// ```
420#[derive(Debug, Clone, Copy, Default)]
421pub struct Merge;
422
423impl<T: Clone> NativeFn<T> for Merge {
424    fn evaluate<'a>(
425        &self,
426        mut args: Vec<SpannedValue<'a, T>>,
427        ctx: &mut CallContext<'_, 'a, T>,
428    ) -> EvalResult<'a, T> {
429        ctx.check_args_count(&args, 2)?;
430        let second = extract_array(
431            ctx,
432            args.pop().unwrap(),
433            "`merge` requires second arg to be a tuple",
434        )?;
435        let mut first = extract_array(
436            ctx,
437            args.pop().unwrap(),
438            "`merge` requires first arg to be a tuple",
439        )?;
440
441        first.extend_from_slice(&second);
442        Ok(Value::Tuple(first))
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use crate::{
450        arith::{OrdArithmetic, StdArithmetic, WrappingArithmetic},
451        Environment, VariableMap,
452    };
453
454    use arithmetic_parser::grammars::{NumGrammar, NumLiteral, Parse, Untyped};
455    use assert_matches::assert_matches;
456
457    fn test_len_function<T: NumLiteral>(arithmetic: &dyn OrdArithmetic<T>)
458    where
459        Len: NativeFn<T>,
460    {
461        let code = r#"
462            (1, 2, 3).len() == 3 && ().len() == 0 &&
463            #{}.len() == 0 && #{ x: 1 }.len() == 1 && #{ x: 1, y: 2 }.len() == 2
464        "#;
465        let block = Untyped::<NumGrammar<T>>::parse_statements(code).unwrap();
466        let mut env = Environment::new();
467        let module = env
468            .insert("len", Value::native_fn(Len))
469            .compile_module("len", &block)
470            .unwrap();
471
472        let output = module.with_arithmetic(arithmetic).run().unwrap();
473        assert_matches!(output, Value::Bool(true));
474    }
475
476    #[test]
477    fn len_function_in_floating_point_arithmetic() {
478        test_len_function::<f32>(&StdArithmetic);
479        test_len_function::<f64>(&StdArithmetic);
480    }
481
482    #[test]
483    fn len_function_in_int_arithmetic() {
484        test_len_function::<u8>(&WrappingArithmetic);
485        test_len_function::<i8>(&WrappingArithmetic);
486        test_len_function::<u64>(&WrappingArithmetic);
487        test_len_function::<i64>(&WrappingArithmetic);
488    }
489
490    #[test]
491    fn len_function_with_number_overflow() {
492        let code = "xs.len()";
493        let block = Untyped::<NumGrammar<i8>>::parse_statements(code).unwrap();
494        let mut env = Environment::new();
495        let module = env
496            .insert("xs", Value::Tuple(vec![Value::Bool(true); 128]))
497            .insert("len", Value::native_fn(Len))
498            .compile_module("len", &block)
499            .unwrap();
500
501        let err = module
502            .with_arithmetic(&WrappingArithmetic)
503            .run()
504            .unwrap_err();
505        assert_matches!(
506            err.source().kind(),
507            ErrorKind::NativeCall(msg) if msg.contains("length to number")
508        );
509    }
510
511    #[test]
512    fn array_function_in_floating_point_arithmetic() {
513        let code = r#"
514            array(0, |_| 1) == () && array(-1, |_| 1) == () &&
515            array(0.1, |_| 1) == () && array(0.999, |_| 1) == () &&
516            array(1, |_| 1) == (1,) && array(1.5, |_| 1) == (1,) &&
517            array(2, |_| 1) == (1, 1) && array(3, |i| i) == (0, 1, 2)
518        "#;
519        let block = Untyped::<NumGrammar<f32>>::parse_statements(code).unwrap();
520        let mut env = Environment::new();
521        let module = env
522            .insert("array", Value::native_fn(Array))
523            .compile_module("array", &block)
524            .unwrap();
525
526        let output = module.with_arithmetic(&StdArithmetic).run().unwrap();
527        assert_matches!(output, Value::Bool(true));
528    }
529
530    #[test]
531    fn array_function_in_unsigned_int_arithmetic() {
532        let code = r#"
533            array(0, |_| 1) == () && array(1, |_| 1) == (1,) && array(3, |i| i) == (0, 1, 2)
534        "#;
535        let block = Untyped::<NumGrammar<u32>>::parse_statements(code).unwrap();
536        let mut env = Environment::new();
537        let module = env
538            .insert("array", Value::native_fn(Array))
539            .compile_module("array", &block)
540            .unwrap();
541
542        let output = module.with_arithmetic(&WrappingArithmetic).run().unwrap();
543        assert_matches!(output, Value::Bool(true));
544    }
545}