arithmetic_eval/fns/flow.rs
1//! Flow control functions.
2
3use crate::{
4 alloc::{vec, Vec},
5 error::AuxErrorInfo,
6 fns::extract_fn,
7 CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
8};
9
10/// `if` function that eagerly evaluates "if" / "else" terms.
11///
12/// # Type
13///
14/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
15///
16/// ```text
17/// (Bool, 'T, 'T) -> 'T
18/// ```
19///
20/// # Examples
21///
22/// ```
23/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
24/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
25/// # fn main() -> anyhow::Result<()> {
26/// let program = "x = 3; if(x == 2, -1, x + 1)";
27/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
28///
29/// let module = Environment::new()
30/// .insert_native_fn("if", fns::If)
31/// .compile_module("if_test", &program)?;
32/// assert_eq!(module.run()?, Value::Prim(4.0));
33/// # Ok(())
34/// # }
35/// ```
36///
37/// You can also use the lazy evaluation by returning a function and evaluating it
38/// afterwards:
39///
40/// ```
41/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
42/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
43/// # fn main() -> anyhow::Result<()> {
44/// let program = "x = 3; if(x == 2, || -1, || x + 1)()";
45/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
46///
47/// let module = Environment::new()
48/// .insert_native_fn("if", fns::If)
49/// .compile_module("if_test", &program)?;
50/// assert_eq!(module.run()?, Value::Prim(4.0));
51/// # Ok(())
52/// # }
53/// ```
54#[derive(Debug, Clone, Copy, Default)]
55pub struct If;
56
57impl<T> NativeFn<T> for If {
58 fn evaluate<'a>(
59 &self,
60 mut args: Vec<SpannedValue<'a, T>>,
61 ctx: &mut CallContext<'_, 'a, T>,
62 ) -> EvalResult<'a, T> {
63 ctx.check_args_count(&args, 3)?;
64 let else_val = args.pop().unwrap().extra;
65 let then_val = args.pop().unwrap().extra;
66
67 if let Value::Bool(condition) = &args[0].extra {
68 Ok(if *condition { then_val } else { else_val })
69 } else {
70 let err = ErrorKind::native("`if` requires first arg to be boolean");
71 Err(ctx
72 .call_site_error(err)
73 .with_span(&args[0], AuxErrorInfo::InvalidArg))
74 }
75 }
76}
77
78/// Loop function that evaluates the provided closure one or more times.
79///
80/// # Type
81///
82/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation with custom
83/// notation for union types; they are not supported in the typing crate)
84///
85/// ```text
86/// ('T, ('T) -> (false, 'R) | (true, 'T)) -> 'R
87/// ```
88///
89/// # Examples
90///
91/// ```
92/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
93/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
94/// # fn main() -> anyhow::Result<()> {
95/// let program = r#"
96/// factorial = |x| {
97/// loop((x, 1), |(i, acc)| {
98/// continue = i >= 1;
99/// (continue, if(continue, (i - 1, acc * i), acc))
100/// })
101/// };
102/// factorial(5) == 120 && factorial(10) == 3628800
103/// "#;
104/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
105///
106/// let module = Environment::new()
107/// .insert_native_fn("if", fns::If)
108/// .insert_native_fn("loop", fns::Loop)
109/// .compile_module("test_loop", &program)?;
110/// assert_eq!(module.run()?, Value::Bool(true));
111/// # Ok(())
112/// # }
113/// ```
114#[derive(Debug, Clone, Copy, Default)]
115pub struct Loop;
116
117impl Loop {
118 const ITER_ERROR: &'static str =
119 "iteration function should return a 2-element tuple with first bool value";
120}
121
122impl<T: Clone> NativeFn<T> for Loop {
123 fn evaluate<'a>(
124 &self,
125 mut args: Vec<SpannedValue<'a, T>>,
126 ctx: &mut CallContext<'_, 'a, T>,
127 ) -> EvalResult<'a, T> {
128 ctx.check_args_count(&args, 2)?;
129 let iter = args.pop().unwrap();
130 let iter = if let Value::Function(iter) = iter.extra {
131 iter
132 } else {
133 let err = ErrorKind::native("Second argument of `loop` should be an iterator function");
134 return Err(ctx
135 .call_site_error(err)
136 .with_span(&iter, AuxErrorInfo::InvalidArg));
137 };
138
139 let mut arg = args.pop().unwrap();
140 loop {
141 if let Value::Tuple(mut tuple) = iter.evaluate(vec![arg], ctx)? {
142 let (ret_or_next_arg, flag) = if tuple.len() == 2 {
143 (tuple.pop().unwrap(), tuple.pop().unwrap())
144 } else {
145 let err = ErrorKind::native(Self::ITER_ERROR);
146 break Err(ctx.call_site_error(err));
147 };
148
149 match (flag, ret_or_next_arg) {
150 (Value::Bool(false), ret) => break Ok(ret),
151 (Value::Bool(true), next_arg) => {
152 arg = ctx.apply_call_span(next_arg);
153 }
154 _ => {
155 let err = ErrorKind::native(Self::ITER_ERROR);
156 break Err(ctx.call_site_error(err));
157 }
158 }
159 } else {
160 let err = ErrorKind::native(Self::ITER_ERROR);
161 break Err(ctx.call_site_error(err));
162 }
163 }
164 }
165}
166
167/// Loop function that evaluates the provided closure while a certain condition is true.
168/// Returns the loop state afterwards.
169///
170/// # Type
171///
172/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
173///
174/// ```text
175/// ('T, ('T) -> Bool, ('T) -> 'T) -> 'T
176/// ```
177///
178/// # Examples
179///
180/// ```
181/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
182/// # use arithmetic_eval::{fns, Environment, Value, VariableMap};
183/// # fn main() -> anyhow::Result<()> {
184/// let program = r#"
185/// factorial = |x| {
186/// (_, acc) = (x, 1).while(
187/// |(i, _)| i >= 1,
188/// |(i, acc)| (i - 1, acc * i),
189/// );
190/// acc
191/// };
192/// factorial(5) == 120 && factorial(10) == 3628800
193/// "#;
194/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
195///
196/// let module = Environment::new()
197/// .insert_native_fn("while", fns::While)
198/// .compile_module("test_while", &program)?;
199/// assert_eq!(module.run()?, Value::Bool(true));
200/// # Ok(())
201/// # }
202/// ```
203#[derive(Debug, Clone, Copy, Default)]
204pub struct While;
205
206impl<T: Clone> NativeFn<T> for While {
207 fn evaluate<'a>(
208 &self,
209 mut args: Vec<SpannedValue<'a, T>>,
210 ctx: &mut CallContext<'_, 'a, T>,
211 ) -> EvalResult<'a, T> {
212 ctx.check_args_count(&args, 3)?;
213
214 let step_fn = extract_fn(
215 ctx,
216 args.pop().unwrap(),
217 "`while` requires third arg to be a step function",
218 )?;
219 let condition_fn = extract_fn(
220 ctx,
221 args.pop().unwrap(),
222 "`while` requires second arg to be a condition function",
223 )?;
224 let mut state = args.pop().unwrap();
225 let state_span = state.copy_with_extra(());
226
227 loop {
228 let condition_value = condition_fn.evaluate(vec![state.clone()], ctx)?;
229 match condition_value {
230 Value::Bool(true) => {
231 let new_state = step_fn.evaluate(vec![state], ctx)?;
232 state = state_span.copy_with_extra(new_state);
233 }
234 Value::Bool(false) => break Ok(state.extra),
235 _ => {
236 let err =
237 ErrorKind::native("`while` requires condition function to return booleans");
238 return Err(ctx.call_site_error(err));
239 }
240 }
241 }
242 }
243}