1use core::{fmt, marker::PhantomData};
4
5use crate::{
6 alloc::Vec, error::AuxErrorInfo, CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue,
7};
8
9mod traits;
10
11pub use self::traits::{
12 ErrorOutput, FromValueError, FromValueErrorKind, FromValueErrorLocation, IntoEvalResult,
13 TryFromValue,
14};
15
16pub const fn wrap<T, F>(function: F) -> FnWrapper<T, F> {
21 FnWrapper::new(function)
22}
23
24pub struct FnWrapper<T, F> {
86 function: F,
87 _arg_types: PhantomData<T>,
88}
89
90impl<T, F> fmt::Debug for FnWrapper<T, F>
91where
92 F: fmt::Debug,
93{
94 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
95 formatter
96 .debug_struct("FnWrapper")
97 .field("function", &self.function)
98 .finish()
99 }
100}
101
102impl<T, F: Clone> Clone for FnWrapper<T, F> {
103 fn clone(&self) -> Self {
104 Self {
105 function: self.function.clone(),
106 _arg_types: PhantomData,
107 }
108 }
109}
110
111impl<T, F: Copy> Copy for FnWrapper<T, F> {}
112
113impl<T, F> FnWrapper<T, F> {
116 pub const fn new(function: F) -> Self {
125 Self {
126 function,
127 _arg_types: PhantomData,
128 }
129 }
130}
131
132macro_rules! arity_fn {
133 ($arity:tt => $($arg_name:ident : $t:ident),*) => {
134 impl<Num, F, Ret, $($t,)*> NativeFn<Num> for FnWrapper<(Ret, $($t,)*), F>
135 where
136 F: Fn($($t,)*) -> Ret,
137 $($t: for<'val> TryFromValue<'val, Num>,)*
138 Ret: for<'val> IntoEvalResult<'val, Num>,
139 {
140 #[allow(clippy::shadow_unrelated)] #[allow(unused_variables, unused_mut)] fn evaluate<'a>(
143 &self,
144 args: Vec<SpannedValue<'a, Num>>,
145 context: &mut CallContext<'_, 'a, Num>,
146 ) -> EvalResult<'a, Num> {
147 context.check_args_count(&args, $arity)?;
148 let mut args_iter = args.into_iter().enumerate();
149
150 $(
151 let (index, $arg_name) = args_iter.next().unwrap();
152 let span = $arg_name.with_no_extra();
153 let $arg_name = $t::try_from_value($arg_name.extra).map_err(|mut err| {
154 err.set_arg_index(index);
155 context
156 .call_site_error(ErrorKind::Wrapper(err))
157 .with_span(&span, AuxErrorInfo::InvalidArg)
158 })?;
159 )*
160
161 let output = (self.function)($($arg_name,)*);
162 output.into_eval_result().map_err(|err| err.into_spanned(context))
163 }
164 }
165 };
166}
167
168arity_fn!(0 =>);
169arity_fn!(1 => x0: T);
170arity_fn!(2 => x0: T, x1: U);
171arity_fn!(3 => x0: T, x1: U, x2: V);
172arity_fn!(4 => x0: T, x1: U, x2: V, x3: W);
173arity_fn!(5 => x0: T, x1: U, x2: V, x3: W, x4: X);
174arity_fn!(6 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
175arity_fn!(7 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
176arity_fn!(8 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
177arity_fn!(9 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
178arity_fn!(10 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
179
180pub type Unary<T> = FnWrapper<(T, T), fn(T) -> T>;
182
183pub type Binary<T> = FnWrapper<(T, T, T), fn(T, T) -> T>;
185
186pub type Ternary<T> = FnWrapper<(T, T, T, T), fn(T, T, T) -> T>;
188
189pub type Quaternary<T> = FnWrapper<(T, T, T, T, T), fn(T, T, T, T) -> T>;
191
192#[macro_export]
255macro_rules! wrap_fn {
256 (0, $function:expr) => { $crate::wrap_fn!(@arg 0 =>; $function) };
257 (1, $function:expr) => { $crate::wrap_fn!(@arg 1 => x0; $function) };
258 (2, $function:expr) => { $crate::wrap_fn!(@arg 2 => x0, x1; $function) };
259 (3, $function:expr) => { $crate::wrap_fn!(@arg 3 => x0, x1, x2; $function) };
260 (4, $function:expr) => { $crate::wrap_fn!(@arg 4 => x0, x1, x2, x3; $function) };
261 (5, $function:expr) => { $crate::wrap_fn!(@arg 5 => x0, x1, x2, x3, x4; $function) };
262 (6, $function:expr) => { $crate::wrap_fn!(@arg 6 => x0, x1, x2, x3, x4, x5; $function) };
263 (7, $function:expr) => { $crate::wrap_fn!(@arg 7 => x0, x1, x2, x3, x4, x5, x6; $function) };
264 (8, $function:expr) => {
265 $crate::wrap_fn!(@arg 8 => x0, x1, x2, x3, x4, x5, x6, x7; $function)
266 };
267 (9, $function:expr) => {
268 $crate::wrap_fn!(@arg 9 => x0, x1, x2, x3, x4, x5, x6, x7, x8; $function)
269 };
270 (10, $function:expr) => {
271 $crate::wrap_fn!(@arg 10 => x0, x1, x2, x3, x4, x5, x6, x7, x8, x9; $function)
272 };
273
274 ($($ctx:ident,)? @arg $arity:expr => $($arg_name:ident),*; $function:expr) => {{
275 let function = $function;
276 $crate::fns::enforce_closure_type(move |args, context| {
277 context.check_args_count(&args, $arity)?;
278 let mut args_iter = args.into_iter().enumerate();
279
280 $(
281 let (index, $arg_name) = args_iter.next().unwrap();
282 let span = $arg_name.with_no_extra();
283 let $arg_name = $crate::fns::TryFromValue::try_from_value($arg_name.extra)
284 .map_err(|mut err| {
285 err.set_arg_index(index);
286 context
287 .call_site_error($crate::error::ErrorKind::Wrapper(err))
288 .with_span(&span, $crate::error::AuxErrorInfo::InvalidArg)
289 })?;
290 )+
291
292 let output = function($({ let $ctx = (); context },)? $($arg_name,)+);
294 $crate::fns::IntoEvalResult::into_eval_result(output)
295 .map_err(|err| err.into_spanned(context))
296 })
297 }}
298}
299
300#[macro_export]
339macro_rules! wrap_fn_with_context {
340 (0, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 0 =>; $function) };
341 (1, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 1 => x0; $function) };
342 (2, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 2 => x0, x1; $function) };
343 (3, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 3 => x0, x1, x2; $function) };
344 (4, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 4 => x0, x1, x2, x3; $function) };
345 (5, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 5 => x0, x1, x2, x3, x4; $function) };
346 (6, $function:expr) => {
347 $crate::wrap_fn!(_ctx, @arg 6 => x0, x1, x2, x3, x4, x5; $function)
348 };
349 (7, $function:expr) => {
350 $crate::wrap_fn!(_ctx, @arg 7 => x0, x1, x2, x3, x4, x5, x6; $function)
351 };
352 (8, $function:expr) => {
353 $crate::wrap_fn!(_ctx, @arg 8 => x0, x1, x2, x3, x4, x5, x6, x7; $function)
354 };
355 (9, $function:expr) => {
356 $crate::wrap_fn!(_ctx, @arg 9 => x0, x1, x2, x3, x4, x5, x6, x7, x8; $function)
357 };
358 (10, $function:expr) => {
359 $crate::wrap_fn!(_ctx, @arg 10 => x0, x1, x2, x3, x4, x5, x6, x7, x8, x9; $function)
360 };
361}
362
363#[doc(hidden)] pub fn enforce_closure_type<T, A, F>(function: F) -> F
365where
366 F: for<'a> Fn(Vec<SpannedValue<'a, T>>, &mut CallContext<'_, 'a, A>) -> EvalResult<'a, T>,
367{
368 function
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::{
375 alloc::{format, ToOwned},
376 Environment, ExecutableModule, Prelude, Value, WildcardId,
377 };
378
379 use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
380 use assert_matches::assert_matches;
381
382 #[test]
383 fn functions_with_primitive_args() {
384 let unary_fn = Unary::new(|x: f32| x + 3.0);
385 let binary_fn = Binary::new(f32::min);
386 let ternary_fn = Ternary::new(|x: f32, y, z| if x > 0.0 { y } else { z });
387
388 let program = r#"
389 unary_fn(2) == 5 && binary_fn(1, -3) == -3 &&
390 ternary_fn(1, 2, 3) == 2 && ternary_fn(-1, 2, 3) == 3
391 "#;
392 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
393
394 let module = ExecutableModule::builder(WildcardId, &block)
395 .unwrap()
396 .with_import("unary_fn", Value::native_fn(unary_fn))
397 .with_import("binary_fn", Value::native_fn(binary_fn))
398 .with_import("ternary_fn", Value::native_fn(ternary_fn))
399 .build();
400 assert_eq!(module.run().unwrap(), Value::Bool(true));
401 }
402
403 fn array_min_max(values: Vec<f32>) -> (f32, f32) {
404 let mut min = f32::INFINITY;
405 let mut max = f32::NEG_INFINITY;
406
407 for value in values {
408 if value < min {
409 min = value;
410 }
411 if value > max {
412 max = value;
413 }
414 }
415 (min, max)
416 }
417
418 fn overly_convoluted_fn(xs: Vec<(f32, f32)>, ys: (Vec<f32>, f32)) -> f32 {
419 xs.into_iter().map(|(a, b)| a + b).sum::<f32>() + ys.0.into_iter().sum::<f32>() + ys.1
420 }
421
422 #[test]
423 fn functions_with_composite_args() {
424 let program = r#"
425 (1, 5, -3, 2, 1).array_min_max() == (-3, 5) &&
426 total_sum(((1, 2), (3, 4)), ((5, 6, 7), 8)) == 36
427 "#;
428 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
429
430 let module = ExecutableModule::builder(WildcardId, &block)
431 .unwrap()
432 .with_import("array_min_max", Value::wrapped_fn(array_min_max))
433 .with_import("total_sum", Value::wrapped_fn(overly_convoluted_fn))
434 .build();
435 assert_eq!(module.run().unwrap(), Value::Bool(true));
436 }
437
438 fn sum_arrays(xs: Vec<f32>, ys: Vec<f32>) -> Result<Vec<f32>, String> {
439 if xs.len() == ys.len() {
440 Ok(xs.into_iter().zip(ys).map(|(x, y)| x + y).collect())
441 } else {
442 Err("Summed arrays must have the same size".to_owned())
443 }
444 }
445
446 #[test]
447 fn fallible_function() {
448 let program = "(1, 2, 3).sum_arrays((4, 5, 6)) == (5, 7, 9)";
449 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
450 let module = ExecutableModule::builder(WildcardId, &block)
451 .unwrap()
452 .with_import("sum_arrays", Value::wrapped_fn(sum_arrays))
453 .build();
454 assert_eq!(module.run().unwrap(), Value::Bool(true));
455 }
456
457 #[test]
458 fn fallible_function_with_bogus_program() {
459 let program = "(1, 2, 3).sum_arrays((4, 5))";
460 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
461
462 let err = ExecutableModule::builder(WildcardId, &block)
463 .unwrap()
464 .with_import("sum_arrays", Value::wrapped_fn(sum_arrays))
465 .build()
466 .run()
467 .unwrap_err();
468 assert!(err
469 .source()
470 .kind()
471 .to_short_string()
472 .contains("Summed arrays must have the same size"));
473 }
474
475 #[test]
476 fn function_with_bool_return_value() {
477 let contains = wrap(|(a, b): (f32, f32), x: f32| (a..=b).contains(&x));
478
479 let program = "(-1, 2).contains(0) && !(1, 3).contains(0)";
480 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
481 let module = ExecutableModule::builder(WildcardId, &block)
482 .unwrap()
483 .with_import("contains", Value::native_fn(contains))
484 .build();
485 assert_eq!(module.run().unwrap(), Value::Bool(true));
486 }
487
488 #[test]
489 fn function_with_void_return_value() {
490 let mut env = Environment::new();
491 env.insert_wrapped_fn("assert_eq", |expected: f32, actual: f32| {
492 if (expected - actual).abs() < f32::EPSILON {
493 Ok(())
494 } else {
495 Err(format!(
496 "Assertion failed: expected {}, got {}",
497 expected, actual
498 ))
499 }
500 });
501
502 let program = "assert_eq(3, 1 + 2)";
503 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
504 let module = ExecutableModule::builder(WildcardId, &block)
505 .unwrap()
506 .with_imports_from(&env)
507 .build();
508 assert!(module.run().unwrap().is_void());
509
510 let bogus_program = "assert_eq(3, 1 - 2)";
511 let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program).unwrap();
512 let err = ExecutableModule::builder(WildcardId, &bogus_block)
513 .unwrap()
514 .with_imports_from(&env)
515 .build()
516 .run()
517 .unwrap_err();
518 assert_matches!(
519 err.source().kind(),
520 ErrorKind::NativeCall(ref msg) if msg.contains("Assertion failed")
521 );
522 }
523
524 #[test]
525 fn function_with_bool_argument() {
526 let program = "flip_sign(-1, true) == 1 && flip_sign(-1, false) == -1";
527 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
528
529 let module = ExecutableModule::builder(WildcardId, &block)
530 .unwrap()
531 .with_imports_from(&Prelude)
532 .with_import(
533 "flip_sign",
534 Value::wrapped_fn(|val: f32, flag: bool| if flag { -val } else { val }),
535 )
536 .build();
537 assert_eq!(module.run().unwrap(), Value::Bool(true));
538 }
539
540 #[test]
541 fn error_reporting_with_destructuring() {
542 let program = "((true, 1), (2, 3)).destructure()";
543 let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
544
545 let err = ExecutableModule::builder(WildcardId, &block)
546 .unwrap()
547 .with_imports_from(&Prelude)
548 .with_import(
549 "destructure",
550 Value::wrapped_fn(|values: Vec<(bool, f32)>| {
551 values
552 .into_iter()
553 .map(|(flag, x)| if flag { x } else { 0.0 })
554 .sum::<f32>()
555 }),
556 )
557 .build()
558 .run()
559 .unwrap_err();
560
561 let err_message = err.source().kind().to_short_string();
562 assert!(err_message.contains("Cannot convert primitive value to bool"));
563 assert!(err_message.contains("location: arg0[1].0"));
564 }
565}