mathhook_core/algebra/simplification/
elementary.rs

1//! Elementary Function Simplification Strategies
2//!
3//! Implements algebraic rewrite rules for elementary functions (sqrt, abs, exp).
4
5use super::strategy::SimplificationStrategy;
6use crate::core::{Expression, Number};
7use num_bigint::BigInt;
8use num_traits::{ToPrimitive, Zero};
9
10/// Square root simplification strategy
11pub struct SqrtSimplificationStrategy;
12
13impl SqrtSimplificationStrategy {
14    fn integer_sqrt(&self, n: &BigInt) -> Option<BigInt> {
15        if n < &BigInt::zero() {
16            return None;
17        }
18
19        if let Some(val) = n.to_i64() {
20            let sqrt_val = (val as f64).sqrt() as i64;
21            let sqrt_bigint = BigInt::from(sqrt_val);
22
23            if &(&sqrt_bigint * &sqrt_bigint) == n {
24                Some(sqrt_bigint)
25            } else {
26                None
27            }
28        } else {
29            None
30        }
31    }
32}
33
34impl SimplificationStrategy for SqrtSimplificationStrategy {
35    fn simplify(&self, args: &[Expression]) -> Expression {
36        if args.len() == 1 {
37            match &args[0] {
38                Expression::Number(Number::Integer(n)) => {
39                    if n.is_zero() {
40                        Expression::integer(0)
41                    } else if *n == 1 {
42                        Expression::integer(1)
43                    } else if let Some(sqrt_val) = self.integer_sqrt(&BigInt::from(*n)) {
44                        Expression::big_integer(sqrt_val)
45                    } else {
46                        Expression::function("sqrt", args.to_vec())
47                    }
48                }
49
50                Expression::Pow(base, exp) => {
51                    if exp.as_ref() == &Expression::integer(2) {
52                        base.as_ref().clone()
53                    } else {
54                        Expression::function("sqrt", args.to_vec())
55                    }
56                }
57
58                _ => Expression::function("sqrt", args.to_vec()),
59            }
60        } else {
61            Expression::function("sqrt", args.to_vec())
62        }
63    }
64
65    fn applies_to(&self, args: &[Expression]) -> bool {
66        args.len() == 1
67    }
68
69    fn name(&self) -> &str {
70        "SqrtSimplificationStrategy"
71    }
72}
73
74/// Absolute value simplification strategy
75pub struct AbsSimplificationStrategy;
76
77impl SimplificationStrategy for AbsSimplificationStrategy {
78    fn simplify(&self, args: &[Expression]) -> Expression {
79        if args.len() == 1 {
80            match &args[0] {
81                Expression::Number(Number::Integer(n)) => Expression::integer(n.abs()),
82                Expression::Number(Number::Float(f)) => Expression::number(Number::float(f.abs())),
83                _ => Expression::function("abs", args.to_vec()),
84            }
85        } else {
86            Expression::function("abs", args.to_vec())
87        }
88    }
89
90    fn applies_to(&self, args: &[Expression]) -> bool {
91        args.len() == 1
92    }
93
94    fn name(&self) -> &str {
95        "AbsSimplificationStrategy"
96    }
97}
98
99/// Exponential function simplification strategy
100pub struct ExpSimplificationStrategy;
101
102impl SimplificationStrategy for ExpSimplificationStrategy {
103    fn simplify(&self, args: &[Expression]) -> Expression {
104        if args.len() == 1 {
105            match &args[0] {
106                Expression::Number(Number::Integer(n)) if n.is_zero() => Expression::integer(1),
107
108                Expression::Function {
109                    name,
110                    args: inner_args,
111                } if name == "ln" && inner_args.len() == 1 => inner_args[0].clone(),
112
113                _ => Expression::function("exp", args.to_vec()),
114            }
115        } else {
116            Expression::function("exp", args.to_vec())
117        }
118    }
119
120    fn applies_to(&self, args: &[Expression]) -> bool {
121        args.len() == 1
122    }
123
124    fn name(&self) -> &str {
125        "ExpSimplificationStrategy"
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::{expr, symbol};
133
134    #[test]
135    fn test_sqrt_of_zero() {
136        let strategy = SqrtSimplificationStrategy;
137        let result = strategy.simplify(&[expr!(0)]);
138        assert_eq!(result, expr!(0));
139    }
140
141    #[test]
142    fn test_sqrt_of_one() {
143        let strategy = SqrtSimplificationStrategy;
144        let result = strategy.simplify(&[expr!(1)]);
145        assert_eq!(result, expr!(1));
146    }
147
148    #[test]
149    fn test_sqrt_of_four() {
150        let strategy = SqrtSimplificationStrategy;
151        let result = strategy.simplify(&[expr!(4)]);
152        assert_eq!(result, expr!(2));
153    }
154
155    #[test]
156    fn test_sqrt_of_nine() {
157        let strategy = SqrtSimplificationStrategy;
158        let result = strategy.simplify(&[expr!(9)]);
159        assert_eq!(result, expr!(3));
160    }
161
162    #[test]
163    fn test_sqrt_of_power() {
164        let strategy = SqrtSimplificationStrategy;
165        let x = symbol!(x);
166        let result = strategy.simplify(&[expr!(x ^ 2)]);
167        assert_eq!(result, x.into());
168    }
169
170    #[test]
171    fn test_abs_of_positive_integer() {
172        let strategy = AbsSimplificationStrategy;
173        let result = strategy.simplify(&[expr!(5)]);
174        assert_eq!(result, expr!(5));
175    }
176
177    #[test]
178    fn test_abs_of_negative_integer() {
179        let strategy = AbsSimplificationStrategy;
180        let result = strategy.simplify(&[expr!(-5)]);
181        assert_eq!(result, expr!(5));
182    }
183
184    #[test]
185    fn test_exp_of_zero() {
186        let strategy = ExpSimplificationStrategy;
187        let result = strategy.simplify(&[expr!(0)]);
188        assert_eq!(result, expr!(1));
189    }
190
191    #[test]
192    fn test_exp_of_ln() {
193        let strategy = ExpSimplificationStrategy;
194        let x = symbol!(x);
195        let ln_x = Expression::function("ln", vec![x.clone().into()]);
196        let result = strategy.simplify(&[ln_x]);
197        assert_eq!(result, x.into());
198    }
199}