Skip to main content

scivex_sym/
expr.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::ops;
4
5use crate::error::{Result, SymError};
6
7/// Built-in mathematical functions.
8///
9/// # Examples
10///
11/// ```
12/// # use scivex_sym::MathFn;
13/// let f = MathFn::Sin;
14/// assert_eq!(format!("{f}"), "sin");
15/// ```
16#[cfg_attr(
17    feature = "serde-support",
18    derive(serde::Serialize, serde::Deserialize)
19)]
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum MathFn {
22    Sin,
23    Cos,
24    Tan,
25    Exp,
26    Ln,
27    Sqrt,
28    Abs,
29}
30
31impl fmt::Display for MathFn {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        let name = match self {
34            Self::Sin => "sin",
35            Self::Cos => "cos",
36            Self::Tan => "tan",
37            Self::Exp => "exp",
38            Self::Ln => "ln",
39            Self::Sqrt => "sqrt",
40            Self::Abs => "abs",
41        };
42        f.write_str(name)
43    }
44}
45
46/// A symbolic expression AST.
47///
48/// `Sub` is represented as `Add(a, Neg(b))` and `Div` as `Mul(a, Pow(b, Const(-1)))`.
49///
50/// # Examples
51///
52/// ```
53/// # use scivex_sym::{var, constant};
54/// # use std::collections::HashMap;
55/// let expr = var("x") + constant(1.0);
56/// let mut vars = HashMap::new();
57/// vars.insert("x".to_string(), 2.0);
58/// assert!((expr.eval(&vars).unwrap() - 3.0).abs() < 1e-10);
59/// ```
60#[cfg_attr(
61    feature = "serde-support",
62    derive(serde::Serialize, serde::Deserialize)
63)]
64#[derive(Debug, Clone, PartialEq)]
65pub enum Expr {
66    /// Numeric constant.
67    Const(f64),
68    /// Named variable.
69    Var(String),
70    /// Addition: `lhs + rhs`.
71    Add(Box<Expr>, Box<Expr>),
72    /// Multiplication: `lhs * rhs`.
73    Mul(Box<Expr>, Box<Expr>),
74    /// Exponentiation: `base ^ exp`.
75    Pow(Box<Expr>, Box<Expr>),
76    /// Negation: `-expr`.
77    Neg(Box<Expr>),
78    /// Function application: `f(arg)`.
79    Fn(MathFn, Box<Expr>),
80}
81
82// ---------------------------------------------------------------------------
83// Constructors
84// ---------------------------------------------------------------------------
85
86/// Create a constant expression.
87///
88/// # Examples
89///
90/// ```
91/// # use scivex_sym::constant;
92/// # use std::collections::HashMap;
93/// let five = constant(5.0);
94/// let val = five.eval(&HashMap::new()).unwrap();
95/// assert!((val - 5.0).abs() < 1e-10);
96/// ```
97#[must_use]
98pub fn constant(v: f64) -> Expr {
99    Expr::Const(v)
100}
101
102/// Create a variable expression.
103///
104/// # Examples
105///
106/// ```
107/// # use scivex_sym::var;
108/// # use std::collections::HashMap;
109/// let x = var("x");
110/// let mut vars = HashMap::new();
111/// vars.insert("x".to_string(), 7.0);
112/// assert!((x.eval(&vars).unwrap() - 7.0).abs() < 1e-10);
113/// ```
114#[must_use]
115pub fn var(name: &str) -> Expr {
116    Expr::Var(name.to_owned())
117}
118
119/// The additive identity.
120///
121/// # Examples
122///
123/// ```
124/// # use scivex_sym::zero;
125/// assert!(zero().is_zero());
126/// ```
127#[must_use]
128pub fn zero() -> Expr {
129    Expr::Const(0.0)
130}
131
132/// The multiplicative identity.
133///
134/// # Examples
135///
136/// ```
137/// # use scivex_sym::one;
138/// assert!(one().is_one());
139/// ```
140#[must_use]
141pub fn one() -> Expr {
142    Expr::Const(1.0)
143}
144
145/// The constant pi.
146///
147/// # Examples
148///
149/// ```
150/// # use scivex_sym::pi;
151/// assert!(pi().as_const().unwrap() > 3.14);
152/// ```
153#[must_use]
154pub fn pi() -> Expr {
155    Expr::Const(std::f64::consts::PI)
156}
157
158/// The constant e.
159///
160/// # Examples
161///
162/// ```
163/// # use scivex_sym::e;
164/// assert!(e().as_const().unwrap() > 2.71);
165/// ```
166#[must_use]
167pub fn e() -> Expr {
168    Expr::Const(std::f64::consts::E)
169}
170
171/// `sin(expr)`
172///
173/// # Examples
174///
175/// ```
176/// # use scivex_sym::{sin, pi};
177/// # use std::collections::HashMap;
178/// let val = sin(pi()).eval(&HashMap::new()).unwrap();
179/// assert!(val.abs() < 1e-10);
180/// ```
181#[must_use]
182pub fn sin(expr: Expr) -> Expr {
183    Expr::Fn(MathFn::Sin, Box::new(expr))
184}
185
186/// `cos(expr)`
187///
188/// # Examples
189///
190/// ```
191/// # use scivex_sym::{cos, constant};
192/// # use std::collections::HashMap;
193/// let val = cos(constant(0.0)).eval(&HashMap::new()).unwrap();
194/// assert!((val - 1.0).abs() < 1e-10);
195/// ```
196#[must_use]
197pub fn cos(expr: Expr) -> Expr {
198    Expr::Fn(MathFn::Cos, Box::new(expr))
199}
200
201/// `tan(expr)`
202///
203/// # Examples
204///
205/// ```
206/// # use scivex_sym::{tan, constant};
207/// # use std::collections::HashMap;
208/// let val = tan(constant(0.0)).eval(&HashMap::new()).unwrap();
209/// assert!(val.abs() < 1e-10);
210/// ```
211#[must_use]
212pub fn tan(expr: Expr) -> Expr {
213    Expr::Fn(MathFn::Tan, Box::new(expr))
214}
215
216/// `exp(expr)`
217///
218/// # Examples
219///
220/// ```
221/// # use scivex_sym::{exp, constant};
222/// # use std::collections::HashMap;
223/// let val = exp(constant(0.0)).eval(&HashMap::new()).unwrap();
224/// assert!((val - 1.0).abs() < 1e-10);
225/// ```
226#[must_use]
227pub fn exp(expr: Expr) -> Expr {
228    Expr::Fn(MathFn::Exp, Box::new(expr))
229}
230
231/// `ln(expr)`
232///
233/// # Examples
234///
235/// ```
236/// # use scivex_sym::{ln, e};
237/// # use std::collections::HashMap;
238/// let val = ln(e()).eval(&HashMap::new()).unwrap();
239/// assert!((val - 1.0).abs() < 1e-10);
240/// ```
241#[must_use]
242pub fn ln(expr: Expr) -> Expr {
243    Expr::Fn(MathFn::Ln, Box::new(expr))
244}
245
246/// `sqrt(expr)`
247///
248/// # Examples
249///
250/// ```
251/// # use scivex_sym::{sqrt, constant};
252/// # use std::collections::HashMap;
253/// let val = sqrt(constant(4.0)).eval(&HashMap::new()).unwrap();
254/// assert!((val - 2.0).abs() < 1e-10);
255/// ```
256#[must_use]
257pub fn sqrt(expr: Expr) -> Expr {
258    Expr::Fn(MathFn::Sqrt, Box::new(expr))
259}
260
261/// `|expr|`
262///
263/// # Examples
264///
265/// ```
266/// # use scivex_sym::{abs, constant};
267/// # use std::collections::HashMap;
268/// let val = abs(constant(-3.0)).eval(&HashMap::new()).unwrap();
269/// assert!((val - 3.0).abs() < 1e-10);
270/// ```
271#[must_use]
272pub fn abs(expr: Expr) -> Expr {
273    Expr::Fn(MathFn::Abs, Box::new(expr))
274}
275
276// ---------------------------------------------------------------------------
277// Core methods
278// ---------------------------------------------------------------------------
279
280impl Expr {
281    /// Evaluate the expression given concrete variable bindings.
282    ///
283    /// # Examples
284    ///
285    /// ```
286    /// # use scivex_sym::expr::{var, constant};
287    /// # use std::collections::HashMap;
288    /// let expr = constant(2.0) * var("x") + constant(1.0);
289    /// let vars = HashMap::from([("x".to_string(), 3.0)]);
290    /// assert!((expr.eval(&vars).unwrap() - 7.0).abs() < 1e-10);
291    /// ```
292    pub fn eval(&self, vars: &HashMap<String, f64>) -> Result<f64> {
293        match self {
294            Self::Const(v) => Ok(*v),
295            Self::Var(name) => vars
296                .get(name)
297                .copied()
298                .ok_or_else(|| SymError::UndefinedVariable { name: name.clone() }),
299            Self::Add(a, b) => Ok(a.eval(vars)? + b.eval(vars)?),
300            Self::Mul(a, b) => {
301                let av = a.eval(vars)?;
302                let bv = b.eval(vars)?;
303                Ok(av * bv)
304            }
305            Self::Pow(base, exp) => {
306                let bv = base.eval(vars)?;
307                let ev = exp.eval(vars)?;
308                // Check for 0^negative (division by zero).
309                if bv == 0.0 && ev < 0.0 {
310                    return Err(SymError::DivisionByZero);
311                }
312                Ok(bv.powf(ev))
313            }
314            Self::Neg(inner) => Ok(-inner.eval(vars)?),
315            Self::Fn(func, arg) => {
316                let v = arg.eval(vars)?;
317                Ok(match func {
318                    MathFn::Sin => v.sin(),
319                    MathFn::Cos => v.cos(),
320                    MathFn::Tan => v.tan(),
321                    MathFn::Exp => v.exp(),
322                    MathFn::Ln => v.ln(),
323                    MathFn::Sqrt => v.sqrt(),
324                    MathFn::Abs => v.abs(),
325                })
326            }
327        }
328    }
329
330    /// Replace every occurrence of `var` with `replacement`.
331    ///
332    /// # Examples
333    ///
334    /// ```
335    /// # use scivex_sym::expr::{var, constant};
336    /// # use std::collections::HashMap;
337    /// let expr = var("x") + constant(1.0);
338    /// let replaced = expr.substitute("x", &constant(5.0));
339    /// assert!((replaced.eval(&HashMap::new()).unwrap() - 6.0).abs() < 1e-10);
340    /// ```
341    #[must_use]
342    pub fn substitute(&self, var: &str, replacement: &Expr) -> Expr {
343        match self {
344            Self::Const(_) => self.clone(),
345            Self::Var(name) => {
346                if name == var {
347                    replacement.clone()
348                } else {
349                    self.clone()
350                }
351            }
352            Self::Add(a, b) => Expr::Add(
353                Box::new(a.substitute(var, replacement)),
354                Box::new(b.substitute(var, replacement)),
355            ),
356            Self::Mul(a, b) => Expr::Mul(
357                Box::new(a.substitute(var, replacement)),
358                Box::new(b.substitute(var, replacement)),
359            ),
360            Self::Pow(base, exp) => Expr::Pow(
361                Box::new(base.substitute(var, replacement)),
362                Box::new(exp.substitute(var, replacement)),
363            ),
364            Self::Neg(inner) => Expr::Neg(Box::new(inner.substitute(var, replacement))),
365            Self::Fn(func, arg) => Expr::Fn(*func, Box::new(arg.substitute(var, replacement))),
366        }
367    }
368
369    /// Collect all free variable names in the expression.
370    ///
371    /// # Examples
372    ///
373    /// ```
374    /// # use scivex_sym::expr::{var, constant};
375    /// let expr = var("x") + var("y") * constant(2.0);
376    /// let vars = expr.free_variables();
377    /// assert!(vars.contains("x"));
378    /// assert!(vars.contains("y"));
379    /// assert_eq!(vars.len(), 2);
380    /// ```
381    #[must_use]
382    pub fn free_variables(&self) -> HashSet<String> {
383        let mut set = HashSet::new();
384        self.collect_vars(&mut set);
385        set
386    }
387
388    fn collect_vars(&self, set: &mut HashSet<String>) {
389        match self {
390            Self::Const(_) => {}
391            Self::Var(name) => {
392                set.insert(name.clone());
393            }
394            Self::Add(a, b) | Self::Mul(a, b) | Self::Pow(a, b) => {
395                a.collect_vars(set);
396                b.collect_vars(set);
397            }
398            Self::Neg(inner) | Self::Fn(_, inner) => inner.collect_vars(set),
399        }
400    }
401
402    /// Returns `true` if the expression is `Const(0.0)`.
403    ///
404    /// # Examples
405    ///
406    /// ```
407    /// # use scivex_sym::expr::constant;
408    /// assert!(constant(0.0).is_zero());
409    /// assert!(!constant(1.0).is_zero());
410    /// ```
411    #[must_use]
412    pub fn is_zero(&self) -> bool {
413        matches!(self, Self::Const(v) if *v == 0.0)
414    }
415
416    /// Returns `true` if the expression is `Const(1.0)`.
417    ///
418    /// # Examples
419    ///
420    /// ```
421    /// # use scivex_sym::expr::constant;
422    /// assert!(constant(1.0).is_one());
423    /// assert!(!constant(2.0).is_one());
424    /// ```
425    #[must_use]
426    pub fn is_one(&self) -> bool {
427        matches!(self, Self::Const(v) if (*v - 1.0).abs() < f64::EPSILON)
428    }
429
430    /// Returns `true` if the expression is a constant.
431    ///
432    /// # Examples
433    ///
434    /// ```
435    /// # use scivex_sym::expr::{constant, var};
436    /// assert!(constant(3.14).is_const());
437    /// assert!(!var("x").is_const());
438    /// ```
439    #[must_use]
440    pub fn is_const(&self) -> bool {
441        matches!(self, Self::Const(_))
442    }
443
444    /// If the expression is a constant, return its value.
445    ///
446    /// # Examples
447    ///
448    /// ```
449    /// # use scivex_sym::expr::{constant, var};
450    /// assert_eq!(constant(42.0).as_const(), Some(42.0));
451    /// assert_eq!(var("x").as_const(), None);
452    /// ```
453    #[must_use]
454    pub fn as_const(&self) -> Option<f64> {
455        match self {
456            Self::Const(v) => Some(*v),
457            _ => None,
458        }
459    }
460}
461
462// ---------------------------------------------------------------------------
463// Operator overloading
464// ---------------------------------------------------------------------------
465
466impl ops::Add for Expr {
467    type Output = Self;
468    fn add(self, rhs: Self) -> Self {
469        Expr::Add(Box::new(self), Box::new(rhs))
470    }
471}
472
473impl ops::Sub for Expr {
474    type Output = Self;
475    fn sub(self, rhs: Self) -> Self {
476        Expr::Add(Box::new(self), Box::new(Expr::Neg(Box::new(rhs))))
477    }
478}
479
480impl ops::Mul for Expr {
481    type Output = Self;
482    fn mul(self, rhs: Self) -> Self {
483        Expr::Mul(Box::new(self), Box::new(rhs))
484    }
485}
486
487impl ops::Div for Expr {
488    type Output = Self;
489    fn div(self, rhs: Self) -> Self {
490        Expr::Mul(
491            Box::new(self),
492            Box::new(Expr::Pow(Box::new(rhs), Box::new(Expr::Const(-1.0)))),
493        )
494    }
495}
496
497impl ops::Neg for Expr {
498    type Output = Self;
499    fn neg(self) -> Self {
500        Expr::Neg(Box::new(self))
501    }
502}
503
504// ---------------------------------------------------------------------------
505// Display
506// ---------------------------------------------------------------------------
507
508impl fmt::Display for Expr {
509    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510        match self {
511            Self::Const(v) => {
512                if (*v - std::f64::consts::PI).abs() < f64::EPSILON {
513                    write!(f, "pi")
514                } else if *v < 0.0 {
515                    write!(f, "({v})")
516                } else {
517                    write!(f, "{v}")
518                }
519            }
520            Self::Var(name) => f.write_str(name),
521            Self::Add(a, b) => write!(f, "({a} + {b})"),
522            Self::Mul(a, b) => write!(f, "({a} * {b})"),
523            Self::Pow(base, exp) => write!(f, "({base}^{exp})"),
524            Self::Neg(inner) => write!(f, "(-{inner})"),
525            Self::Fn(func, arg) => write!(f, "{func}({arg})"),
526        }
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533
534    #[test]
535    fn eval_const_and_var() {
536        let e = constant(42.0);
537        assert!((e.eval(&HashMap::new()).unwrap() - 42.0).abs() < f64::EPSILON);
538
539        let x = var("x");
540        let mut vars = HashMap::new();
541        vars.insert("x".into(), 3.0);
542        assert!((x.eval(&vars).unwrap() - 3.0).abs() < f64::EPSILON);
543    }
544
545    #[test]
546    fn eval_undefined_variable() {
547        let x = var("x");
548        let err = x.eval(&HashMap::new()).unwrap_err();
549        assert!(matches!(err, SymError::UndefinedVariable { name } if name == "x"));
550    }
551
552    #[test]
553    fn eval_division_by_zero() {
554        // 1 / 0 = 1 * 0^(-1)
555        let e = constant(1.0) / constant(0.0);
556        let err = e.eval(&HashMap::new()).unwrap_err();
557        assert!(matches!(err, SymError::DivisionByZero));
558    }
559
560    #[test]
561    fn eval_arithmetic() {
562        let mut vars = HashMap::new();
563        vars.insert("x".into(), 2.0);
564        // (x + 3) * 4 = 20
565        let e = (var("x") + constant(3.0)) * constant(4.0);
566        assert!((e.eval(&vars).unwrap() - 20.0).abs() < f64::EPSILON);
567    }
568
569    #[test]
570    fn eval_functions() {
571        let vars = HashMap::new();
572        let e = sin(constant(0.0));
573        assert!(e.eval(&vars).unwrap().abs() < f64::EPSILON);
574
575        let e = cos(constant(0.0));
576        assert!((e.eval(&vars).unwrap() - 1.0).abs() < f64::EPSILON);
577
578        let e = exp(constant(0.0));
579        assert!((e.eval(&vars).unwrap() - 1.0).abs() < f64::EPSILON);
580
581        let e = ln(constant(1.0));
582        assert!(e.eval(&vars).unwrap().abs() < f64::EPSILON);
583    }
584
585    #[test]
586    fn substitute_works() {
587        let e = var("x") + constant(1.0);
588        let replaced = e.substitute("x", &constant(5.0));
589        assert!((replaced.eval(&HashMap::new()).unwrap() - 6.0).abs() < f64::EPSILON);
590    }
591
592    #[test]
593    fn free_variables_collected() {
594        let e = var("x") * var("y") + sin(var("x"));
595        let fv = e.free_variables();
596        assert!(fv.contains("x"));
597        assert!(fv.contains("y"));
598        assert_eq!(fv.len(), 2);
599    }
600
601    #[test]
602    fn display_formatting() {
603        let e = var("x") + constant(1.0);
604        let s = format!("{e}");
605        assert_eq!(s, "(x + 1)");
606    }
607
608    #[test]
609    fn is_predicates() {
610        assert!(zero().is_zero());
611        assert!(one().is_one());
612        assert!(constant(3.5).is_const());
613        assert!(!var("x").is_const());
614        assert_eq!(constant(7.0).as_const(), Some(7.0));
615        assert_eq!(var("x").as_const(), None);
616    }
617}