1use core::cmp::Ordering;
29
30use crate::{
31 alloc::Vec, error::AuxErrorInfo, CallContext, Error, ErrorKind, EvalResult, Function, NativeFn,
32 SpannedValue, Value,
33};
34
35mod array;
36mod assertions;
37mod flow;
38#[cfg(feature = "std")]
39mod std;
40mod wrapper;
41
42#[cfg(feature = "std")]
43pub use self::std::Dbg;
44pub use self::{
45 array::{Array, Filter, Fold, Len, Map, Merge, Push},
46 assertions::{Assert, AssertEq},
47 flow::{If, Loop, While},
48 wrapper::{
49 enforce_closure_type, wrap, Binary, ErrorOutput, FnWrapper, FromValueError,
50 FromValueErrorKind, FromValueErrorLocation, IntoEvalResult, Quaternary, Ternary,
51 TryFromValue, Unary,
52 },
53};
54
55fn extract_primitive<'a, T, A>(
56 ctx: &CallContext<'_, 'a, A>,
57 value: SpannedValue<'a, T>,
58 error_msg: &str,
59) -> Result<T, Error<'a>> {
60 match value.extra {
61 Value::Prim(value) => Ok(value),
62 _ => Err(ctx
63 .call_site_error(ErrorKind::native(error_msg))
64 .with_span(&value, AuxErrorInfo::InvalidArg)),
65 }
66}
67
68fn extract_array<'a, T, A>(
69 ctx: &CallContext<'_, 'a, A>,
70 value: SpannedValue<'a, T>,
71 error_msg: &str,
72) -> Result<Vec<Value<'a, T>>, Error<'a>> {
73 if let Value::Tuple(array) = value.extra {
74 Ok(array)
75 } else {
76 let err = ErrorKind::native(error_msg);
77 Err(ctx
78 .call_site_error(err)
79 .with_span(&value, AuxErrorInfo::InvalidArg))
80 }
81}
82
83fn extract_fn<'a, T, A>(
84 ctx: &CallContext<'_, 'a, A>,
85 value: SpannedValue<'a, T>,
86 error_msg: &str,
87) -> Result<Function<'a, T>, Error<'a>> {
88 if let Value::Function(function) = value.extra {
89 Ok(function)
90 } else {
91 let err = ErrorKind::native(error_msg);
92 Err(ctx
93 .call_site_error(err)
94 .with_span(&value, AuxErrorInfo::InvalidArg))
95 }
96}
97
98#[derive(Debug, Clone, Copy)]
155#[non_exhaustive]
156pub enum Compare {
157 Raw,
160 Min,
162 Max,
164}
165
166impl Compare {
167 fn extract_primitives<'a, T>(
168 mut args: Vec<SpannedValue<'a, T>>,
169 ctx: &mut CallContext<'_, 'a, T>,
170 ) -> Result<(T, T), Error<'a>> {
171 ctx.check_args_count(&args, 2)?;
172 let y = args.pop().unwrap();
173 let x = args.pop().unwrap();
174 let x = extract_primitive(ctx, x, COMPARE_ERROR_MSG)?;
175 let y = extract_primitive(ctx, y, COMPARE_ERROR_MSG)?;
176 Ok((x, y))
177 }
178}
179
180const COMPARE_ERROR_MSG: &str = "Compare requires 2 primitive arguments";
181
182impl<T> NativeFn<T> for Compare {
183 fn evaluate<'a>(
184 &self,
185 args: Vec<SpannedValue<'a, T>>,
186 ctx: &mut CallContext<'_, 'a, T>,
187 ) -> EvalResult<'a, T> {
188 let (x, y) = Self::extract_primitives(args, ctx)?;
189 let maybe_ordering = ctx.arithmetic().partial_cmp(&x, &y);
190
191 if let Self::Raw = self {
192 Ok(maybe_ordering.map_or_else(Value::void, Value::opaque_ref))
193 } else {
194 let ordering =
195 maybe_ordering.ok_or_else(|| ctx.call_site_error(ErrorKind::CannotCompare))?;
196 let value = match (ordering, self) {
197 (Ordering::Equal, _)
198 | (Ordering::Less, Self::Min)
199 | (Ordering::Greater, Self::Max) => x,
200 _ => y,
201 };
202 Ok(Value::Prim(value))
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::{Environment, ExecutableModule, WildcardId};
211
212 use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
213 use assert_matches::assert_matches;
214
215 #[test]
216 fn if_basic() {
217 let block = r#"
218 x = 1.0;
219 if(x < 2, x + 5, 3 - x)
220 "#;
221 let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
222 let module = ExecutableModule::builder(WildcardId, &block)
223 .unwrap()
224 .with_import("if", Value::native_fn(If))
225 .build();
226 assert_eq!(module.run().unwrap(), Value::Prim(6.0));
227 }
228
229 #[test]
230 fn if_with_closures() {
231 let block = r#"
232 x = 4.5;
233 if(x < 2, || x + 5, || 3 - x)()
234 "#;
235 let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
236 let module = ExecutableModule::builder(WildcardId, &block)
237 .unwrap()
238 .with_import("if", Value::native_fn(If))
239 .build();
240 assert_eq!(module.run().unwrap(), Value::Prim(-1.5));
241 }
242
243 #[test]
244 fn cmp_sugar() {
245 let program = "x = 1.0; x > 0 && x <= 3";
246 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
247 let module = ExecutableModule::builder(WildcardId, &block)
248 .unwrap()
249 .build();
250 assert_eq!(module.run().unwrap(), Value::Bool(true));
251
252 let bogus_program = "x = 1.0; x > (1, 2)";
253 let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program).unwrap();
254 let bogus_module = ExecutableModule::builder(WildcardId, &bogus_block)
255 .unwrap()
256 .build();
257
258 let err = bogus_module.run().unwrap_err();
259 let err = err.source();
260 assert_matches!(err.kind(), ErrorKind::CannotCompare);
261 assert_eq!(*err.main_span().code().fragment(), "(1, 2)");
262 }
263
264 #[test]
265 fn loop_basic() {
266 let program = r#"
267 // Finds the greatest power of 2 lesser or equal to the value.
268 discrete_log2 = |x| {
269 loop(0, |i| {
270 continue = 2^i <= x;
271 (continue, if(continue, i + 1, i - 1))
272 })
273 };
274
275 (discrete_log2(1), discrete_log2(2),
276 discrete_log2(4), discrete_log2(6.5), discrete_log2(1000))
277 "#;
278 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
279
280 let module = ExecutableModule::builder(WildcardId, &block)
281 .unwrap()
282 .with_import("loop", Value::native_fn(Loop))
283 .with_import("if", Value::native_fn(If))
284 .build();
285
286 assert_eq!(
287 module.run().unwrap(),
288 Value::Tuple(vec![
289 Value::Prim(0.0),
290 Value::Prim(1.0),
291 Value::Prim(2.0),
292 Value::Prim(2.0),
293 Value::Prim(9.0),
294 ])
295 );
296 }
297
298 #[test]
299 fn max_value_with_fold() {
300 let program = r#"
301 max_value = |...xs| {
302 fold(xs, -Inf, |acc, x| if(x > acc, x, acc))
303 };
304 max_value(1, -2, 7, 2, 5) == 7 && max_value(3, -5, 9) == 9
305 "#;
306 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
307
308 let module = ExecutableModule::builder(WildcardId, &block)
309 .unwrap()
310 .with_import("Inf", Value::Prim(f32::INFINITY))
311 .with_import("fold", Value::native_fn(Fold))
312 .with_import("if", Value::native_fn(If))
313 .build();
314
315 assert_eq!(module.run().unwrap(), Value::Bool(true));
316 }
317
318 #[test]
319 fn reverse_list_with_fold() {
320 const SAMPLES: &[(&[f32], &[f32])] = &[
321 (&[1.0, 2.0, 3.0], &[3.0, 2.0, 1.0]),
322 (&[], &[]),
323 (&[1.0], &[1.0]),
324 ];
325
326 let program = r#"
327 reverse = |xs| {
328 fold(xs, (), |acc, x| merge((x,), acc))
329 };
330 xs = (-4, 3, 0, 1);
331 xs.reverse() == (1, 0, 3, -4)
332 "#;
333 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
334
335 let module = ExecutableModule::builder(WildcardId, &block)
336 .unwrap()
337 .with_import("merge", Value::native_fn(Merge))
338 .with_import("fold", Value::native_fn(Fold))
339 .build();
340
341 let mut env = module.imports().into_iter().collect::<Environment<'_, _>>();
342 assert_eq!(module.run_in_env(&mut env).unwrap(), Value::Bool(true));
343
344 let test_block = Untyped::<F32Grammar>::parse_statements("xs.reverse()").unwrap();
345 let mut test_module = ExecutableModule::builder("test", &test_block)
346 .unwrap()
347 .with_import("reverse", env["reverse"].clone())
348 .set_imports(|_| Value::void());
349
350 for &(input, expected) in SAMPLES {
351 let input = input.iter().copied().map(Value::Prim).collect();
352 let expected = expected.iter().copied().map(Value::Prim).collect();
353 test_module.set_import("xs", Value::Tuple(input));
354 assert_eq!(test_module.run().unwrap(), Value::Tuple(expected));
355 }
356 }
357
358 #[test]
359 fn error_with_min_function_args() {
360 let program = "5 - min(1, (2, 3))";
361 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
362 let module = ExecutableModule::builder(WildcardId, &block)
363 .unwrap()
364 .with_import("min", Value::native_fn(Compare::Min))
365 .build();
366
367 let err = module.run().unwrap_err();
368 let err = err.source();
369 assert_eq!(*err.main_span().code().fragment(), "min(1, (2, 3))");
370 assert_matches!(
371 err.kind(),
372 ErrorKind::NativeCall(ref msg) if msg.contains("requires 2 primitive arguments")
373 );
374 }
375
376 #[test]
377 fn error_with_min_function_incomparable_args() {
378 let program = "5 - min(1, NAN)";
379 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
380 let module = ExecutableModule::builder(WildcardId, &block)
381 .unwrap()
382 .with_import("NAN", Value::Prim(f32::NAN))
383 .with_import("min", Value::native_fn(Compare::Min))
384 .build();
385
386 let err = module.run().unwrap_err();
387 let err = err.source();
388 assert_eq!(*err.main_span().code().fragment(), "min(1, NAN)");
389 assert_matches!(err.kind(), ErrorKind::CannotCompare);
390 }
391}