Skip to main content

mathlex_eval/compiler/
ir.rs

1/// Compiled expression ready for repeated evaluation.
2///
3/// Opaque to callers — only inspectable via [`argument_names`](Self::argument_names)
4/// and [`is_complex`](Self::is_complex).
5#[derive(Debug, Clone)]
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7pub struct CompiledExpr {
8    pub(crate) root: CompiledNode,
9    pub(crate) argument_names: Vec<String>,
10    pub(crate) is_complex: bool,
11}
12
13impl CompiledExpr {
14    /// Names of free variables (arguments) in declaration order.
15    pub fn argument_names(&self) -> &[String] {
16        &self.argument_names
17    }
18
19    /// Whether the expression contains complex literals or imaginary unit.
20    pub fn is_complex(&self) -> bool {
21        self.is_complex
22    }
23}
24
25/// Internal IR node — not exposed in public API.
26#[derive(Debug, Clone)]
27#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
28pub(crate) enum CompiledNode {
29    /// Real literal (constant-folded)
30    Literal(f64),
31    /// Complex literal (constant-folded)
32    ComplexLiteral { re: f64, im: f64 },
33    /// Free variable, index into argument values by declaration order
34    Argument(usize),
35    /// Bound variable from sum/product loop, separate index space from arguments
36    Index(usize),
37
38    /// Binary arithmetic operation
39    Binary {
40        op: BinaryOp,
41        left: Box<CompiledNode>,
42        right: Box<CompiledNode>,
43    },
44    /// Unary operation
45    Unary {
46        op: UnaryOp,
47        operand: Box<CompiledNode>,
48    },
49
50    /// Built-in math function call
51    Function {
52        kind: BuiltinFn,
53        args: Vec<CompiledNode>,
54    },
55
56    /// Finite summation: Σ_{index=lower}^{upper} body
57    Sum {
58        index: usize,
59        lower: i64,
60        upper: i64,
61        body: Box<CompiledNode>,
62    },
63    /// Finite product: Π_{index=lower}^{upper} body
64    Product {
65        index: usize,
66        lower: i64,
67        upper: i64,
68        body: Box<CompiledNode>,
69    },
70}
71
72/// Binary arithmetic operators supported by the evaluator.
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
75pub(crate) enum BinaryOp {
76    Add,
77    Sub,
78    Mul,
79    Div,
80    Pow,
81    Mod,
82}
83
84/// Unary operators supported by the evaluator.
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
87pub(crate) enum UnaryOp {
88    Neg,
89    Factorial,
90}
91
92/// Built-in math functions recognized by the compiler.
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95pub(crate) enum BuiltinFn {
96    Sin,
97    Cos,
98    Tan,
99    Asin,
100    Acos,
101    Atan,
102    Atan2,
103    Sinh,
104    Cosh,
105    Tanh,
106    Exp,
107    Ln,
108    Log2,
109    Log10,
110    /// log(base, value)
111    Log,
112    Sqrt,
113    Cbrt,
114    Abs,
115    Floor,
116    Ceil,
117    Round,
118    Min,
119    Max,
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn compiled_expr_argument_names() {
128        let expr = CompiledExpr {
129            root: CompiledNode::Literal(1.0),
130            argument_names: vec!["x".into(), "y".into()],
131            is_complex: false,
132        };
133        assert_eq!(expr.argument_names(), &["x", "y"]);
134    }
135
136    #[test]
137    fn compiled_expr_is_complex() {
138        let expr = CompiledExpr {
139            root: CompiledNode::ComplexLiteral { re: 0.0, im: 1.0 },
140            argument_names: vec![],
141            is_complex: true,
142        };
143        assert!(expr.is_complex());
144    }
145
146    #[test]
147    fn compiled_node_variants_constructible() {
148        // Verify all variants can be constructed
149        let _ = CompiledNode::Literal(1.0);
150        let _ = CompiledNode::ComplexLiteral { re: 1.0, im: 2.0 };
151        let _ = CompiledNode::Argument(0);
152        let _ = CompiledNode::Index(0);
153        let _ = CompiledNode::Binary {
154            op: BinaryOp::Add,
155            left: Box::new(CompiledNode::Literal(1.0)),
156            right: Box::new(CompiledNode::Literal(2.0)),
157        };
158        let _ = CompiledNode::Unary {
159            op: UnaryOp::Neg,
160            operand: Box::new(CompiledNode::Literal(1.0)),
161        };
162        let _ = CompiledNode::Function {
163            kind: BuiltinFn::Sin,
164            args: vec![CompiledNode::Literal(0.0)],
165        };
166        let _ = CompiledNode::Sum {
167            index: 0,
168            lower: 1,
169            upper: 10,
170            body: Box::new(CompiledNode::Index(0)),
171        };
172        let _ = CompiledNode::Product {
173            index: 0,
174            lower: 1,
175            upper: 5,
176            body: Box::new(CompiledNode::Index(0)),
177        };
178    }
179
180    #[test]
181    fn binary_op_all_variants() {
182        let ops = [
183            BinaryOp::Add,
184            BinaryOp::Sub,
185            BinaryOp::Mul,
186            BinaryOp::Div,
187            BinaryOp::Pow,
188            BinaryOp::Mod,
189        ];
190        assert_eq!(ops.len(), 6);
191    }
192
193    #[test]
194    fn unary_op_all_variants() {
195        let ops = [UnaryOp::Neg, UnaryOp::Factorial];
196        assert_eq!(ops.len(), 2);
197    }
198
199    #[test]
200    fn builtin_fn_all_variants() {
201        let fns = [
202            BuiltinFn::Sin,
203            BuiltinFn::Cos,
204            BuiltinFn::Tan,
205            BuiltinFn::Asin,
206            BuiltinFn::Acos,
207            BuiltinFn::Atan,
208            BuiltinFn::Atan2,
209            BuiltinFn::Sinh,
210            BuiltinFn::Cosh,
211            BuiltinFn::Tanh,
212            BuiltinFn::Exp,
213            BuiltinFn::Ln,
214            BuiltinFn::Log2,
215            BuiltinFn::Log10,
216            BuiltinFn::Log,
217            BuiltinFn::Sqrt,
218            BuiltinFn::Cbrt,
219            BuiltinFn::Abs,
220            BuiltinFn::Floor,
221            BuiltinFn::Ceil,
222            BuiltinFn::Round,
223            BuiltinFn::Min,
224            BuiltinFn::Max,
225        ];
226        assert_eq!(fns.len(), 23);
227    }
228}