Skip to main content

mathlex_eval/compiler/
compile.rs

1use std::collections::HashMap;
2
3use mathlex::Expression;
4
5use crate::compiler::fold;
6use crate::compiler::ir::CompiledExpr;
7use crate::compiler::validate;
8use crate::error::CompileError;
9use crate::eval::numeric::NumericResult;
10
11/// Compile a mathlex AST into a [`CompiledExpr`] ready for evaluation.
12///
13/// Takes a reference to the AST and a map of constant names to values.
14/// Constants are substituted at compile time; remaining free variables
15/// become arguments that must be provided at eval time.
16///
17/// # Errors
18///
19/// Returns [`CompileError`] if the AST contains unsupported expression
20/// variants, unknown functions, arity mismatches, unresolvable bounds,
21/// or division by zero during constant folding.
22pub fn compile(
23    ast: &Expression,
24    constants: &HashMap<&str, NumericResult>,
25) -> Result<CompiledExpr, CompileError> {
26    validate::validate(ast)?;
27    fold::fold(ast, constants)
28}
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33    use approx::assert_abs_diff_eq;
34    use mathlex::{BinaryOp, MathConstant, UnaryOp};
35
36    use crate::compiler::ir::CompiledNode;
37
38    fn int(v: i64) -> Expression {
39        Expression::Integer(v)
40    }
41
42    fn var(name: &str) -> Expression {
43        Expression::Variable(name.into())
44    }
45
46    #[test]
47    fn compile_simple_expression() {
48        // x + 1
49        let ast = Expression::Binary {
50            op: BinaryOp::Add,
51            left: Box::new(var("x")),
52            right: Box::new(int(1)),
53        };
54        let compiled = compile(&ast, &HashMap::new()).unwrap();
55        assert_eq!(compiled.argument_names(), &["x"]);
56        assert!(!compiled.is_complex());
57    }
58
59    #[test]
60    fn compile_with_constants() {
61        // a * x where a = 2.0
62        let ast = Expression::Binary {
63            op: BinaryOp::Mul,
64            left: Box::new(var("a")),
65            right: Box::new(var("x")),
66        };
67        let mut constants = HashMap::new();
68        constants.insert("a", NumericResult::Real(2.0));
69        let compiled = compile(&ast, &constants).unwrap();
70        assert_eq!(compiled.argument_names(), &["x"]);
71    }
72
73    #[test]
74    fn compile_pure_constant_folds() {
75        // 2 * pi → single literal
76        let ast = Expression::Binary {
77            op: BinaryOp::Mul,
78            left: Box::new(int(2)),
79            right: Box::new(Expression::Constant(MathConstant::Pi)),
80        };
81        let compiled = compile(&ast, &HashMap::new()).unwrap();
82        if let CompiledNode::Literal(v) = compiled.root {
83            assert_abs_diff_eq!(v, 2.0 * std::f64::consts::PI, epsilon = 1e-15);
84        } else {
85            panic!("expected folded literal");
86        }
87    }
88
89    #[test]
90    fn compile_rejects_vector() {
91        let ast = Expression::Vector(vec![int(1)]);
92        let err = compile(&ast, &HashMap::new()).unwrap_err();
93        assert!(matches!(err, CompileError::UnsupportedExpression { .. }));
94    }
95
96    #[test]
97    fn compile_rejects_derivative() {
98        let ast = Expression::Derivative {
99            expr: Box::new(var("x")),
100            var: "x".into(),
101            order: 1,
102        };
103        let err = compile(&ast, &HashMap::new()).unwrap_err();
104        assert!(matches!(err, CompileError::UnsupportedExpression { .. }));
105    }
106
107    #[test]
108    fn compile_complex_constant_sets_flag() {
109        let ast = Expression::Constant(MathConstant::I);
110        let compiled = compile(&ast, &HashMap::new()).unwrap();
111        assert!(compiled.is_complex());
112    }
113
114    #[test]
115    fn compile_factorial() {
116        let ast = Expression::Unary {
117            op: UnaryOp::Factorial,
118            operand: Box::new(int(5)),
119        };
120        let compiled = compile(&ast, &HashMap::new()).unwrap();
121        if let CompiledNode::Literal(v) = compiled.root {
122            assert_abs_diff_eq!(v, 120.0, epsilon = 1e-10);
123        } else {
124            panic!("expected folded literal");
125        }
126    }
127
128    #[test]
129    fn compile_sum() {
130        let ast = Expression::Sum {
131            index: "k".into(),
132            lower: Box::new(int(1)),
133            upper: Box::new(int(10)),
134            body: Box::new(var("k")),
135        };
136        let compiled = compile(&ast, &HashMap::new()).unwrap();
137        assert!(compiled.argument_names().is_empty());
138    }
139}