mathlex_eval/compiler/
compile.rs1use std::collections::HashMap;
2
3use mathlex::Expression;
4
5use crate::compiler::fold;
6use crate::compiler::ir::CompiledExpr;
7use crate::compiler::validate;
8use crate::error::CompileError;
9use crate::eval::numeric::NumericResult;
10
11pub fn compile(
23 ast: &Expression,
24 constants: &HashMap<&str, NumericResult>,
25) -> Result<CompiledExpr, CompileError> {
26 validate::validate(ast)?;
27 fold::fold(ast, constants)
28}
29
30#[cfg(test)]
31mod tests {
32 use super::*;
33 use approx::assert_abs_diff_eq;
34 use mathlex::{BinaryOp, MathConstant, UnaryOp};
35
36 use crate::compiler::ir::CompiledNode;
37
38 fn int(v: i64) -> Expression {
39 Expression::Integer(v)
40 }
41
42 fn var(name: &str) -> Expression {
43 Expression::Variable(name.into())
44 }
45
46 #[test]
47 fn compile_simple_expression() {
48 let ast = Expression::Binary {
50 op: BinaryOp::Add,
51 left: Box::new(var("x")),
52 right: Box::new(int(1)),
53 };
54 let compiled = compile(&ast, &HashMap::new()).unwrap();
55 assert_eq!(compiled.argument_names(), &["x"]);
56 assert!(!compiled.is_complex());
57 }
58
59 #[test]
60 fn compile_with_constants() {
61 let ast = Expression::Binary {
63 op: BinaryOp::Mul,
64 left: Box::new(var("a")),
65 right: Box::new(var("x")),
66 };
67 let mut constants = HashMap::new();
68 constants.insert("a", NumericResult::Real(2.0));
69 let compiled = compile(&ast, &constants).unwrap();
70 assert_eq!(compiled.argument_names(), &["x"]);
71 }
72
73 #[test]
74 fn compile_pure_constant_folds() {
75 let ast = Expression::Binary {
77 op: BinaryOp::Mul,
78 left: Box::new(int(2)),
79 right: Box::new(Expression::Constant(MathConstant::Pi)),
80 };
81 let compiled = compile(&ast, &HashMap::new()).unwrap();
82 if let CompiledNode::Literal(v) = compiled.root {
83 assert_abs_diff_eq!(v, 2.0 * std::f64::consts::PI, epsilon = 1e-15);
84 } else {
85 panic!("expected folded literal");
86 }
87 }
88
89 #[test]
90 fn compile_rejects_vector() {
91 let ast = Expression::Vector(vec![int(1)]);
92 let err = compile(&ast, &HashMap::new()).unwrap_err();
93 assert!(matches!(err, CompileError::UnsupportedExpression { .. }));
94 }
95
96 #[test]
97 fn compile_rejects_derivative() {
98 let ast = Expression::Derivative {
99 expr: Box::new(var("x")),
100 var: "x".into(),
101 order: 1,
102 };
103 let err = compile(&ast, &HashMap::new()).unwrap_err();
104 assert!(matches!(err, CompileError::UnsupportedExpression { .. }));
105 }
106
107 #[test]
108 fn compile_complex_constant_sets_flag() {
109 let ast = Expression::Constant(MathConstant::I);
110 let compiled = compile(&ast, &HashMap::new()).unwrap();
111 assert!(compiled.is_complex());
112 }
113
114 #[test]
115 fn compile_factorial() {
116 let ast = Expression::Unary {
117 op: UnaryOp::Factorial,
118 operand: Box::new(int(5)),
119 };
120 let compiled = compile(&ast, &HashMap::new()).unwrap();
121 if let CompiledNode::Literal(v) = compiled.root {
122 assert_abs_diff_eq!(v, 120.0, epsilon = 1e-10);
123 } else {
124 panic!("expected folded literal");
125 }
126 }
127
128 #[test]
129 fn compile_sum() {
130 let ast = Expression::Sum {
131 index: "k".into(),
132 lower: Box::new(int(1)),
133 upper: Box::new(int(10)),
134 body: Box::new(var("k")),
135 };
136 let compiled = compile(&ast, &HashMap::new()).unwrap();
137 assert!(compiled.argument_names().is_empty());
138 }
139}