1use core::cmp::Ordering;
4
5use num_traits::{FromPrimitive, One, Zero};
6
7use crate::{
8 alloc::{vec, Vec},
9 error::AuxErrorInfo,
10 fns::{extract_array, extract_fn, extract_primitive},
11 CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
12};
13
14#[derive(Debug, Clone, Copy, Default)]
41pub struct Array;
42
43impl<T> NativeFn<T> for Array
44where
45 T: Clone + Zero + One,
46{
47 fn evaluate<'a>(
48 &self,
49 mut args: Vec<SpannedValue<'a, T>>,
50 ctx: &mut CallContext<'_, 'a, T>,
51 ) -> EvalResult<'a, T> {
52 ctx.check_args_count(&args, 2)?;
53 let generation_fn = extract_fn(
54 ctx,
55 args.pop().unwrap(),
56 "`array` requires second arg to be a generation function",
57 )?;
58 let len = extract_primitive(
59 ctx,
60 args.pop().unwrap(),
61 "`array` requires first arg to be a number",
62 )?;
63
64 let mut index = T::zero();
65 let mut array = vec![];
66 loop {
67 let next_index = ctx
68 .arithmetic()
69 .add(index.clone(), T::one())
70 .map_err(|err| ctx.call_site_error(ErrorKind::Arithmetic(err)))?;
71
72 let cmp = ctx.arithmetic().partial_cmp(&next_index, &len);
73 if matches!(cmp, Some(Ordering::Less) | Some(Ordering::Equal)) {
74 let spanned = ctx.apply_call_span(Value::Prim(index));
75 array.push(generation_fn.evaluate(vec![spanned], ctx)?);
76 index = next_index;
77 } else {
78 break;
79 }
80 }
81 Ok(Value::Tuple(array))
82 }
83}
84
85#[derive(Debug, Clone, Copy, Default)]
112pub struct Len;
113
114impl<T: FromPrimitive> NativeFn<T> for Len {
115 fn evaluate<'a>(
116 &self,
117 mut args: Vec<SpannedValue<'a, T>>,
118 ctx: &mut CallContext<'_, 'a, T>,
119 ) -> EvalResult<'a, T> {
120 ctx.check_args_count(&args, 1)?;
121 let arg = args.pop().unwrap();
122
123 let len = match arg.extra {
124 Value::Tuple(array) => array.len(),
125 Value::Object(object) => object.len(),
126 _ => {
127 let err = ErrorKind::native("`len` requires object or tuple arg");
128 return Err(ctx
129 .call_site_error(err)
130 .with_span(&arg, AuxErrorInfo::InvalidArg));
131 }
132 };
133 let len = T::from_usize(len).ok_or_else(|| {
134 let err = ErrorKind::native("Cannot convert length to number");
135 ctx.call_site_error(err)
136 })?;
137 Ok(Value::Prim(len))
138 }
139}
140
141#[derive(Debug, Clone, Copy, Default)]
172pub struct Map;
173
174impl<T: Clone> NativeFn<T> for Map {
175 fn evaluate<'a>(
176 &self,
177 mut args: Vec<SpannedValue<'a, T>>,
178 ctx: &mut CallContext<'_, 'a, T>,
179 ) -> EvalResult<'a, T> {
180 ctx.check_args_count(&args, 2)?;
181 let map_fn = extract_fn(
182 ctx,
183 args.pop().unwrap(),
184 "`map` requires second arg to be a mapping function",
185 )?;
186 let array = extract_array(
187 ctx,
188 args.pop().unwrap(),
189 "`map` requires first arg to be a tuple",
190 )?;
191
192 let mapped: Result<Vec<_>, _> = array
193 .into_iter()
194 .map(|value| {
195 let spanned = ctx.apply_call_span(value);
196 map_fn.evaluate(vec![spanned], ctx)
197 })
198 .collect();
199 mapped.map(Value::Tuple)
200 }
201}
202
203#[derive(Debug, Clone, Copy, Default)]
234pub struct Filter;
235
236impl<T: Clone> NativeFn<T> for Filter {
237 fn evaluate<'a>(
238 &self,
239 mut args: Vec<SpannedValue<'a, T>>,
240 ctx: &mut CallContext<'_, 'a, T>,
241 ) -> EvalResult<'a, T> {
242 ctx.check_args_count(&args, 2)?;
243 let filter_fn = extract_fn(
244 ctx,
245 args.pop().unwrap(),
246 "`filter` requires second arg to be a filter function",
247 )?;
248 let array = extract_array(
249 ctx,
250 args.pop().unwrap(),
251 "`filter` requires first arg to be a tuple",
252 )?;
253
254 let mut filtered = vec![];
255 for value in array {
256 let spanned = ctx.apply_call_span(value.clone());
257 match filter_fn.evaluate(vec![spanned], ctx)? {
258 Value::Bool(true) => filtered.push(value),
259 Value::Bool(false) => { }
260 _ => {
261 let err = ErrorKind::native(
262 "`filter` requires filtering function to return booleans",
263 );
264 return Err(ctx.call_site_error(err));
265 }
266 }
267 }
268 Ok(Value::Tuple(filtered))
269 }
270}
271
272#[derive(Debug, Clone, Copy, Default)]
302pub struct Fold;
303
304impl<T: Clone> NativeFn<T> for Fold {
305 fn evaluate<'a>(
306 &self,
307 mut args: Vec<SpannedValue<'a, T>>,
308 ctx: &mut CallContext<'_, 'a, T>,
309 ) -> EvalResult<'a, T> {
310 ctx.check_args_count(&args, 3)?;
311 let fold_fn = extract_fn(
312 ctx,
313 args.pop().unwrap(),
314 "`fold` requires third arg to be a folding function",
315 )?;
316 let acc = args.pop().unwrap().extra;
317 let array = extract_array(
318 ctx,
319 args.pop().unwrap(),
320 "`fold` requires first arg to be a tuple",
321 )?;
322
323 array.into_iter().try_fold(acc, |acc, value| {
324 let spanned_args = vec![ctx.apply_call_span(acc), ctx.apply_call_span(value)];
325 fold_fn.evaluate(spanned_args, ctx)
326 })
327 }
328}
329
330#[derive(Debug, Clone, Copy, Default)]
368pub struct Push;
369
370impl<T> NativeFn<T> for Push {
371 fn evaluate<'a>(
372 &self,
373 mut args: Vec<SpannedValue<'a, T>>,
374 ctx: &mut CallContext<'_, 'a, T>,
375 ) -> EvalResult<'a, T> {
376 ctx.check_args_count(&args, 2)?;
377 let elem = args.pop().unwrap().extra;
378 let mut array = extract_array(
379 ctx,
380 args.pop().unwrap(),
381 "`fold` requires first arg to be a tuple",
382 )?;
383
384 array.push(elem);
385 Ok(Value::Tuple(array))
386 }
387}
388
389#[derive(Debug, Clone, Copy, Default)]
421pub struct Merge;
422
423impl<T: Clone> NativeFn<T> for Merge {
424 fn evaluate<'a>(
425 &self,
426 mut args: Vec<SpannedValue<'a, T>>,
427 ctx: &mut CallContext<'_, 'a, T>,
428 ) -> EvalResult<'a, T> {
429 ctx.check_args_count(&args, 2)?;
430 let second = extract_array(
431 ctx,
432 args.pop().unwrap(),
433 "`merge` requires second arg to be a tuple",
434 )?;
435 let mut first = extract_array(
436 ctx,
437 args.pop().unwrap(),
438 "`merge` requires first arg to be a tuple",
439 )?;
440
441 first.extend_from_slice(&second);
442 Ok(Value::Tuple(first))
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use crate::{
450 arith::{OrdArithmetic, StdArithmetic, WrappingArithmetic},
451 Environment, VariableMap,
452 };
453
454 use arithmetic_parser::grammars::{NumGrammar, NumLiteral, Parse, Untyped};
455 use assert_matches::assert_matches;
456
457 fn test_len_function<T: NumLiteral>(arithmetic: &dyn OrdArithmetic<T>)
458 where
459 Len: NativeFn<T>,
460 {
461 let code = r#"
462 (1, 2, 3).len() == 3 && ().len() == 0 &&
463 #{}.len() == 0 && #{ x: 1 }.len() == 1 && #{ x: 1, y: 2 }.len() == 2
464 "#;
465 let block = Untyped::<NumGrammar<T>>::parse_statements(code).unwrap();
466 let mut env = Environment::new();
467 let module = env
468 .insert("len", Value::native_fn(Len))
469 .compile_module("len", &block)
470 .unwrap();
471
472 let output = module.with_arithmetic(arithmetic).run().unwrap();
473 assert_matches!(output, Value::Bool(true));
474 }
475
476 #[test]
477 fn len_function_in_floating_point_arithmetic() {
478 test_len_function::<f32>(&StdArithmetic);
479 test_len_function::<f64>(&StdArithmetic);
480 }
481
482 #[test]
483 fn len_function_in_int_arithmetic() {
484 test_len_function::<u8>(&WrappingArithmetic);
485 test_len_function::<i8>(&WrappingArithmetic);
486 test_len_function::<u64>(&WrappingArithmetic);
487 test_len_function::<i64>(&WrappingArithmetic);
488 }
489
490 #[test]
491 fn len_function_with_number_overflow() {
492 let code = "xs.len()";
493 let block = Untyped::<NumGrammar<i8>>::parse_statements(code).unwrap();
494 let mut env = Environment::new();
495 let module = env
496 .insert("xs", Value::Tuple(vec![Value::Bool(true); 128]))
497 .insert("len", Value::native_fn(Len))
498 .compile_module("len", &block)
499 .unwrap();
500
501 let err = module
502 .with_arithmetic(&WrappingArithmetic)
503 .run()
504 .unwrap_err();
505 assert_matches!(
506 err.source().kind(),
507 ErrorKind::NativeCall(msg) if msg.contains("length to number")
508 );
509 }
510
511 #[test]
512 fn array_function_in_floating_point_arithmetic() {
513 let code = r#"
514 array(0, |_| 1) == () && array(-1, |_| 1) == () &&
515 array(0.1, |_| 1) == () && array(0.999, |_| 1) == () &&
516 array(1, |_| 1) == (1,) && array(1.5, |_| 1) == (1,) &&
517 array(2, |_| 1) == (1, 1) && array(3, |i| i) == (0, 1, 2)
518 "#;
519 let block = Untyped::<NumGrammar<f32>>::parse_statements(code).unwrap();
520 let mut env = Environment::new();
521 let module = env
522 .insert("array", Value::native_fn(Array))
523 .compile_module("array", &block)
524 .unwrap();
525
526 let output = module.with_arithmetic(&StdArithmetic).run().unwrap();
527 assert_matches!(output, Value::Bool(true));
528 }
529
530 #[test]
531 fn array_function_in_unsigned_int_arithmetic() {
532 let code = r#"
533 array(0, |_| 1) == () && array(1, |_| 1) == (1,) && array(3, |i| i) == (0, 1, 2)
534 "#;
535 let block = Untyped::<NumGrammar<u32>>::parse_statements(code).unwrap();
536 let mut env = Environment::new();
537 let module = env
538 .insert("array", Value::native_fn(Array))
539 .compile_module("array", &block)
540 .unwrap();
541
542 let output = module.with_arithmetic(&WrappingArithmetic).run().unwrap();
543 assert_matches!(output, Value::Bool(true));
544 }
545}