mathlex_eval/compiler/
compile.rs1use std::collections::HashMap;
2
3use mathlex::Expression;
4
5use crate::compiler::fold;
6use crate::compiler::ir::CompiledExpr;
7use crate::compiler::validate;
8use crate::error::CompileError;
9use crate::eval::numeric::NumericResult;
10
11pub fn compile(
23 ast: &Expression,
24 constants: &HashMap<&str, NumericResult>,
25) -> Result<CompiledExpr, CompileError> {
26 validate::validate(ast)?;
27 fold::fold(ast, constants)
28}
29
30#[cfg(test)]
31mod tests {
32 use super::*;
33 use approx::assert_abs_diff_eq;
34 use mathlex::{BinaryOp, ExprKind, MathConstant, UnaryOp};
35
36 use crate::compiler::ir::CompiledNode;
37
38 fn int(v: i64) -> Expression {
39 Expression::integer(v)
40 }
41
42 fn var(name: &str) -> Expression {
43 Expression::variable(name)
44 }
45
46 #[test]
47 fn compile_simple_expression() {
48 let ast = ExprKind::Binary {
50 op: BinaryOp::Add,
51 left: Box::new(var("x")),
52 right: Box::new(int(1)),
53 }
54 .into();
55 let compiled = compile(&ast, &HashMap::new()).unwrap();
56 assert_eq!(compiled.argument_names(), &["x"]);
57 assert!(!compiled.is_complex());
58 }
59
60 #[test]
61 fn compile_with_constants() {
62 let ast = ExprKind::Binary {
64 op: BinaryOp::Mul,
65 left: Box::new(var("a")),
66 right: Box::new(var("x")),
67 }
68 .into();
69 let mut constants = HashMap::new();
70 constants.insert("a", NumericResult::Real(2.0));
71 let compiled = compile(&ast, &constants).unwrap();
72 assert_eq!(compiled.argument_names(), &["x"]);
73 }
74
75 #[test]
76 fn compile_pure_constant_folds() {
77 let ast = ExprKind::Binary {
79 op: BinaryOp::Mul,
80 left: Box::new(int(2)),
81 right: Box::new(Expression::constant(MathConstant::Pi)),
82 }
83 .into();
84 let compiled = compile(&ast, &HashMap::new()).unwrap();
85 if let CompiledNode::Literal(v) = compiled.root {
86 assert_abs_diff_eq!(v, 2.0 * std::f64::consts::PI, epsilon = 1e-15);
87 } else {
88 panic!("expected folded literal");
89 }
90 }
91
92 #[test]
93 fn compile_rejects_vector() {
94 let ast = Expression::vector(vec![int(1)]);
95 let err = compile(&ast, &HashMap::new()).unwrap_err();
96 assert!(matches!(err, CompileError::UnsupportedExpression { .. }));
97 }
98
99 #[test]
100 fn compile_rejects_derivative() {
101 let ast = ExprKind::Derivative {
102 expr: Box::new(var("x")),
103 var: "x".into(),
104 order: 1,
105 }
106 .into();
107 let err = compile(&ast, &HashMap::new()).unwrap_err();
108 assert!(matches!(err, CompileError::UnsupportedExpression { .. }));
109 }
110
111 #[test]
112 fn compile_complex_constant_sets_flag() {
113 let ast = Expression::constant(MathConstant::I);
114 let compiled = compile(&ast, &HashMap::new()).unwrap();
115 assert!(compiled.is_complex());
116 }
117
118 #[test]
119 fn compile_factorial() {
120 let ast = ExprKind::Unary {
121 op: UnaryOp::Factorial,
122 operand: Box::new(int(5)),
123 }
124 .into();
125 let compiled = compile(&ast, &HashMap::new()).unwrap();
126 if let CompiledNode::Literal(v) = compiled.root {
127 assert_abs_diff_eq!(v, 120.0, epsilon = 1e-10);
128 } else {
129 panic!("expected folded literal");
130 }
131 }
132
133 #[test]
134 fn compile_sum() {
135 let ast = ExprKind::Sum {
136 index: "k".into(),
137 lower: Box::new(int(1)),
138 upper: Box::new(int(10)),
139 body: Box::new(var("k")),
140 }
141 .into();
142 let compiled = compile(&ast, &HashMap::new()).unwrap();
143 assert!(compiled.argument_names().is_empty());
144 }
145}