1use super::expr::Expr;
4use std::collections::HashMap;
5
6pub struct JitExpr {
9 instructions: Vec<Instruction>,
10 var_indices: HashMap<String, usize>,
11 stack: Vec<f64>,
12}
13
14#[derive(Debug, Clone)]
15enum Instruction {
16 PushConst(f64),
17 PushVar(usize),
18 Neg, Add, Sub, Mul, Div, Pow,
19 Sin, Cos, Tan, Ln, Exp, Sqrt, Abs, Floor, Ceil, Atan, Atan2,
20}
21
22impl JitExpr {
23 pub fn compile(expr: &Expr, var_names: &[&str]) -> Self {
25 let var_indices: HashMap<String, usize> = var_names.iter().enumerate()
26 .map(|(i, &name)| (name.to_string(), i))
27 .collect();
28 let mut instructions = Vec::new();
29 Self::emit(expr, &var_indices, &mut instructions);
30 Self { instructions, var_indices, stack: Vec::with_capacity(32) }
31 }
32
33 fn emit(expr: &Expr, vars: &HashMap<String, usize>, out: &mut Vec<Instruction>) {
34 match expr {
35 Expr::Const(v) => out.push(Instruction::PushConst(*v)),
36 Expr::Var(name) => {
37 let idx = vars.get(name).copied().unwrap_or(0);
38 out.push(Instruction::PushVar(idx));
39 }
40 Expr::Neg(a) => { Self::emit(a, vars, out); out.push(Instruction::Neg); }
41 Expr::Add(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Add); }
42 Expr::Sub(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Sub); }
43 Expr::Mul(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Mul); }
44 Expr::Div(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Div); }
45 Expr::Pow(a, b) => { Self::emit(a, vars, out); Self::emit(b, vars, out); out.push(Instruction::Pow); }
46 Expr::Sin(a) => { Self::emit(a, vars, out); out.push(Instruction::Sin); }
47 Expr::Cos(a) => { Self::emit(a, vars, out); out.push(Instruction::Cos); }
48 Expr::Tan(a) => { Self::emit(a, vars, out); out.push(Instruction::Tan); }
49 Expr::Ln(a) => { Self::emit(a, vars, out); out.push(Instruction::Ln); }
50 Expr::Exp(a) => { Self::emit(a, vars, out); out.push(Instruction::Exp); }
51 Expr::Sqrt(a) => { Self::emit(a, vars, out); out.push(Instruction::Sqrt); }
52 Expr::Abs(a) => { Self::emit(a, vars, out); out.push(Instruction::Abs); }
53 _ => out.push(Instruction::PushConst(f64::NAN)),
54 }
55 }
56
57 pub fn eval(&mut self, vars: &[f64]) -> f64 {
59 self.stack.clear();
60 for inst in &self.instructions {
61 match inst {
62 Instruction::PushConst(v) => self.stack.push(*v),
63 Instruction::PushVar(i) => self.stack.push(vars.get(*i).copied().unwrap_or(0.0)),
64 Instruction::Neg => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(-a); }
65 Instruction::Add => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a + b); }
66 Instruction::Sub => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a - b); }
67 Instruction::Mul => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a * b); }
68 Instruction::Div => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(if b.abs() < 1e-15 { f64::NAN } else { a / b }); }
69 Instruction::Pow => { let b = self.stack.pop().unwrap_or(0.0); let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.powf(b)); }
70 Instruction::Sin => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.sin()); }
71 Instruction::Cos => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.cos()); }
72 Instruction::Tan => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.tan()); }
73 Instruction::Ln => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.ln()); }
74 Instruction::Exp => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.exp()); }
75 Instruction::Sqrt => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.sqrt()); }
76 Instruction::Abs => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.abs()); }
77 Instruction::Floor => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.floor()); }
78 Instruction::Ceil => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.ceil()); }
79 Instruction::Atan => { let a = self.stack.pop().unwrap_or(0.0); self.stack.push(a.atan()); }
80 Instruction::Atan2 => { let x = self.stack.pop().unwrap_or(0.0); let y = self.stack.pop().unwrap_or(0.0); self.stack.push(y.atan2(x)); }
81 }
82 }
83 self.stack.pop().unwrap_or(f64::NAN)
84 }
85
86 pub fn instruction_count(&self) -> usize { self.instructions.len() }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn compile_and_eval() {
95 let expr = Expr::var("x").pow(Expr::c(2.0)).add(Expr::c(1.0));
96 let mut jit = JitExpr::compile(&expr, &["x"]);
97 assert!((jit.eval(&[3.0]) - 10.0).abs() < 1e-10);
98 assert!((jit.eval(&[0.0]) - 1.0).abs() < 1e-10);
99 }
100
101 #[test]
102 fn compile_trig() {
103 let expr = Expr::var("x").sin();
104 let mut jit = JitExpr::compile(&expr, &["x"]);
105 assert!((jit.eval(&[0.0]) - 0.0).abs() < 1e-10);
106 assert!((jit.eval(&[std::f64::consts::FRAC_PI_2]) - 1.0).abs() < 1e-10);
107 }
108
109 #[test]
110 fn compile_multi_var() {
111 let expr = Expr::var("x").add(Expr::var("y"));
112 let mut jit = JitExpr::compile(&expr, &["x", "y"]);
113 assert!((jit.eval(&[3.0, 4.0]) - 7.0).abs() < 1e-10);
114 }
115}