Skip to main content

mathlex_eval/compiler/
compile.rs

1use std::collections::HashMap;
2
3use mathlex::Expression;
4
5use crate::compiler::fold;
6use crate::compiler::ir::CompiledExpr;
7use crate::compiler::validate;
8use crate::error::CompileError;
9use crate::eval::numeric::NumericResult;
10
11/// Compile a mathlex AST into a [`CompiledExpr`] ready for evaluation.
12///
13/// Takes a reference to the AST and a map of constant names to values.
14/// Constants are substituted at compile time; remaining free variables
15/// become arguments that must be provided at eval time.
16///
17/// # Errors
18///
19/// Returns [`CompileError`] if the AST contains unsupported expression
20/// variants, unknown functions, arity mismatches, unresolvable bounds,
21/// or division by zero during constant folding.
22pub fn compile(
23    ast: &Expression,
24    constants: &HashMap<&str, NumericResult>,
25) -> Result<CompiledExpr, CompileError> {
26    validate::validate(ast)?;
27    fold::fold(ast, constants)
28}
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33    use approx::assert_abs_diff_eq;
34    use mathlex::{BinaryOp, ExprKind, MathConstant, UnaryOp};
35
36    use crate::compiler::ir::CompiledNode;
37
38    fn int(v: i64) -> Expression {
39        Expression::integer(v)
40    }
41
42    fn var(name: &str) -> Expression {
43        Expression::variable(name)
44    }
45
46    #[test]
47    fn compile_simple_expression() {
48        // x + 1
49        let ast = ExprKind::Binary {
50            op: BinaryOp::Add,
51            left: Box::new(var("x")),
52            right: Box::new(int(1)),
53        }
54        .into();
55        let compiled = compile(&ast, &HashMap::new()).unwrap();
56        assert_eq!(compiled.argument_names(), &["x"]);
57        assert!(!compiled.is_complex());
58    }
59
60    #[test]
61    fn compile_with_constants() {
62        // a * x where a = 2.0
63        let ast = ExprKind::Binary {
64            op: BinaryOp::Mul,
65            left: Box::new(var("a")),
66            right: Box::new(var("x")),
67        }
68        .into();
69        let mut constants = HashMap::new();
70        constants.insert("a", NumericResult::Real(2.0));
71        let compiled = compile(&ast, &constants).unwrap();
72        assert_eq!(compiled.argument_names(), &["x"]);
73    }
74
75    #[test]
76    fn compile_pure_constant_folds() {
77        // 2 * pi → single literal
78        let ast = ExprKind::Binary {
79            op: BinaryOp::Mul,
80            left: Box::new(int(2)),
81            right: Box::new(Expression::constant(MathConstant::Pi)),
82        }
83        .into();
84        let compiled = compile(&ast, &HashMap::new()).unwrap();
85        if let CompiledNode::Literal(v) = compiled.root {
86            assert_abs_diff_eq!(v, 2.0 * std::f64::consts::PI, epsilon = 1e-15);
87        } else {
88            panic!("expected folded literal");
89        }
90    }
91
92    #[test]
93    fn compile_rejects_vector() {
94        let ast = Expression::vector(vec![int(1)]);
95        let err = compile(&ast, &HashMap::new()).unwrap_err();
96        assert!(matches!(err, CompileError::UnsupportedExpression { .. }));
97    }
98
99    #[test]
100    fn compile_rejects_derivative() {
101        let ast = ExprKind::Derivative {
102            expr: Box::new(var("x")),
103            var: "x".into(),
104            order: 1,
105        }
106        .into();
107        let err = compile(&ast, &HashMap::new()).unwrap_err();
108        assert!(matches!(err, CompileError::UnsupportedExpression { .. }));
109    }
110
111    #[test]
112    fn compile_complex_constant_sets_flag() {
113        let ast = Expression::constant(MathConstant::I);
114        let compiled = compile(&ast, &HashMap::new()).unwrap();
115        assert!(compiled.is_complex());
116    }
117
118    #[test]
119    fn compile_factorial() {
120        let ast = ExprKind::Unary {
121            op: UnaryOp::Factorial,
122            operand: Box::new(int(5)),
123        }
124        .into();
125        let compiled = compile(&ast, &HashMap::new()).unwrap();
126        if let CompiledNode::Literal(v) = compiled.root {
127            assert_abs_diff_eq!(v, 120.0, epsilon = 1e-10);
128        } else {
129            panic!("expected folded literal");
130        }
131    }
132
133    #[test]
134    fn compile_sum() {
135        let ast = ExprKind::Sum {
136            index: "k".into(),
137            lower: Box::new(int(1)),
138            upper: Box::new(int(10)),
139            body: Box::new(var("k")),
140        }
141        .into();
142        let compiled = compile(&ast, &HashMap::new()).unwrap();
143        assert!(compiled.argument_names().is_empty());
144    }
145}