Skip to main content

proof_engine/symbolic/
compile.rs

1//! Symbolic-to-numeric compilation — compile expression trees to fast evaluators.
2
3use super::expr::Expr;
4use std::collections::HashMap;
5
6/// Compiled expression for fast repeated evaluation.
7/// Converts the recursive Expr tree into a flat stack-based instruction sequence.
8pub struct JitExpr {
9    instructions: Vec<Instruction>,
10    var_indices: HashMap<String, usize>,
11    stack: Vec<f64>,
12}
13
14#[derive(Debug, Clone)]
15enum Instruction {
16    PushConst(f64),
17    PushVar(usize),
18    Neg, Add, Sub, Mul, Div, Pow,
19    Sin, Cos, Tan, Ln, Exp, Sqrt, Abs, Floor, Ceil, Atan, Atan2,
20}
21
22impl JitExpr {
23    /// Compile an expression. Variables are mapped to indices for fast lookup.
24    pub fn compile(expr: &Expr, var_names: &[&str]) -> Self {
25        let var_indices: HashMap<String, usize> = var_names.iter().enumerate()
26            .map(|(i, &name)| (name.to_string(), i))
27            .collect();
28        let mut instructions = Vec::new();
29        Self::emit(expr, &var_indices, &mut instructions);
30        Self { instructions, var_indices, stack: Vec::with_capacity(32) }
31    }
32
33    fn emit(expr: &Expr, vars: &HashMap<String, usize>, out: &mut Vec<Instruction>) {
34        match expr {
35            Expr::Const(v) => out.push(Instruction::PushConst(*v)),
36            Expr::Var(name) => {
37                let idx = vars.get(name).copied().unwrap_or(0);
38                out.push(Instruction::PushVar(idx));
39            }
40            Expr::Neg(a) => { Self::emit(a, vars, out); out.push(Instruction::Neg); }
41            Expr::Add(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Add); }
42            Expr::Sub(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Sub); }
43            Expr::Mul(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Mul); }
44            Expr::Div(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Div); }
45            Expr::Pow(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Pow); }
46            Expr::Sin(a) => { Self::emit(a, vars, out); out.push(Instruction::Sin); }
47            Expr::Cos(a) => { Self::emit(a, vars, out); out.push(Instruction::Cos); }
48            Expr::Tan(a) => { Self::emit(a, vars, out); out.push(Instruction::Tan); }
49            Expr::Ln(a) => { Self::emit(a, vars, out); out.push(Instruction::Ln); }
50            Expr::Exp(a) => { Self::emit(a, vars, out); out.push(Instruction::Exp); }
51            Expr::Sqrt(a) => { Self::emit(a, vars, out); out.push(Instruction::Sqrt); }
52            Expr::Abs(a) => { Self::emit(a, vars, out); out.push(Instruction::Abs); }
53            _ => out.push(Instruction::PushConst(f64::NAN)),
54        }
55    }
56
57    /// Evaluate with the given variable values (indexed same as var_names in compile).
58    pub fn eval(&mut self, vars: &[f64]) -> f64 {
59        self.stack.clear();
60        for inst in &self.instructions {
61            match inst {
62                Instruction::PushConst(v) => self.stack.push(*v),
63                Instruction::PushVar(i) => self.stack.push(vars.get(*i).copied().unwrap_or(0.0)),
64                Instruction::Neg => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(-a); }
65                Instruction::Add => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a + b); }
66                Instruction::Sub => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a - b); }
67                Instruction::Mul => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a * b); }
68                Instruction::Div => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(if b.abs() < 1e-15 { f64::NAN } else { a / b }); }
69                Instruction::Pow => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.powf(b)); }
70                Instruction::Sin => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.sin()); }
71                Instruction::Cos => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.cos()); }
72                Instruction::Tan => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.tan()); }
73                Instruction::Ln => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.ln()); }
74                Instruction::Exp => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.exp()); }
75                Instruction::Sqrt => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.sqrt()); }
76                Instruction::Abs => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.abs()); }
77                Instruction::Floor => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.floor()); }
78                Instruction::Ceil => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.ceil()); }
79                Instruction::Atan => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.atan()); }
80                Instruction::Atan2 => { let x = self.stack.pop().unwrap_or(0.0); let y = self.stack.pop().unwrap_or(0.0); self.stack.push(y.atan2(x)); }
81            }
82        }
83        self.stack.pop().unwrap_or(f64::NAN)
84    }
85
86    pub fn instruction_count(&self) -> usize { self.instructions.len() }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn compile_and_eval() {
95        let expr = Expr::var("x").pow(Expr::c(2.0)).add(Expr::c(1.0));
96        let mut jit = JitExpr::compile(&expr, &["x"]);
97        assert!((jit.eval(&[3.0]) - 10.0).abs() < 1e-10);
98        assert!((jit.eval(&[0.0]) - 1.0).abs() < 1e-10);
99    }
100
101    #[test]
102    fn compile_trig() {
103        let expr = Expr::var("x").sin();
104        let mut jit = JitExpr::compile(&expr, &["x"]);
105        assert!((jit.eval(&[0.0]) - 0.0).abs() < 1e-10);
106        assert!((jit.eval(&[std::f64::consts::FRAC_PI_2]) - 1.0).abs() < 1e-10);
107    }
108
109    #[test]
110    fn compile_multi_var() {
111        let expr = Expr::var("x").add(Expr::var("y"));
112        let mut jit = JitExpr::compile(&expr, &["x", "y"]);
113        assert!((jit.eval(&[3.0, 4.0]) - 7.0).abs() < 1e-10);
114    }
115}