mathhook_core/algebra/simplification/
elementary.rs1use super::strategy::SimplificationStrategy;
6use crate::core::{Expression, Number};
7use num_bigint::BigInt;
8use num_traits::{ToPrimitive, Zero};
9
10pub struct SqrtSimplificationStrategy;
12
13impl SqrtSimplificationStrategy {
14 fn integer_sqrt(&self, n: &BigInt) -> Option<BigInt> {
15 if n < &BigInt::zero() {
16 return None;
17 }
18
19 if let Some(val) = n.to_i64() {
20 let sqrt_val = (val as f64).sqrt() as i64;
21 let sqrt_bigint = BigInt::from(sqrt_val);
22
23 if &(&sqrt_bigint * &sqrt_bigint) == n {
24 Some(sqrt_bigint)
25 } else {
26 None
27 }
28 } else {
29 None
30 }
31 }
32}
33
34impl SimplificationStrategy for SqrtSimplificationStrategy {
35 fn simplify(&self, args: &[Expression]) -> Expression {
36 if args.len() == 1 {
37 match &args[0] {
38 Expression::Number(Number::Integer(n)) => {
39 if n.is_zero() {
40 Expression::integer(0)
41 } else if *n == 1 {
42 Expression::integer(1)
43 } else if let Some(sqrt_val) = self.integer_sqrt(&BigInt::from(*n)) {
44 Expression::big_integer(sqrt_val)
45 } else {
46 Expression::function("sqrt", args.to_vec())
47 }
48 }
49
50 Expression::Pow(base, exp) => {
51 if exp.as_ref() == &Expression::integer(2) {
52 base.as_ref().clone()
53 } else {
54 Expression::function("sqrt", args.to_vec())
55 }
56 }
57
58 _ => Expression::function("sqrt", args.to_vec()),
59 }
60 } else {
61 Expression::function("sqrt", args.to_vec())
62 }
63 }
64
65 fn applies_to(&self, args: &[Expression]) -> bool {
66 args.len() == 1
67 }
68
69 fn name(&self) -> &str {
70 "SqrtSimplificationStrategy"
71 }
72}
73
74pub struct AbsSimplificationStrategy;
76
77impl SimplificationStrategy for AbsSimplificationStrategy {
78 fn simplify(&self, args: &[Expression]) -> Expression {
79 if args.len() == 1 {
80 match &args[0] {
81 Expression::Number(Number::Integer(n)) => Expression::integer(n.abs()),
82 Expression::Number(Number::Float(f)) => Expression::number(Number::float(f.abs())),
83 _ => Expression::function("abs", args.to_vec()),
84 }
85 } else {
86 Expression::function("abs", args.to_vec())
87 }
88 }
89
90 fn applies_to(&self, args: &[Expression]) -> bool {
91 args.len() == 1
92 }
93
94 fn name(&self) -> &str {
95 "AbsSimplificationStrategy"
96 }
97}
98
99pub struct ExpSimplificationStrategy;
101
102impl SimplificationStrategy for ExpSimplificationStrategy {
103 fn simplify(&self, args: &[Expression]) -> Expression {
104 if args.len() == 1 {
105 match &args[0] {
106 Expression::Number(Number::Integer(n)) if n.is_zero() => Expression::integer(1),
107
108 Expression::Function {
109 name,
110 args: inner_args,
111 } if name.as_ref() == "ln" && inner_args.len() == 1 => inner_args[0].clone(),
112
113 _ => Expression::function("exp", args.to_vec()),
114 }
115 } else {
116 Expression::function("exp", args.to_vec())
117 }
118 }
119
120 fn applies_to(&self, args: &[Expression]) -> bool {
121 args.len() == 1
122 }
123
124 fn name(&self) -> &str {
125 "ExpSimplificationStrategy"
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::{expr, symbol};
133
134 #[test]
135 fn test_sqrt_of_zero() {
136 let strategy = SqrtSimplificationStrategy;
137 let result = strategy.simplify(&[expr!(0)]);
138 assert_eq!(result, expr!(0));
139 }
140
141 #[test]
142 fn test_sqrt_of_one() {
143 let strategy = SqrtSimplificationStrategy;
144 let result = strategy.simplify(&[expr!(1)]);
145 assert_eq!(result, expr!(1));
146 }
147
148 #[test]
149 fn test_sqrt_of_four() {
150 let strategy = SqrtSimplificationStrategy;
151 let result = strategy.simplify(&[expr!(4)]);
152 assert_eq!(result, expr!(2));
153 }
154
155 #[test]
156 fn test_sqrt_of_nine() {
157 let strategy = SqrtSimplificationStrategy;
158 let result = strategy.simplify(&[expr!(9)]);
159 assert_eq!(result, expr!(3));
160 }
161
162 #[test]
163 fn test_sqrt_of_power() {
164 let strategy = SqrtSimplificationStrategy;
165 let x = symbol!(x);
166 let result = strategy.simplify(&[expr!(x ^ 2)]);
167 assert_eq!(result, x.into());
168 }
169
170 #[test]
171 fn test_abs_of_positive_integer() {
172 let strategy = AbsSimplificationStrategy;
173 let result = strategy.simplify(&[expr!(5)]);
174 assert_eq!(result, expr!(5));
175 }
176
177 #[test]
178 fn test_abs_of_negative_integer() {
179 let strategy = AbsSimplificationStrategy;
180 let result = strategy.simplify(&[expr!(-5)]);
181 assert_eq!(result, expr!(5));
182 }
183
184 #[test]
185 fn test_exp_of_zero() {
186 let strategy = ExpSimplificationStrategy;
187 let result = strategy.simplify(&[expr!(0)]);
188 assert_eq!(result, expr!(1));
189 }
190
191 #[test]
192 fn test_exp_of_ln() {
193 let strategy = ExpSimplificationStrategy;
194 let x = symbol!(x);
195 let ln_x = Expression::function("ln", vec![x.clone().into()]);
196 let result = strategy.simplify(&[ln_x]);
197 assert_eq!(result, x.into());
198 }
199}