arithmetic_eval/fns/
assertions.rs

1//! Assertion functions.
2
3use core::fmt;
4
5use crate::{
6    alloc::{format, Vec},
7    error::AuxErrorInfo,
8    CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
9};
10use arithmetic_parser::CodeFragment;
11
12/// Assertion function.
13///
14/// # Type
15///
16/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
17///
18/// ```text
19/// (Bool) -> ()
20/// ```
21///
22/// # Examples
23///
24/// ```
25/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
26/// # use arithmetic_eval::{fns, Environment, ErrorKind, VariableMap};
27/// # use assert_matches::assert_matches;
28/// # fn main() -> anyhow::Result<()> {
29/// let program = r#"
30///     assert(1 + 2 != 5); // this assertion is fine
31///     assert(3^2 > 10); // this one will fail
32/// "#;
33/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
34/// let module = Environment::new()
35///     .insert_native_fn("assert", fns::Assert)
36///     .compile_module("test_assert", &program)?;
37///
38/// let err = module.run().unwrap_err();
39/// assert_eq!(*err.source().main_span().code().fragment(), "assert(3^2 > 10)");
40/// assert_matches!(
41///     err.source().kind(),
42///     ErrorKind::NativeCall(ref msg) if msg == "Assertion failed: 3^2 > 10"
43/// );
44/// # Ok(())
45/// # }
46/// ```
47#[derive(Debug, Clone, Copy, Default)]
48pub struct Assert;
49
50impl<T> NativeFn<T> for Assert {
51    fn evaluate<'a>(
52        &self,
53        args: Vec<SpannedValue<'a, T>>,
54        ctx: &mut CallContext<'_, 'a, T>,
55    ) -> EvalResult<'a, T> {
56        ctx.check_args_count(&args, 1)?;
57        match args[0].extra {
58            Value::Bool(true) => Ok(Value::void()),
59
60            Value::Bool(false) => {
61                let err = if let CodeFragment::Str(code) = args[0].fragment() {
62                    ErrorKind::native(format!("Assertion failed: {}", code))
63                } else {
64                    ErrorKind::native("Assertion failed")
65                };
66                Err(ctx.call_site_error(err))
67            }
68
69            _ => {
70                let err = ErrorKind::native("`assert` requires a single boolean argument");
71                Err(ctx
72                    .call_site_error(err)
73                    .with_span(&args[0], AuxErrorInfo::InvalidArg))
74            }
75        }
76    }
77}
78
79/// Equality assertion function.
80///
81/// # Type
82///
83/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
84///
85/// ```text
86/// ('T, 'T) -> ()
87/// ```
88///
89/// # Examples
90///
91/// ```
92/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
93/// # use arithmetic_eval::{fns, Environment, ErrorKind, VariableMap};
94/// # use assert_matches::assert_matches;
95/// # fn main() -> anyhow::Result<()> {
96/// let program = r#"
97///     assert_eq(1 + 2, 3); // this assertion is fine
98///     assert_eq(3^2, 10); // this one will fail
99/// "#;
100/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
101/// let module = Environment::new()
102///     .insert_native_fn("assert_eq", fns::AssertEq)
103///     .compile_module("test_assert", &program)?;
104///
105/// let err = module.run().unwrap_err();
106/// assert_eq!(*err.source().main_span().code().fragment(), "assert_eq(3^2, 10)");
107/// assert_matches!(
108///     err.source().kind(),
109///     ErrorKind::NativeCall(ref msg) if msg == "Assertion failed: 3^2 == 10"
110/// );
111/// # Ok(())
112/// # }
113/// ```
114#[derive(Debug, Clone, Copy, Default)]
115pub struct AssertEq;
116
117impl<T: fmt::Display> NativeFn<T> for AssertEq {
118    fn evaluate<'a>(
119        &self,
120        args: Vec<SpannedValue<'a, T>>,
121        ctx: &mut CallContext<'_, 'a, T>,
122    ) -> EvalResult<'a, T> {
123        ctx.check_args_count(&args, 2)?;
124
125        let is_equal = args[0]
126            .extra
127            .eq_by_arithmetic(&args[1].extra, ctx.arithmetic());
128
129        if is_equal {
130            Ok(Value::void())
131        } else {
132            let err = if let (CodeFragment::Str(lhs), CodeFragment::Str(rhs)) =
133                (args[0].fragment(), args[1].fragment())
134            {
135                ErrorKind::native(format!("Assertion failed: {} == {}", lhs, rhs))
136            } else {
137                ErrorKind::native("Equality assertion failed")
138            };
139            Err(ctx
140                .call_site_error(err)
141                .with_span(&args[0], AuxErrorInfo::arg_value(&args[0].extra))
142                .with_span(&args[1], AuxErrorInfo::arg_value(&args[1].extra)))
143        }
144    }
145}