use core::cmp::Ordering;
use crate::{
alloc::Vec, error::AuxErrorInfo, CallContext, Error, ErrorKind, EvalResult, Function, NativeFn,
SpannedValue, Value,
};
mod array;
mod assertions;
mod flow;
#[cfg(feature = "std")]
mod std;
mod wrapper;
#[cfg(feature = "std")]
pub use self::std::Dbg;
pub use self::{
array::{Array, Filter, Fold, Len, Map, Merge, Push},
assertions::{Assert, AssertEq},
flow::{If, Loop, While},
wrapper::{
enforce_closure_type, wrap, Binary, ErrorOutput, FnWrapper, FromValueError,
FromValueErrorKind, FromValueErrorLocation, IntoEvalResult, Quaternary, Ternary,
TryFromValue, Unary,
},
};
fn extract_primitive<'a, T, A>(
ctx: &CallContext<'_, 'a, A>,
value: SpannedValue<'a, T>,
error_msg: &str,
) -> Result<T, Error<'a>> {
match value.extra {
Value::Prim(value) => Ok(value),
_ => Err(ctx
.call_site_error(ErrorKind::native(error_msg))
.with_span(&value, AuxErrorInfo::InvalidArg)),
}
}
fn extract_array<'a, T, A>(
ctx: &CallContext<'_, 'a, A>,
value: SpannedValue<'a, T>,
error_msg: &str,
) -> Result<Vec<Value<'a, T>>, Error<'a>> {
if let Value::Tuple(array) = value.extra {
Ok(array)
} else {
let err = ErrorKind::native(error_msg);
Err(ctx
.call_site_error(err)
.with_span(&value, AuxErrorInfo::InvalidArg))
}
}
fn extract_fn<'a, T, A>(
ctx: &CallContext<'_, 'a, A>,
value: SpannedValue<'a, T>,
error_msg: &str,
) -> Result<Function<'a, T>, Error<'a>> {
if let Value::Function(function) = value.extra {
Ok(function)
} else {
let err = ErrorKind::native(error_msg);
Err(ctx
.call_site_error(err)
.with_span(&value, AuxErrorInfo::InvalidArg))
}
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum Compare {
Raw,
Min,
Max,
}
impl Compare {
fn extract_primitives<'a, T>(
mut args: Vec<SpannedValue<'a, T>>,
ctx: &mut CallContext<'_, 'a, T>,
) -> Result<(T, T), Error<'a>> {
ctx.check_args_count(&args, 2)?;
let y = args.pop().unwrap();
let x = args.pop().unwrap();
let x = extract_primitive(ctx, x, COMPARE_ERROR_MSG)?;
let y = extract_primitive(ctx, y, COMPARE_ERROR_MSG)?;
Ok((x, y))
}
}
const COMPARE_ERROR_MSG: &str = "Compare requires 2 primitive arguments";
impl<T> NativeFn<T> for Compare {
fn evaluate<'a>(
&self,
args: Vec<SpannedValue<'a, T>>,
ctx: &mut CallContext<'_, 'a, T>,
) -> EvalResult<'a, T> {
let (x, y) = Self::extract_primitives(args, ctx)?;
let maybe_ordering = ctx.arithmetic().partial_cmp(&x, &y);
if let Self::Raw = self {
Ok(maybe_ordering.map_or_else(Value::void, Value::opaque_ref))
} else {
let ordering =
maybe_ordering.ok_or_else(|| ctx.call_site_error(ErrorKind::CannotCompare))?;
let value = match (ordering, self) {
(Ordering::Equal, _)
| (Ordering::Less, Self::Min)
| (Ordering::Greater, Self::Max) => x,
_ => y,
};
Ok(Value::Prim(value))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Environment, ExecutableModule, WildcardId};
use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
use assert_matches::assert_matches;
#[test]
fn if_basic() {
let block = r#"
x = 1.0;
if(x < 2, x + 5, 3 - x)
"#;
let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("if", Value::native_fn(If))
.build();
assert_eq!(module.run().unwrap(), Value::Prim(6.0));
}
#[test]
fn if_with_closures() {
let block = r#"
x = 4.5;
if(x < 2, || x + 5, || 3 - x)()
"#;
let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("if", Value::native_fn(If))
.build();
assert_eq!(module.run().unwrap(), Value::Prim(-1.5));
}
#[test]
fn cmp_sugar() {
let program = "x = 1.0; x > 0 && x <= 3";
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.build();
assert_eq!(module.run().unwrap(), Value::Bool(true));
let bogus_program = "x = 1.0; x > (1, 2)";
let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program).unwrap();
let bogus_module = ExecutableModule::builder(WildcardId, &bogus_block)
.unwrap()
.build();
let err = bogus_module.run().unwrap_err();
let err = err.source();
assert_matches!(err.kind(), ErrorKind::CannotCompare);
assert_eq!(*err.main_span().code().fragment(), "(1, 2)");
}
#[test]
fn loop_basic() {
let program = r#"
// Finds the greatest power of 2 lesser or equal to the value.
discrete_log2 = |x| {
loop(0, |i| {
continue = 2^i <= x;
(continue, if(continue, i + 1, i - 1))
})
};
(discrete_log2(1), discrete_log2(2),
discrete_log2(4), discrete_log2(6.5), discrete_log2(1000))
"#;
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("loop", Value::native_fn(Loop))
.with_import("if", Value::native_fn(If))
.build();
assert_eq!(
module.run().unwrap(),
Value::Tuple(vec![
Value::Prim(0.0),
Value::Prim(1.0),
Value::Prim(2.0),
Value::Prim(2.0),
Value::Prim(9.0),
])
);
}
#[test]
fn max_value_with_fold() {
let program = r#"
max_value = |...xs| {
fold(xs, -Inf, |acc, x| if(x > acc, x, acc))
};
max_value(1, -2, 7, 2, 5) == 7 && max_value(3, -5, 9) == 9
"#;
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("Inf", Value::Prim(f32::INFINITY))
.with_import("fold", Value::native_fn(Fold))
.with_import("if", Value::native_fn(If))
.build();
assert_eq!(module.run().unwrap(), Value::Bool(true));
}
#[test]
fn reverse_list_with_fold() {
const SAMPLES: &[(&[f32], &[f32])] = &[
(&[1.0, 2.0, 3.0], &[3.0, 2.0, 1.0]),
(&[], &[]),
(&[1.0], &[1.0]),
];
let program = r#"
reverse = |xs| {
fold(xs, (), |acc, x| merge((x,), acc))
};
xs = (-4, 3, 0, 1);
xs.reverse() == (1, 0, 3, -4)
"#;
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("merge", Value::native_fn(Merge))
.with_import("fold", Value::native_fn(Fold))
.build();
let mut env = module.imports().into_iter().collect::<Environment<'_, _>>();
assert_eq!(module.run_in_env(&mut env).unwrap(), Value::Bool(true));
let test_block = Untyped::<F32Grammar>::parse_statements("xs.reverse()").unwrap();
let mut test_module = ExecutableModule::builder("test", &test_block)
.unwrap()
.with_import("reverse", env["reverse"].clone())
.set_imports(|_| Value::void());
for &(input, expected) in SAMPLES {
let input = input.iter().copied().map(Value::Prim).collect();
let expected = expected.iter().copied().map(Value::Prim).collect();
test_module.set_import("xs", Value::Tuple(input));
assert_eq!(test_module.run().unwrap(), Value::Tuple(expected));
}
}
#[test]
fn error_with_min_function_args() {
let program = "5 - min(1, (2, 3))";
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("min", Value::native_fn(Compare::Min))
.build();
let err = module.run().unwrap_err();
let err = err.source();
assert_eq!(*err.main_span().code().fragment(), "min(1, (2, 3))");
assert_matches!(
err.kind(),
ErrorKind::NativeCall(ref msg) if msg.contains("requires 2 primitive arguments")
);
}
#[test]
fn error_with_min_function_incomparable_args() {
let program = "5 - min(1, NAN)";
let block = Untyped::<F32Grammar>::parse_statements(program).unwrap();
let module = ExecutableModule::builder(WildcardId, &block)
.unwrap()
.with_import("NAN", Value::Prim(f32::NAN))
.with_import("min", Value::native_fn(Compare::Min))
.build();
let err = module.run().unwrap_err();
let err = err.source();
assert_eq!(*err.main_span().code().fragment(), "min(1, NAN)");
assert_matches!(err.kind(), ErrorKind::CannotCompare);
}
}