Skip to main content

proof_engine/symbolic/
expr.rs

1//! Expression tree representation — AST for mathematical expressions.
2
3use std::fmt;
4
5/// A symbolic mathematical expression.
6#[derive(Debug, Clone, PartialEq)]
7pub enum Expr {
8    /// Named variable: x, y, t, etc.
9    Var(String),
10    /// Numeric constant.
11    Const(f64),
12    /// Negation: -a.
13    Neg(Box<Expr>),
14    /// Addition: a + b.
15    Add(Box<Expr>, Box<Expr>),
16    /// Subtraction: a - b.
17    Sub(Box<Expr>, Box<Expr>),
18    /// Multiplication: a * b.
19    Mul(Box<Expr>, Box<Expr>),
20    /// Division: a / b.
21    Div(Box<Expr>, Box<Expr>),
22    /// Power: a^b.
23    Pow(Box<Expr>, Box<Expr>),
24    /// Sine.
25    Sin(Box<Expr>),
26    /// Cosine.
27    Cos(Box<Expr>),
28    /// Tangent.
29    Tan(Box<Expr>),
30    /// Natural logarithm.
31    Ln(Box<Expr>),
32    /// Exponential: e^a.
33    Exp(Box<Expr>),
34    /// Square root.
35    Sqrt(Box<Expr>),
36    /// Absolute value.
37    Abs(Box<Expr>),
38    /// Floor function.
39    Floor(Box<Expr>),
40    /// Ceiling function.
41    Ceil(Box<Expr>),
42    /// Arctangent.
43    Atan(Box<Expr>),
44    /// Atan2(y, x).
45    Atan2(Box<Expr>, Box<Expr>),
46    /// Summation: Σ(body, var, from, to).
47    Sum { body: Box<Expr>, var: String, from: Box<Expr>, to: Box<Expr> },
48    /// Product: Π(body, var, from, to).
49    Product { body: Box<Expr>, var: String, from: Box<Expr>, to: Box<Expr> },
50    /// Integral: ∫(body, var).
51    Integral { body: Box<Expr>, var: String },
52    /// Derivative: d/dvar(body).
53    Derivative { body: Box<Expr>, var: String },
54}
55
56/// Convenience constructors.
57pub fn Var(name: &str) -> Expr { Expr::Var(name.to_string()) }
58pub fn Const(val: f64) -> Expr { Expr::Const(val) }
59
60impl Expr {
61    pub fn var(name: &str) -> Self { Self::Var(name.to_string()) }
62    pub fn c(val: f64) -> Self { Self::Const(val) }
63    pub fn zero() -> Self { Self::Const(0.0) }
64    pub fn one() -> Self { Self::Const(1.0) }
65    pub fn pi() -> Self { Self::Const(std::f64::consts::PI) }
66    pub fn e() -> Self { Self::Const(std::f64::consts::E) }
67
68    // Binary operations
69    pub fn add(self, other: Expr) -> Expr { Expr::Add(Box::new(self), Box::new(other)) }
70    pub fn sub(self, other: Expr) -> Expr { Expr::Sub(Box::new(self), Box::new(other)) }
71    pub fn mul(self, other: Expr) -> Expr { Expr::Mul(Box::new(self), Box::new(other)) }
72    pub fn div(self, other: Expr) -> Expr { Expr::Div(Box::new(self), Box::new(other)) }
73    pub fn pow(self, exp: Expr) -> Expr { Expr::Pow(Box::new(self), Box::new(exp)) }
74
75    // Unary operations
76    pub fn neg(self) -> Expr { Expr::Neg(Box::new(self)) }
77    pub fn sin(self) -> Expr { Expr::Sin(Box::new(self)) }
78    pub fn cos(self) -> Expr { Expr::Cos(Box::new(self)) }
79    pub fn tan(self) -> Expr { Expr::Tan(Box::new(self)) }
80    pub fn ln(self) -> Expr { Expr::Ln(Box::new(self)) }
81    pub fn exp(self) -> Expr { Expr::Exp(Box::new(self)) }
82    pub fn sqrt(self) -> Expr { Expr::Sqrt(Box::new(self)) }
83    pub fn abs(self) -> Expr { Expr::Abs(Box::new(self)) }
84
85    /// Evaluate the expression with variable bindings.
86    pub fn eval(&self, vars: &std::collections::HashMap<String, f64>) -> f64 {
87        match self {
88            Expr::Var(name) => *vars.get(name).unwrap_or(&0.0),
89            Expr::Const(v) => *v,
90            Expr::Neg(a) => -a.eval(vars),
91            Expr::Add(a, b) => a.eval(vars) + b.eval(vars),
92            Expr::Sub(a, b) => a.eval(vars) - b.eval(vars),
93            Expr::Mul(a, b) => a.eval(vars) * b.eval(vars),
94            Expr::Div(a, b) => { let d = b.eval(vars); if d.abs() < 1e-15 { f64::NAN } else { a.eval(vars) / d } }
95            Expr::Pow(a, b) => a.eval(vars).powf(b.eval(vars)),
96            Expr::Sin(a) => a.eval(vars).sin(),
97            Expr::Cos(a) => a.eval(vars).cos(),
98            Expr::Tan(a) => a.eval(vars).tan(),
99            Expr::Ln(a) => a.eval(vars).ln(),
100            Expr::Exp(a) => a.eval(vars).exp(),
101            Expr::Sqrt(a) => a.eval(vars).sqrt(),
102            Expr::Abs(a) => a.eval(vars).abs(),
103            Expr::Floor(a) => a.eval(vars).floor(),
104            Expr::Ceil(a) => a.eval(vars).ceil(),
105            Expr::Atan(a) => a.eval(vars).atan(),
106            Expr::Atan2(y, x) => y.eval(vars).atan2(x.eval(vars)),
107            Expr::Sum { body, var, from, to } => {
108                let f = from.eval(vars) as i64;
109                let t = to.eval(vars) as i64;
110                let mut sum = 0.0;
111                let mut local = vars.clone();
112                for i in f..=t {
113                    local.insert(var.clone(), i as f64);
114                    sum += body.eval(&local);
115                }
116                sum
117            }
118            Expr::Product { body, var, from, to } => {
119                let f = from.eval(vars) as i64;
120                let t = to.eval(vars) as i64;
121                let mut prod = 1.0;
122                let mut local = vars.clone();
123                for i in f..=t {
124                    local.insert(var.clone(), i as f64);
125                    prod *= body.eval(&local);
126                }
127                prod
128            }
129            Expr::Integral { .. } => f64::NAN, // symbolic only
130            Expr::Derivative { .. } => f64::NAN,
131        }
132    }
133
134    /// Whether this expression contains the given variable.
135    pub fn contains_var(&self, var: &str) -> bool {
136        match self {
137            Expr::Var(name) => name == var,
138            Expr::Const(_) => false,
139            Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a) |
140            Expr::Ln(a) | Expr::Exp(a) | Expr::Sqrt(a) | Expr::Abs(a) |
141            Expr::Floor(a) | Expr::Ceil(a) | Expr::Atan(a) => a.contains_var(var),
142            Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) |
143            Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
144                a.contains_var(var) || b.contains_var(var)
145            }
146            Expr::Sum { body, .. } | Expr::Product { body, .. } |
147            Expr::Integral { body, .. } | Expr::Derivative { body, .. } => {
148                body.contains_var(var)
149            }
150        }
151    }
152
153    /// Whether this is a constant (no variables).
154    pub fn is_constant(&self) -> bool {
155        matches!(self, Expr::Const(_))
156    }
157
158    /// Substitute a variable with an expression.
159    pub fn substitute(&self, var: &str, replacement: &Expr) -> Expr {
160        match self {
161            Expr::Var(name) if name == var => replacement.clone(),
162            Expr::Var(_) | Expr::Const(_) => self.clone(),
163            Expr::Neg(a) => Expr::Neg(Box::new(a.substitute(var, replacement))),
164            Expr::Add(a, b) => Expr::Add(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
165            Expr::Sub(a, b) => Expr::Sub(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
166            Expr::Mul(a, b) => Expr::Mul(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
167            Expr::Div(a, b) => Expr::Div(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
168            Expr::Pow(a, b) => Expr::Pow(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
169            Expr::Sin(a) => Expr::Sin(Box::new(a.substitute(var, replacement))),
170            Expr::Cos(a) => Expr::Cos(Box::new(a.substitute(var, replacement))),
171            Expr::Tan(a) => Expr::Tan(Box::new(a.substitute(var, replacement))),
172            Expr::Ln(a) => Expr::Ln(Box::new(a.substitute(var, replacement))),
173            Expr::Exp(a) => Expr::Exp(Box::new(a.substitute(var, replacement))),
174            Expr::Sqrt(a) => Expr::Sqrt(Box::new(a.substitute(var, replacement))),
175            Expr::Abs(a) => Expr::Abs(Box::new(a.substitute(var, replacement))),
176            _ => self.clone(), // Simplified: other variants not substituted
177        }
178    }
179
180    /// Count the number of nodes in the expression tree.
181    pub fn node_count(&self) -> usize {
182        match self {
183            Expr::Var(_) | Expr::Const(_) => 1,
184            Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a) |
185            Expr::Ln(a) | Expr::Exp(a) | Expr::Sqrt(a) | Expr::Abs(a) |
186            Expr::Floor(a) | Expr::Ceil(a) | Expr::Atan(a) => 1 + a.node_count(),
187            Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) |
188            Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
189                1 + a.node_count() + b.node_count()
190            }
191            Expr::Sum { body, from, to, .. } | Expr::Product { body, from, to, .. } => {
192                1 + body.node_count() + from.node_count() + to.node_count()
193            }
194            Expr::Integral { body, .. } | Expr::Derivative { body, .. } => 1 + body.node_count(),
195        }
196    }
197}
198
199impl fmt::Display for Expr {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        match self {
202            Expr::Var(name) => write!(f, "{name}"),
203            Expr::Const(v) => {
204                if v.fract() == 0.0 && v.abs() < 1e12 { write!(f, "{}", *v as i64) }
205                else { write!(f, "{v:.4}") }
206            }
207            Expr::Neg(a) => write!(f, "(-{a})"),
208            Expr::Add(a, b) => write!(f, "({a} + {b})"),
209            Expr::Sub(a, b) => write!(f, "({a} - {b})"),
210            Expr::Mul(a, b) => write!(f, "({a} * {b})"),
211            Expr::Div(a, b) => write!(f, "({a} / {b})"),
212            Expr::Pow(a, b) => write!(f, "({a}^{b})"),
213            Expr::Sin(a) => write!(f, "sin({a})"),
214            Expr::Cos(a) => write!(f, "cos({a})"),
215            Expr::Tan(a) => write!(f, "tan({a})"),
216            Expr::Ln(a) => write!(f, "ln({a})"),
217            Expr::Exp(a) => write!(f, "exp({a})"),
218            Expr::Sqrt(a) => write!(f, "√({a})"),
219            Expr::Abs(a) => write!(f, "|{a}|"),
220            Expr::Floor(a) => write!(f, "⌊{a}⌋"),
221            Expr::Ceil(a) => write!(f, "⌈{a}⌉"),
222            Expr::Atan(a) => write!(f, "atan({a})"),
223            Expr::Atan2(y, x) => write!(f, "atan2({y}, {x})"),
224            Expr::Sum { body, var, from, to } => write!(f, "Σ({var}={from}..{to}){body}"),
225            Expr::Product { body, var, from, to } => write!(f, "Π({var}={from}..{to}){body}"),
226            Expr::Integral { body, var } => write!(f, "∫{body} d{var}"),
227            Expr::Derivative { body, var } => write!(f, "d/d{var}({body})"),
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use std::collections::HashMap;
236
237    #[test]
238    fn eval_constant() {
239        let e = Expr::c(42.0);
240        assert_eq!(e.eval(&HashMap::new()), 42.0);
241    }
242
243    #[test]
244    fn eval_variable() {
245        let e = Expr::var("x");
246        let mut vars = HashMap::new();
247        vars.insert("x".to_string(), 3.0);
248        assert_eq!(e.eval(&vars), 3.0);
249    }
250
251    #[test]
252    fn eval_arithmetic() {
253        let e = Expr::var("x").add(Expr::c(1.0)).mul(Expr::c(2.0));
254        let mut vars = HashMap::new();
255        vars.insert("x".to_string(), 4.0);
256        assert_eq!(e.eval(&vars), 10.0);
257    }
258
259    #[test]
260    fn eval_trig() {
261        let e = Expr::c(0.0).sin();
262        assert!((e.eval(&HashMap::new()) - 0.0).abs() < 1e-10);
263    }
264
265    #[test]
266    fn eval_sum() {
267        // Σ(i=1..3) i = 6
268        let e = Expr::Sum {
269            body: Box::new(Expr::var("i")),
270            var: "i".to_string(),
271            from: Box::new(Expr::c(1.0)),
272            to: Box::new(Expr::c(3.0)),
273        };
274        assert_eq!(e.eval(&HashMap::new()), 6.0);
275    }
276
277    #[test]
278    fn contains_var_works() {
279        let e = Expr::var("x").add(Expr::c(1.0));
280        assert!(e.contains_var("x"));
281        assert!(!e.contains_var("y"));
282    }
283
284    #[test]
285    fn substitute_works() {
286        let e = Expr::var("x").add(Expr::c(1.0));
287        let replaced = e.substitute("x", &Expr::c(5.0));
288        assert_eq!(replaced.eval(&HashMap::new()), 6.0);
289    }
290
291    #[test]
292    fn display_format() {
293        let e = Expr::var("x").pow(Expr::c(2.0)).add(Expr::c(1.0));
294        let s = format!("{e}");
295        assert!(s.contains("x"));
296        assert!(s.contains("2"));
297    }
298}