arithmetic_eval/fns/
flow.rs

1//! Flow control functions.
2
3use crate::{
4    alloc::{vec, Vec},
5    error::AuxErrorInfo,
6    fns::extract_fn,
7    CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
8};
9
10/// `if` function that eagerly evaluates "if" / "else" terms.
11///
12/// # Type
13///
14/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
15///
16/// ```text
17/// (Bool, 'T, 'T) -> 'T
18/// ```
19///
20/// # Examples
21///
22/// ```
23/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
24/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
25/// # fn main() -> anyhow::Result<()> {
26/// let program = "x = 3; if(x == 2, -1, x + 1)";
27/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
28///
29/// let module = Environment::new()
30///     .insert_native_fn("if", fns::If)
31///     .compile_module("if_test", &program)?;
32/// assert_eq!(module.run()?, Value::Prim(4.0));
33/// # Ok(())
34/// # }
35/// ```
36///
37/// You can also use the lazy evaluation by returning a function and evaluating it
38/// afterwards:
39///
40/// ```
41/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
42/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
43/// # fn main() -> anyhow::Result<()> {
44/// let program = "x = 3; if(x == 2, || -1, || x + 1)()";
45/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
46///
47/// let module = Environment::new()
48///     .insert_native_fn("if", fns::If)
49///     .compile_module("if_test", &program)?;
50/// assert_eq!(module.run()?, Value::Prim(4.0));
51/// # Ok(())
52/// # }
53/// ```
54#[derive(Debug, Clone, Copy, Default)]
55pub struct If;
56
57impl<T> NativeFn<T> for If {
58    fn evaluate<'a>(
59        &self,
60        mut args: Vec<SpannedValue<'a, T>>,
61        ctx: &mut CallContext<'_, 'a, T>,
62    ) -> EvalResult<'a, T> {
63        ctx.check_args_count(&args, 3)?;
64        let else_val = args.pop().unwrap().extra;
65        let then_val = args.pop().unwrap().extra;
66
67        if let Value::Bool(condition) = &args[0].extra {
68            Ok(if *condition { then_val } else { else_val })
69        } else {
70            let err = ErrorKind::native("`if` requires first arg to be boolean");
71            Err(ctx
72                .call_site_error(err)
73                .with_span(&args[0], AuxErrorInfo::InvalidArg))
74        }
75    }
76}
77
78/// Loop function that evaluates the provided closure one or more times.
79///
80/// # Type
81///
82/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation with custom
83/// notation for union types; they are not supported in the typing crate)
84///
85/// ```text
86/// ('T, ('T) -> (false, 'R) | (true, 'T)) -> 'R
87/// ```
88///
89/// # Examples
90///
91/// ```
92/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
93/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
94/// # fn main() -> anyhow::Result<()> {
95/// let program = r#"
96///     factorial = |x| {
97///         loop((x, 1), |(i, acc)| {
98///             continue = i >= 1;
99///             (continue, if(continue, (i - 1, acc * i), acc))
100///         })
101///     };
102///     factorial(5) == 120 && factorial(10) == 3628800
103/// "#;
104/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
105///
106/// let module = Environment::new()
107///     .insert_native_fn("if", fns::If)
108///     .insert_native_fn("loop", fns::Loop)
109///     .compile_module("test_loop", &program)?;
110/// assert_eq!(module.run()?, Value::Bool(true));
111/// # Ok(())
112/// # }
113/// ```
114#[derive(Debug, Clone, Copy, Default)]
115pub struct Loop;
116
117impl Loop {
118    const ITER_ERROR: &'static str =
119        "iteration function should return a 2-element tuple with first bool value";
120}
121
122impl<T: Clone> NativeFn<T> for Loop {
123    fn evaluate<'a>(
124        &self,
125        mut args: Vec<SpannedValue<'a, T>>,
126        ctx: &mut CallContext<'_, 'a, T>,
127    ) -> EvalResult<'a, T> {
128        ctx.check_args_count(&args, 2)?;
129        let iter = args.pop().unwrap();
130        let iter = if let Value::Function(iter) = iter.extra {
131            iter
132        } else {
133            let err = ErrorKind::native("Second argument of `loop` should be an iterator function");
134            return Err(ctx
135                .call_site_error(err)
136                .with_span(&iter, AuxErrorInfo::InvalidArg));
137        };
138
139        let mut arg = args.pop().unwrap();
140        loop {
141            if let Value::Tuple(mut tuple) = iter.evaluate(vec![arg], ctx)? {
142                let (ret_or_next_arg, flag) = if tuple.len() == 2 {
143                    (tuple.pop().unwrap(), tuple.pop().unwrap())
144                } else {
145                    let err = ErrorKind::native(Self::ITER_ERROR);
146                    break Err(ctx.call_site_error(err));
147                };
148
149                match (flag, ret_or_next_arg) {
150                    (Value::Bool(false), ret) => break Ok(ret),
151                    (Value::Bool(true), next_arg) => {
152                        arg = ctx.apply_call_span(next_arg);
153                    }
154                    _ => {
155                        let err = ErrorKind::native(Self::ITER_ERROR);
156                        break Err(ctx.call_site_error(err));
157                    }
158                }
159            } else {
160                let err = ErrorKind::native(Self::ITER_ERROR);
161                break Err(ctx.call_site_error(err));
162            }
163        }
164    }
165}
166
167/// Loop function that evaluates the provided closure while a certain condition is true.
168/// Returns the loop state afterwards.
169///
170/// # Type
171///
172/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
173///
174/// ```text
175/// ('T, ('T) -> Bool, ('T) -> 'T) -> 'T
176/// ```
177///
178/// # Examples
179///
180/// ```
181/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
182/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
183/// # fn main() -> anyhow::Result<()> {
184/// let program = r#"
185///     factorial = |x| {
186///         (_, acc) = (x, 1).while(
187///             |(i, _)| i >= 1,
188///             |(i, acc)| (i - 1, acc * i),
189///         );
190///         acc
191///     };
192///     factorial(5) == 120 && factorial(10) == 3628800
193/// "#;
194/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
195///
196/// let module = Environment::new()
197///     .insert_native_fn("while", fns::While)
198///     .compile_module("test_while", &program)?;
199/// assert_eq!(module.run()?, Value::Bool(true));
200/// # Ok(())
201/// # }
202/// ```
203#[derive(Debug, Clone, Copy, Default)]
204pub struct While;
205
206impl<T: Clone> NativeFn<T> for While {
207    fn evaluate<'a>(
208        &self,
209        mut args: Vec<SpannedValue<'a, T>>,
210        ctx: &mut CallContext<'_, 'a, T>,
211    ) -> EvalResult<'a, T> {
212        ctx.check_args_count(&args, 3)?;
213
214        let step_fn = extract_fn(
215            ctx,
216            args.pop().unwrap(),
217            "`while` requires third arg to be a step function",
218        )?;
219        let condition_fn = extract_fn(
220            ctx,
221            args.pop().unwrap(),
222            "`while` requires second arg to be a condition function",
223        )?;
224        let mut state = args.pop().unwrap();
225        let state_span = state.copy_with_extra(());
226
227        loop {
228            let condition_value = condition_fn.evaluate(vec![state.clone()], ctx)?;
229            match condition_value {
230                Value::Bool(true) => {
231                    let new_state = step_fn.evaluate(vec![state], ctx)?;
232                    state = state_span.copy_with_extra(new_state);
233                }
234                Value::Bool(false) => break Ok(state.extra),
235                _ => {
236                    let err =
237                        ErrorKind::native("`while` requires condition function to return booleans");
238                    return Err(ctx.call_site_error(err));
239                }
240            }
241        }
242    }
243}