1use std::fmt;
4
5#[derive(Debug, Clone, PartialEq)]
7pub enum Expr {
8 Var(String),
10 Const(f64),
12 Neg(Box<Expr>),
14 Add(Box<Expr>, Box<Expr>),
16 Sub(Box<Expr>, Box<Expr>),
18 Mul(Box<Expr>, Box<Expr>),
20 Div(Box<Expr>, Box<Expr>),
22 Pow(Box<Expr>, Box<Expr>),
24 Sin(Box<Expr>),
26 Cos(Box<Expr>),
28 Tan(Box<Expr>),
30 Ln(Box<Expr>),
32 Exp(Box<Expr>),
34 Sqrt(Box<Expr>),
36 Abs(Box<Expr>),
38 Floor(Box<Expr>),
40 Ceil(Box<Expr>),
42 Atan(Box<Expr>),
44 Atan2(Box<Expr>, Box<Expr>),
46 Sum { body: Box<Expr>, var: String, from: Box<Expr>, to: Box<Expr> },
48 Product { body: Box<Expr>, var: String, from: Box<Expr>, to: Box<Expr> },
50 Integral { body: Box<Expr>, var: String },
52 Derivative { body: Box<Expr>, var: String },
54}
55
56pub fn Var(name: &str) -> Expr { Expr::Var(name.to_string()) }
58pub fn Const(val: f64) -> Expr { Expr::Const(val) }
59
60impl Expr {
61 pub fn var(name: &str) -> Self { Self::Var(name.to_string()) }
62 pub fn c(val: f64) -> Self { Self::Const(val) }
63 pub fn zero() -> Self { Self::Const(0.0) }
64 pub fn one() -> Self { Self::Const(1.0) }
65 pub fn pi() -> Self { Self::Const(std::f64::consts::PI) }
66 pub fn e() -> Self { Self::Const(std::f64::consts::E) }
67
68 pub fn add(self, other: Expr) -> Expr { Expr::Add(Box::new(self), Box::new(other)) }
70 pub fn sub(self, other: Expr) -> Expr { Expr::Sub(Box::new(self), Box::new(other)) }
71 pub fn mul(self, other: Expr) -> Expr { Expr::Mul(Box::new(self), Box::new(other)) }
72 pub fn div(self, other: Expr) -> Expr { Expr::Div(Box::new(self), Box::new(other)) }
73 pub fn pow(self, exp: Expr) -> Expr { Expr::Pow(Box::new(self), Box::new(exp)) }
74
75 pub fn neg(self) -> Expr { Expr::Neg(Box::new(self)) }
77 pub fn sin(self) -> Expr { Expr::Sin(Box::new(self)) }
78 pub fn cos(self) -> Expr { Expr::Cos(Box::new(self)) }
79 pub fn tan(self) -> Expr { Expr::Tan(Box::new(self)) }
80 pub fn ln(self) -> Expr { Expr::Ln(Box::new(self)) }
81 pub fn exp(self) -> Expr { Expr::Exp(Box::new(self)) }
82 pub fn sqrt(self) -> Expr { Expr::Sqrt(Box::new(self)) }
83 pub fn abs(self) -> Expr { Expr::Abs(Box::new(self)) }
84
85 pub fn eval(&self, vars: &std::collections::HashMap<String, f64>) -> f64 {
87 match self {
88 Expr::Var(name) => *vars.get(name).unwrap_or(&0.0),
89 Expr::Const(v) => *v,
90 Expr::Neg(a) => -a.eval(vars),
91 Expr::Add(a, b) => a.eval(vars) + b.eval(vars),
92 Expr::Sub(a, b) => a.eval(vars) - b.eval(vars),
93 Expr::Mul(a, b) => a.eval(vars) * b.eval(vars),
94 Expr::Div(a, b) => { let d = b.eval(vars); if d.abs() < 1e-15 { f64::NAN } else { a.eval(vars) / d } }
95 Expr::Pow(a, b) => a.eval(vars).powf(b.eval(vars)),
96 Expr::Sin(a) => a.eval(vars).sin(),
97 Expr::Cos(a) => a.eval(vars).cos(),
98 Expr::Tan(a) => a.eval(vars).tan(),
99 Expr::Ln(a) => a.eval(vars).ln(),
100 Expr::Exp(a) => a.eval(vars).exp(),
101 Expr::Sqrt(a) => a.eval(vars).sqrt(),
102 Expr::Abs(a) => a.eval(vars).abs(),
103 Expr::Floor(a) => a.eval(vars).floor(),
104 Expr::Ceil(a) => a.eval(vars).ceil(),
105 Expr::Atan(a) => a.eval(vars).atan(),
106 Expr::Atan2(y, x) => y.eval(vars).atan2(x.eval(vars)),
107 Expr::Sum { body, var, from, to } => {
108 let f = from.eval(vars) as i64;
109 let t = to.eval(vars) as i64;
110 let mut sum = 0.0;
111 let mut local = vars.clone();
112 for i in f..=t {
113 local.insert(var.clone(), i as f64);
114 sum += body.eval(&local);
115 }
116 sum
117 }
118 Expr::Product { body, var, from, to } => {
119 let f = from.eval(vars) as i64;
120 let t = to.eval(vars) as i64;
121 let mut prod = 1.0;
122 let mut local = vars.clone();
123 for i in f..=t {
124 local.insert(var.clone(), i as f64);
125 prod *= body.eval(&local);
126 }
127 prod
128 }
129 Expr::Integral { .. } => f64::NAN, Expr::Derivative { .. } => f64::NAN,
131 }
132 }
133
134 pub fn contains_var(&self, var: &str) -> bool {
136 match self {
137 Expr::Var(name) => name == var,
138 Expr::Const(_) => false,
139 Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a) |
140 Expr::Ln(a) | Expr::Exp(a) | Expr::Sqrt(a) | Expr::Abs(a) |
141 Expr::Floor(a) | Expr::Ceil(a) | Expr::Atan(a) => a.contains_var(var),
142 Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) |
143 Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
144 a.contains_var(var) || b.contains_var(var)
145 }
146 Expr::Sum { body, .. } | Expr::Product { body, .. } |
147 Expr::Integral { body, .. } | Expr::Derivative { body, .. } => {
148 body.contains_var(var)
149 }
150 }
151 }
152
153 pub fn is_constant(&self) -> bool {
155 matches!(self, Expr::Const(_))
156 }
157
158 pub fn substitute(&self, var: &str, replacement: &Expr) -> Expr {
160 match self {
161 Expr::Var(name) if name == var => replacement.clone(),
162 Expr::Var(_) | Expr::Const(_) => self.clone(),
163 Expr::Neg(a) => Expr::Neg(Box::new(a.substitute(var, replacement))),
164 Expr::Add(a, b) => Expr::Add(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
165 Expr::Sub(a, b) => Expr::Sub(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
166 Expr::Mul(a, b) => Expr::Mul(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
167 Expr::Div(a, b) => Expr::Div(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
168 Expr::Pow(a, b) => Expr::Pow(Box::new(a.substitute(var, replacement)), Box::new(b.substitute(var, replacement))),
169 Expr::Sin(a) => Expr::Sin(Box::new(a.substitute(var, replacement))),
170 Expr::Cos(a) => Expr::Cos(Box::new(a.substitute(var, replacement))),
171 Expr::Tan(a) => Expr::Tan(Box::new(a.substitute(var, replacement))),
172 Expr::Ln(a) => Expr::Ln(Box::new(a.substitute(var, replacement))),
173 Expr::Exp(a) => Expr::Exp(Box::new(a.substitute(var, replacement))),
174 Expr::Sqrt(a) => Expr::Sqrt(Box::new(a.substitute(var, replacement))),
175 Expr::Abs(a) => Expr::Abs(Box::new(a.substitute(var, replacement))),
176 _ => self.clone(), }
178 }
179
180 pub fn node_count(&self) -> usize {
182 match self {
183 Expr::Var(_) | Expr::Const(_) => 1,
184 Expr::Neg(a) | Expr::Sin(a) | Expr::Cos(a) | Expr::Tan(a) |
185 Expr::Ln(a) | Expr::Exp(a) | Expr::Sqrt(a) | Expr::Abs(a) |
186 Expr::Floor(a) | Expr::Ceil(a) | Expr::Atan(a) => 1 + a.node_count(),
187 Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) |
188 Expr::Div(a, b) | Expr::Pow(a, b) | Expr::Atan2(a, b) => {
189 1 + a.node_count() + b.node_count()
190 }
191 Expr::Sum { body, from, to, .. } | Expr::Product { body, from, to, .. } => {
192 1 + body.node_count() + from.node_count() + to.node_count()
193 }
194 Expr::Integral { body, .. } | Expr::Derivative { body, .. } => 1 + body.node_count(),
195 }
196 }
197}
198
199impl fmt::Display for Expr {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 match self {
202 Expr::Var(name) => write!(f, "{name}"),
203 Expr::Const(v) => {
204 if v.fract() == 0.0 && v.abs() < 1e12 { write!(f, "{}", *v as i64) }
205 else { write!(f, "{v:.4}") }
206 }
207 Expr::Neg(a) => write!(f, "(-{a})"),
208 Expr::Add(a, b) => write!(f, "({a} + {b})"),
209 Expr::Sub(a, b) => write!(f, "({a} - {b})"),
210 Expr::Mul(a, b) => write!(f, "({a} * {b})"),
211 Expr::Div(a, b) => write!(f, "({a} / {b})"),
212 Expr::Pow(a, b) => write!(f, "({a}^{b})"),
213 Expr::Sin(a) => write!(f, "sin({a})"),
214 Expr::Cos(a) => write!(f, "cos({a})"),
215 Expr::Tan(a) => write!(f, "tan({a})"),
216 Expr::Ln(a) => write!(f, "ln({a})"),
217 Expr::Exp(a) => write!(f, "exp({a})"),
218 Expr::Sqrt(a) => write!(f, "√({a})"),
219 Expr::Abs(a) => write!(f, "|{a}|"),
220 Expr::Floor(a) => write!(f, "⌊{a}⌋"),
221 Expr::Ceil(a) => write!(f, "⌈{a}⌉"),
222 Expr::Atan(a) => write!(f, "atan({a})"),
223 Expr::Atan2(y, x) => write!(f, "atan2({y}, {x})"),
224 Expr::Sum { body, var, from, to } => write!(f, "Σ({var}={from}..{to}){body}"),
225 Expr::Product { body, var, from, to } => write!(f, "Π({var}={from}..{to}){body}"),
226 Expr::Integral { body, var } => write!(f, "∫{body} d{var}"),
227 Expr::Derivative { body, var } => write!(f, "d/d{var}({body})"),
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use std::collections::HashMap;
236
237 #[test]
238 fn eval_constant() {
239 let e = Expr::c(42.0);
240 assert_eq!(e.eval(&HashMap::new()), 42.0);
241 }
242
243 #[test]
244 fn eval_variable() {
245 let e = Expr::var("x");
246 let mut vars = HashMap::new();
247 vars.insert("x".to_string(), 3.0);
248 assert_eq!(e.eval(&vars), 3.0);
249 }
250
251 #[test]
252 fn eval_arithmetic() {
253 let e = Expr::var("x").add(Expr::c(1.0)).mul(Expr::c(2.0));
254 let mut vars = HashMap::new();
255 vars.insert("x".to_string(), 4.0);
256 assert_eq!(e.eval(&vars), 10.0);
257 }
258
259 #[test]
260 fn eval_trig() {
261 let e = Expr::c(0.0).sin();
262 assert!((e.eval(&HashMap::new()) - 0.0).abs() < 1e-10);
263 }
264
265 #[test]
266 fn eval_sum() {
267 let e = Expr::Sum {
269 body: Box::new(Expr::var("i")),
270 var: "i".to_string(),
271 from: Box::new(Expr::c(1.0)),
272 to: Box::new(Expr::c(3.0)),
273 };
274 assert_eq!(e.eval(&HashMap::new()), 6.0);
275 }
276
277 #[test]
278 fn contains_var_works() {
279 let e = Expr::var("x").add(Expr::c(1.0));
280 assert!(e.contains_var("x"));
281 assert!(!e.contains_var("y"));
282 }
283
284 #[test]
285 fn substitute_works() {
286 let e = Expr::var("x").add(Expr::c(1.0));
287 let replaced = e.substitute("x", &Expr::c(5.0));
288 assert_eq!(replaced.eval(&HashMap::new()), 6.0);
289 }
290
291 #[test]
292 fn display_format() {
293 let e = Expr::var("x").pow(Expr::c(2.0)).add(Expr::c(1.0));
294 let s = format!("{e}");
295 assert!(s.contains("x"));
296 assert!(s.contains("2"));
297 }
298}