use core::{cmp, fmt, marker::PhantomData};
use crate::{
alloc::{vec, String, Vec},
error::AuxErrorInfo,
CallContext, Error, ErrorKind, EvalResult, Function, NativeFn, Number, SpannedValue, Value,
ValueType,
};
pub const fn wrap<T, F>(function: F) -> FnWrapper<T, F> {
FnWrapper::new(function)
}
pub struct FnWrapper<T, F> {
function: F,
_arg_types: PhantomData<T>,
}
impl<T, F> fmt::Debug for FnWrapper<T, F>
where
F: fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("FnWrapper")
.field("function", &self.function)
.finish()
}
}
impl<T, F: Clone> Clone for FnWrapper<T, F> {
fn clone(&self) -> Self {
Self {
function: self.function.clone(),
_arg_types: PhantomData,
}
}
}
impl<T, F: Copy> Copy for FnWrapper<T, F> {}
impl<T, F> FnWrapper<T, F> {
pub const fn new(function: F) -> Self {
Self {
function,
_arg_types: PhantomData,
}
}
}
#[derive(Debug, Clone)]
pub struct FromValueError {
kind: FromValueErrorKind,
arg_index: usize,
location: Vec<FromValueErrorLocation>,
}
impl FromValueError {
pub(crate) fn invalid_type<T>(expected: ValueType, actual_value: &Value<'_, T>) -> Self {
Self {
kind: FromValueErrorKind::InvalidType {
expected,
actual: actual_value.value_type(),
},
arg_index: 0,
location: vec![],
}
}
fn add_location(mut self, location: FromValueErrorLocation) -> Self {
self.location.push(location);
self
}
#[doc(hidden)]
pub fn set_arg_index(&mut self, index: usize) {
self.arg_index = index;
self.location.reverse();
}
pub fn kind(&self) -> &FromValueErrorKind {
&self.kind
}
pub fn arg_index(&self) -> usize {
self.arg_index
}
pub fn location(&self) -> &[FromValueErrorLocation] {
&self.location
}
}
impl fmt::Display for FromValueError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
formatter,
"{}. Error location: arg{}",
self.kind, self.arg_index
)?;
for location_element in &self.location {
match location_element {
FromValueErrorLocation::Tuple { index, .. } => write!(formatter, ".{}", index)?,
FromValueErrorLocation::Array { index, .. } => write!(formatter, "[{}]", index)?,
}
}
Ok(())
}
}
#[cfg(feature = "std")]
impl std::error::Error for FromValueError {}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum FromValueErrorKind {
InvalidType {
expected: ValueType,
actual: ValueType,
},
}
impl fmt::Display for FromValueErrorKind {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidType { expected, actual } => {
write!(formatter, "Cannot convert {} to {}", actual, expected)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum FromValueErrorLocation {
Tuple {
size: usize,
index: usize,
},
Array {
size: usize,
index: usize,
},
}
pub trait TryFromValue<'a, T>: Sized {
fn try_from_value(value: Value<'a, T>) -> Result<Self, FromValueError>;
}
impl<'a, T: Number> TryFromValue<'a, T> for T {
fn try_from_value(value: Value<'a, T>) -> Result<Self, FromValueError> {
match value {
Value::Number(number) => Ok(number),
_ => Err(FromValueError::invalid_type(ValueType::Number, &value)),
}
}
}
impl<'a, T> TryFromValue<'a, T> for bool {
fn try_from_value(value: Value<'a, T>) -> Result<Self, FromValueError> {
match value {
Value::Bool(flag) => Ok(flag),
_ => Err(FromValueError::invalid_type(ValueType::Bool, &value)),
}
}
}
impl<'a, T> TryFromValue<'a, T> for Value<'a, T> {
fn try_from_value(value: Value<'a, T>) -> Result<Self, FromValueError> {
Ok(value)
}
}
impl<'a, T> TryFromValue<'a, T> for Function<'a, T> {
fn try_from_value(value: Value<'a, T>) -> Result<Self, FromValueError> {
match value {
Value::Function(function) => Ok(function),
_ => Err(FromValueError::invalid_type(ValueType::Function, &value)),
}
}
}
impl<'a, U, T> TryFromValue<'a, T> for Vec<U>
where
U: TryFromValue<'a, T>,
{
fn try_from_value(value: Value<'a, T>) -> Result<Self, FromValueError> {
match value {
Value::Tuple(values) => {
let tuple_len = values.len();
let mut collected = Vec::with_capacity(tuple_len);
for (index, element) in values.into_iter().enumerate() {
let converted = U::try_from_value(element).map_err(|err| {
err.add_location(FromValueErrorLocation::Array {
size: tuple_len,
index,
})
})?;
collected.push(converted);
}
Ok(collected)
}
_ => Err(FromValueError::invalid_type(ValueType::Array, &value)),
}
}
}
macro_rules! try_from_value_for_tuple {
($size:expr => $($var:ident : $ty:ident),+) => {
impl<'a, Num, $($ty,)+> TryFromValue<'a, Num> for ($($ty,)+)
where
$($ty: TryFromValue<'a, Num>,)+
{
#[allow(clippy::shadow_unrelated)]
fn try_from_value(value: Value<'a, Num>) -> Result<Self, FromValueError> {
const EXPECTED_TYPE: ValueType = ValueType::Tuple($size);
match value {
Value::Tuple(values) if values.len() == $size => {
let mut values_iter = values.into_iter().enumerate();
$(
let (index, $var) = values_iter.next().unwrap();
let $var = $ty::try_from_value($var).map_err(|err| {
err.add_location(FromValueErrorLocation::Tuple {
size: $size,
index,
})
})?;
)+
Ok(($($var,)+))
}
_ => Err(FromValueError::invalid_type(EXPECTED_TYPE, &value)),
}
}
}
};
}
try_from_value_for_tuple!(1 => x0: T);
try_from_value_for_tuple!(2 => x0: T, x1: U);
try_from_value_for_tuple!(3 => x0: T, x1: U, x2: V);
try_from_value_for_tuple!(4 => x0: T, x1: U, x2: V, x3: W);
try_from_value_for_tuple!(5 => x0: T, x1: U, x2: V, x3: W, x4: X);
try_from_value_for_tuple!(6 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
try_from_value_for_tuple!(7 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
try_from_value_for_tuple!(8 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
try_from_value_for_tuple!(9 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
try_from_value_for_tuple!(10 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
#[derive(Debug)]
#[non_exhaustive]
pub enum ErrorOutput<'a> {
Spanned(Error<'a>),
Message(String),
}
impl<'a> ErrorOutput<'a> {
#[doc(hidden)]
pub fn into_spanned<A>(self, context: &CallContext<'_, 'a, A>) -> Error<'a> {
match self {
Self::Spanned(err) => err,
Self::Message(message) => context.call_site_error(ErrorKind::native(message)),
}
}
}
pub trait IntoEvalResult<'a, T> {
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>>;
}
impl<'a, T, U> IntoEvalResult<'a, T> for Result<U, String>
where
U: IntoEvalResult<'a, T>,
{
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>> {
self.map_err(ErrorOutput::Message)
.and_then(U::into_eval_result)
}
}
impl<'a, T, U> IntoEvalResult<'a, T> for Result<U, Error<'a>>
where
U: IntoEvalResult<'a, T>,
{
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>> {
self.map_err(ErrorOutput::Spanned)
.and_then(U::into_eval_result)
}
}
impl<'a, T: Number> IntoEvalResult<'a, T> for T {
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>> {
Ok(Value::Number(self))
}
}
impl<'a, T> IntoEvalResult<'a, T> for () {
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>> {
Ok(Value::void())
}
}
impl<'a, T> IntoEvalResult<'a, T> for bool {
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>> {
Ok(Value::Bool(self))
}
}
impl<'a, T> IntoEvalResult<'a, T> for cmp::Ordering {
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>> {
Ok(Value::opaque_ref(self))
}
}
impl<'a, T> IntoEvalResult<'a, T> for Value<'a, T> {
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>> {
Ok(self)
}
}
impl<'a, T> IntoEvalResult<'a, T> for Function<'a, T> {
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>> {
Ok(Value::Function(self))
}
}
impl<'a, U, T> IntoEvalResult<'a, T> for Vec<U>
where
U: IntoEvalResult<'a, T>,
{
fn into_eval_result(self) -> Result<Value<'a, T>, ErrorOutput<'a>> {
let values = self
.into_iter()
.map(U::into_eval_result)
.collect::<Result<Vec<_>, _>>()?;
Ok(Value::Tuple(values))
}
}
macro_rules! into_value_for_tuple {
($($i:tt : $ty:ident),+) => {
impl<'a, Num, $($ty,)+> IntoEvalResult<'a, Num> for ($($ty,)+)
where
$($ty: IntoEvalResult<'a, Num>,)+
{
fn into_eval_result(self) -> Result<Value<'a, Num>, ErrorOutput<'a>> {
Ok(Value::Tuple(vec![$(self.$i.into_eval_result()?,)+]))
}
}
};
}
into_value_for_tuple!(0: T);
into_value_for_tuple!(0: T, 1: U);
into_value_for_tuple!(0: T, 1: U, 2: V);
into_value_for_tuple!(0: T, 1: U, 2: V, 3: W);
into_value_for_tuple!(0: T, 1: U, 2: V, 3: W, 4: X);
into_value_for_tuple!(0: T, 1: U, 2: V, 3: W, 4: X, 5: Y);
into_value_for_tuple!(0: T, 1: U, 2: V, 3: W, 4: X, 5: Y, 6: Z);
into_value_for_tuple!(0: T, 1: U, 2: V, 3: W, 4: X, 5: Y, 6: Z, 7: A);
into_value_for_tuple!(0: T, 1: U, 2: V, 3: W, 4: X, 5: Y, 6: Z, 7: A, 8: B);
into_value_for_tuple!(0: T, 1: U, 2: V, 3: W, 4: X, 5: Y, 6: Z, 7: A, 8: B, 9: C);
macro_rules! arity_fn {
($arity:tt => $($arg_name:ident : $t:ident),*) => {
impl<Num, F, Ret, $($t,)*> NativeFn<Num> for FnWrapper<(Ret, $($t,)*), F>
where
F: Fn($($t,)*) -> Ret,
$($t: for<'val> TryFromValue<'val, Num>,)*
Ret: for<'val> IntoEvalResult<'val, Num>,
{
#[allow(clippy::shadow_unrelated)]
#[allow(unused_variables, unused_mut)]
fn evaluate<'a>(
&self,
args: Vec<SpannedValue<'a, Num>>,
context: &mut CallContext<'_, 'a, Num>,
) -> EvalResult<'a, Num> {
context.check_args_count(&args, $arity)?;
let mut args_iter = args.into_iter().enumerate();
$(
let (index, $arg_name) = args_iter.next().unwrap();
let span = $arg_name.with_no_extra();
let $arg_name = $t::try_from_value($arg_name.extra).map_err(|mut err| {
err.set_arg_index(index);
context
.call_site_error(ErrorKind::Wrapper(err))
.with_span(&span, AuxErrorInfo::InvalidArg)
})?;
)*
let output = (self.function)($($arg_name,)*);
output.into_eval_result().map_err(|err| err.into_spanned(context))
}
}
};
}
arity_fn!(0 =>);
arity_fn!(1 => x0: T);
arity_fn!(2 => x0: T, x1: U);
arity_fn!(3 => x0: T, x1: U, x2: V);
arity_fn!(4 => x0: T, x1: U, x2: V, x3: W);
arity_fn!(5 => x0: T, x1: U, x2: V, x3: W, x4: X);
arity_fn!(6 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
arity_fn!(7 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
arity_fn!(8 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
arity_fn!(9 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
arity_fn!(10 => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
pub type Unary<T> = FnWrapper<(T, T), fn(T) -> T>;
pub type Binary<T> = FnWrapper<(T, T, T), fn(T, T) -> T>;
pub type Ternary<T> = FnWrapper<(T, T, T, T), fn(T, T, T) -> T>;
pub type Quaternary<T> = FnWrapper<(T, T, T, T, T), fn(T, T, T, T) -> T>;
#[macro_export]
macro_rules! wrap_fn {
(0, $function:expr) => { $crate::wrap_fn!(@arg 0 =>; $function) };
(1, $function:expr) => { $crate::wrap_fn!(@arg 1 => x0; $function) };
(2, $function:expr) => { $crate::wrap_fn!(@arg 2 => x0, x1; $function) };
(3, $function:expr) => { $crate::wrap_fn!(@arg 3 => x0, x1, x2; $function) };
(4, $function:expr) => { $crate::wrap_fn!(@arg 4 => x0, x1, x2, x3; $function) };
(5, $function:expr) => { $crate::wrap_fn!(@arg 5 => x0, x1, x2, x3, x4; $function) };
(6, $function:expr) => { $crate::wrap_fn!(@arg 6 => x0, x1, x2, x3, x4, x5; $function) };
(7, $function:expr) => { $crate::wrap_fn!(@arg 7 => x0, x1, x2, x3, x4, x5, x6; $function) };
(8, $function:expr) => {
$crate::wrap_fn!(@arg 8 => x0, x1, x2, x3, x4, x5, x6, x7; $function)
};
(9, $function:expr) => {
$crate::wrap_fn!(@arg 9 => x0, x1, x2, x3, x4, x5, x6, x7, x8; $function)
};
(10, $function:expr) => {
$crate::wrap_fn!(@arg 10 => x0, x1, x2, x3, x4, x5, x6, x7, x8, x9; $function)
};
($($ctx:ident,)? @arg $arity:expr => $($arg_name:ident),*; $function:expr) => {{
let function = $function;
$crate::fns::enforce_closure_type(move |args, context| {
context.check_args_count(&args, $arity)?;
let mut args_iter = args.into_iter().enumerate();
$(
let (index, $arg_name) = args_iter.next().unwrap();
let span = $arg_name.with_no_extra();
let $arg_name = $crate::fns::TryFromValue::try_from_value($arg_name.extra)
.map_err(|mut err| {
err.set_arg_index(index);
context
.call_site_error($crate::error::ErrorKind::Wrapper(err))
.with_span(&span, $crate::error::AuxErrorInfo::InvalidArg)
})?;
)+
let output = function($({ let $ctx = (); context },)? $($arg_name,)+);
$crate::fns::IntoEvalResult::into_eval_result(output)
.map_err(|err| err.into_spanned(context))
})
}}
}
#[macro_export]
macro_rules! wrap_fn_with_context {
(0, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 0 =>; $function) };
(1, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 1 => x0; $function) };
(2, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 2 => x0, x1; $function) };
(3, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 3 => x0, x1, x2; $function) };
(4, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 4 => x0, x1, x2, x3; $function) };
(5, $function:expr) => { $crate::wrap_fn!(_ctx, @arg 5 => x0, x1, x2, x3, x4; $function) };
(6, $function:expr) => {
$crate::wrap_fn!(_ctx, @arg 6 => x0, x1, x2, x3, x4, x5; $function)
};
(7, $function:expr) => {
$crate::wrap_fn!(_ctx, @arg 7 => x0, x1, x2, x3, x4, x5, x6; $function)
};
(8, $function:expr) => {
$crate::wrap_fn!(_ctx, @arg 8 => x0, x1, x2, x3, x4, x5, x6, x7; $function)
};
(9, $function:expr) => {
$crate::wrap_fn!(_ctx, @arg 9 => x0, x1, x2, x3, x4, x5, x6, x7, x8; $function)
};
(10, $function:expr) => {
$crate::wrap_fn!(_ctx, @arg 10 => x0, x1, x2, x3, x4, x5, x6, x7, x8, x9; $function)
};
}
#[doc(hidden)]
pub fn enforce_closure_type<T, A, F>(function: F) -> F
where
F: for<'a> Fn(Vec<SpannedValue<'a, T>>, &mut CallContext<'_, 'a, A>) -> EvalResult<'a, T>,
{
function
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{format, ToOwned},
Environment, ExecutableModule, Prelude, WildcardId,
};
use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
use assert_matches::assert_matches;
use core::f32;
#[test]
fn functions_with_primitive_args() {
let unary_fn = Unary::new(|x: f32| x + 3.0);
let binary_fn = Binary::new(f32::min);
let ternary_fn = Ternary::new(|x: f32, y, z| if x > 0.0 { y } else { z });
let program = r#"
unary_fn(2) == 5 && binary_fn(1, -3) == -3 &&
ternary_fn(1, 2, 3) == 2 && ternary_fn(-1, 2, 3) == 3
"#;
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("unary_fn", Value::native_fn(unary_fn))
.with_import("binary_fn", Value::native_fn(binary_fn))
.with_import("ternary_fn", Value::native_fn(ternary_fn))
.build();
assert_eq!(module.run().unwrap(), Value::Bool(true));
}
fn array_min_max(values: Vec<f32>) -> (f32, f32) {
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
for value in values {
if value < min {
min = value;
}
if value > max {
max = value;
}
}
(min, max)
}
fn overly_convoluted_fn(xs: Vec<(f32, f32)>, ys: (Vec<f32>, f32)) -> f32 {
xs.into_iter().map(|(a, b)| a + b).sum::<f32>() + ys.0.into_iter().sum::<f32>() + ys.1
}
#[test]
fn functions_with_composite_args() {
let program = r#"
(1, 5, -3, 2, 1).array_min_max() == (-3, 5) &&
total_sum(((1, 2), (3, 4)), ((5, 6, 7), 8)) == 36
"#;
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("array_min_max", Value::wrapped_fn(array_min_max))
.with_import("total_sum", Value::wrapped_fn(overly_convoluted_fn))
.build();
assert_eq!(module.run().unwrap(), Value::Bool(true));
}
fn sum_arrays(xs: Vec<f32>, ys: Vec<f32>) -> Result<Vec<f32>, String> {
if xs.len() == ys.len() {
Ok(xs.into_iter().zip(ys).map(|(x, y)| x + y).collect())
} else {
Err("Summed arrays must have the same size".to_owned())
}
}
#[test]
fn fallible_function() {
let program = "(1, 2, 3).sum_arrays((4, 5, 6)) == (5, 7, 9)";
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("sum_arrays", Value::wrapped_fn(sum_arrays))
.build();
assert_eq!(module.run().unwrap(), Value::Bool(true));
}
#[test]
fn fallible_function_with_bogus_program() {
let program = "(1, 2, 3).sum_arrays((4, 5))";
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let err = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("sum_arrays", Value::wrapped_fn(sum_arrays))
.build()
.run()
.unwrap_err();
assert!(err
.source()
.kind()
.to_short_string()
.contains("Summed arrays must have the same size"));
}
#[test]
fn function_with_bool_return_value() {
let contains = wrap(|(a, b): (f32, f32), x: f32| (a..=b).contains(&x));
let program = "(-1, 2).contains(0) && !(1, 3).contains(0)";
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("contains", Value::native_fn(contains))
.build();
assert_eq!(module.run().unwrap(), Value::Bool(true));
}
#[test]
fn function_with_void_return_value() {
let mut env = Environment::new();
env.insert_wrapped_fn("assert_eq", |expected: f32, actual: f32| {
if (expected - actual).abs() < f32::EPSILON {
Ok(())
} else {
Err(format!(
"Assertion failed: expected {}, got {}",
expected, actual
))
}
});
let program = "assert_eq(3, 1 + 2)";
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_imports_from(&env)
.build();
assert!(module.run().unwrap().is_void());
let bogus_program = "assert_eq(3, 1 - 2)";
let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program).unwrap();
let err = ExecutableModule::builder(WildcardId, &bogus_block)
.unwrap()
.with_imports_from(&env)
.build()
.run()
.unwrap_err();
assert_matches!(
err.source().kind(),
ErrorKind::NativeCall(ref msg) if msg.contains("Assertion failed")
);
}
#[test]
fn function_with_bool_argument() {
let program = "flip_sign(-1, true) == 1 && flip_sign(-1, false) == -1";
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_imports_from(&Prelude)
.with_import(
"flip_sign",
Value::wrapped_fn(|val: f32, flag: bool| if flag { -val } else { val }),
)
.build();
assert_eq!(module.run().unwrap(), Value::Bool(true));
}
#[test]
fn error_reporting_with_destructuring() {
let program = "((true, 1), (2, 3)).destructure()";
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let err = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_imports_from(&Prelude)
.with_import(
"destructure",
Value::wrapped_fn(|values: Vec<(bool, f32)>| {
values
.into_iter()
.map(|(flag, x)| if flag { x } else { 0.0 })
.sum::<f32>()
}),
)
.build()
.run()
.unwrap_err();
let err_message = err.source().kind().to_short_string();
assert!(err_message.contains("Cannot convert number to bool"));
assert!(err_message.contains("location: arg0[1].0"));
}
}