arithmetic_eval/fns/assertions.rs
1//! Assertion functions.
2
3use core::fmt;
4
5use crate::{
6 alloc::{format, Vec},
7 error::AuxErrorInfo,
8 CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
9};
10use arithmetic_parser::CodeFragment;
11
12/// Assertion function.
13///
14/// # Type
15///
16/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
17///
18/// ```text
19/// (Bool) -> ()
20/// ```
21///
22/// # Examples
23///
24/// ```
25/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
26/// # use arithmetic_eval::{fns, Environment, ErrorKind, VariableMap};
27/// # use assert_matches::assert_matches;
28/// # fn main() -> anyhow::Result<()> {
29/// let program = r#"
30/// assert(1 + 2 != 5); // this assertion is fine
31/// assert(3^2 > 10); // this one will fail
32/// "#;
33/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
34/// let module = Environment::new()
35/// .insert_native_fn("assert", fns::Assert)
36/// .compile_module("test_assert", &program)?;
37///
38/// let err = module.run().unwrap_err();
39/// assert_eq!(*err.source().main_span().code().fragment(), "assert(3^2 > 10)");
40/// assert_matches!(
41/// err.source().kind(),
42/// ErrorKind::NativeCall(ref msg) if msg == "Assertion failed: 3^2 > 10"
43/// );
44/// # Ok(())
45/// # }
46/// ```
47#[derive(Debug, Clone, Copy, Default)]
48pub struct Assert;
49
50impl<T> NativeFn<T> for Assert {
51 fn evaluate<'a>(
52 &self,
53 args: Vec<SpannedValue<'a, T>>,
54 ctx: &mut CallContext<'_, 'a, T>,
55 ) -> EvalResult<'a, T> {
56 ctx.check_args_count(&args, 1)?;
57 match args[0].extra {
58 Value::Bool(true) => Ok(Value::void()),
59
60 Value::Bool(false) => {
61 let err = if let CodeFragment::Str(code) = args[0].fragment() {
62 ErrorKind::native(format!("Assertion failed: {}", code))
63 } else {
64 ErrorKind::native("Assertion failed")
65 };
66 Err(ctx.call_site_error(err))
67 }
68
69 _ => {
70 let err = ErrorKind::native("`assert` requires a single boolean argument");
71 Err(ctx
72 .call_site_error(err)
73 .with_span(&args[0], AuxErrorInfo::InvalidArg))
74 }
75 }
76 }
77}
78
79/// Equality assertion function.
80///
81/// # Type
82///
83/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
84///
85/// ```text
86/// ('T, 'T) -> ()
87/// ```
88///
89/// # Examples
90///
91/// ```
92/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
93/// # use arithmetic_eval::{fns, Environment, ErrorKind, VariableMap};
94/// # use assert_matches::assert_matches;
95/// # fn main() -> anyhow::Result<()> {
96/// let program = r#"
97/// assert_eq(1 + 2, 3); // this assertion is fine
98/// assert_eq(3^2, 10); // this one will fail
99/// "#;
100/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
101/// let module = Environment::new()
102/// .insert_native_fn("assert_eq", fns::AssertEq)
103/// .compile_module("test_assert", &program)?;
104///
105/// let err = module.run().unwrap_err();
106/// assert_eq!(*err.source().main_span().code().fragment(), "assert_eq(3^2, 10)");
107/// assert_matches!(
108/// err.source().kind(),
109/// ErrorKind::NativeCall(ref msg) if msg == "Assertion failed: 3^2 == 10"
110/// );
111/// # Ok(())
112/// # }
113/// ```
114#[derive(Debug, Clone, Copy, Default)]
115pub struct AssertEq;
116
117impl<T: fmt::Display> NativeFn<T> for AssertEq {
118 fn evaluate<'a>(
119 &self,
120 args: Vec<SpannedValue<'a, T>>,
121 ctx: &mut CallContext<'_, 'a, T>,
122 ) -> EvalResult<'a, T> {
123 ctx.check_args_count(&args, 2)?;
124
125 let is_equal = args[0]
126 .extra
127 .eq_by_arithmetic(&args[1].extra, ctx.arithmetic());
128
129 if is_equal {
130 Ok(Value::void())
131 } else {
132 let err = if let (CodeFragment::Str(lhs), CodeFragment::Str(rhs)) =
133 (args[0].fragment(), args[1].fragment())
134 {
135 ErrorKind::native(format!("Assertion failed: {} == {}", lhs, rhs))
136 } else {
137 ErrorKind::native("Equality assertion failed")
138 };
139 Err(ctx
140 .call_site_error(err)
141 .with_span(&args[0], AuxErrorInfo::arg_value(&args[0].extra))
142 .with_span(&args[1], AuxErrorInfo::arg_value(&args[1].extra)))
143 }
144 }
145}