mathhook_core/simplify/arithmetic/
power.rs1use super::multiplication::simplify_multiplication;
4use super::Simplify;
5use crate::core::commutativity::Commutativity;
6use crate::core::{Expression, Number};
7use num_bigint::BigInt;
8use num_rational::BigRational;
9
10pub fn simplify_power(base: &Expression, exp: &Expression) -> Expression {
12 let simplified_base = base.simplify();
14 let simplified_exp = exp.simplify();
15
16 match (&simplified_base, &simplified_exp) {
17 (_, Expression::Number(Number::Integer(0))) => Expression::integer(1),
19 (_, Expression::Number(Number::Integer(1))) => simplified_base,
21 (Expression::Number(Number::Integer(1)), _) => Expression::integer(1),
23 (Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(n)))
25 if *n > 0 =>
26 {
27 Expression::integer(0)
28 }
29 (Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(-1))) => {
31 Expression::function("undefined".to_owned(), vec![])
32 }
33 (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
35 if *n > 0 && *a != 0 =>
36 {
37 if let Some(result) = (*a).checked_pow(*n as u32) {
39 Expression::integer(result)
40 } else {
41 let base_big = BigInt::from(*a);
43 let result_big = base_big.pow(*n as u32);
44 Expression::Number(Number::rational(BigRational::new(
45 result_big,
46 BigInt::from(1),
47 )))
48 }
49 }
50 (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(-1)))
52 if *a != 0 =>
53 {
54 Expression::Number(Number::rational(BigRational::new(
55 BigInt::from(1),
56 BigInt::from(*a),
57 )))
58 }
59 (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(-1))) => {
61 Expression::Number(Number::rational(BigRational::new(
62 r.denom().clone(),
63 r.numer().clone(),
64 )))
65 }
66 (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(n)))
68 if *n > 0 =>
69 {
70 let exp = *n as u32;
71 let numerator = r.numer().pow(exp);
72 let denominator = r.denom().pow(exp);
73 Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
74 }
75 (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
77 if *n < 0 && *a != 0 =>
78 {
79 let positive_exp = (-n) as u32;
80 let numerator = BigInt::from(1);
81 let denominator = BigInt::from(*a).pow(positive_exp);
82 Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
83 }
84 (Expression::Function { name, args }, Expression::Number(Number::Integer(2)))
86 if name == "sqrt" && args.len() == 1 =>
87 {
88 args[0].clone()
89 }
90 (Expression::Pow(b, e), c) => {
92 let new_exp = simplify_multiplication(&[(**e).clone(), c.clone()]);
93 Expression::Pow(Box::new((**b).clone()), Box::new(new_exp))
94 }
95 (Expression::Mul(factors), Expression::Number(Number::Integer(n))) if *n > 0 => {
97 let commutativity = Commutativity::combine(factors.iter().map(|f| f.commutativity()));
98
99 if commutativity.can_sort() {
100 let powered_factors: Vec<Expression> = factors
102 .iter()
103 .map(|f| Expression::pow(f.clone(), simplified_exp.clone()))
104 .collect();
105 simplify_multiplication(&powered_factors)
106 } else {
107 Expression::Pow(Box::new(simplified_base), Box::new(simplified_exp))
109 }
110 }
111 _ => Expression::Pow(Box::new(simplified_base), Box::new(simplified_exp)),
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::simplify::Simplify;
119 use crate::symbol;
120 use crate::Expression;
121
122 #[test]
123 fn test_power_simplification() {
124 let x = symbol!(x);
125
126 let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(0));
128 assert_eq!(expr, Expression::integer(1));
129
130 let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(1));
132 assert_eq!(expr, Expression::symbol(x));
133 }
134
135 #[test]
136 fn test_scalar_power_distributed() {
137 let x = symbol!(x);
138 let y = symbol!(y);
139 let xy = Expression::mul(vec![
140 Expression::symbol(x.clone()),
141 Expression::symbol(y.clone()),
142 ]);
143 let expr = Expression::pow(xy, Expression::integer(2));
144
145 let simplified = expr.simplify();
146
147 match simplified {
148 Expression::Mul(factors) => {
149 assert_eq!(factors.len(), 2);
150 let has_x_squared = factors.iter().any(|f| {
151 matches!(f, Expression::Pow(base, exp) if
152 **base == Expression::symbol(symbol!(x)) &&
153 **exp == Expression::integer(2))
154 });
155 let has_y_squared = factors.iter().any(|f| {
156 matches!(f, Expression::Pow(base, exp) if
157 **base == Expression::symbol(symbol!(y)) &&
158 **exp == Expression::integer(2))
159 });
160 assert!(has_x_squared, "Expected x^2 in factors");
161 assert!(has_y_squared, "Expected y^2 in factors");
162 }
163 _ => panic!("Expected Mul, got {:?}", simplified),
164 }
165 }
166
167 #[test]
168 fn test_matrix_power_not_distributed() {
169 let matrix_a = symbol!(A; matrix);
170 let matrix_b = symbol!(B; matrix);
171 let ab = Expression::mul(vec![
172 Expression::symbol(matrix_a.clone()),
173 Expression::symbol(matrix_b.clone()),
174 ]);
175 let expr = Expression::pow(ab.clone(), Expression::integer(2));
176
177 let simplified = expr.simplify();
178
179 match simplified {
180 Expression::Pow(base, exp) => {
181 assert_eq!(*exp, Expression::integer(2));
182 match *base {
183 Expression::Mul(factors) => {
184 assert_eq!(factors.len(), 2);
185 assert!(factors.iter().all(|f| matches!(f, Expression::Symbol(s) if s.symbol_type() == crate::core::symbol::SymbolType::Matrix)));
186 }
187 _ => panic!("Expected Mul base, got {:?}", base),
188 }
189 }
190 _ => panic!("Expected Pow, got {:?}", simplified),
191 }
192 }
193
194 #[test]
195 fn test_operator_power_not_distributed() {
196 let matrix_p = symbol!(P; operator);
197 let matrix_q = symbol!(Q; operator);
198 let pq = Expression::mul(vec![
199 Expression::symbol(matrix_p.clone()),
200 Expression::symbol(matrix_q.clone()),
201 ]);
202 let expr = Expression::pow(pq, Expression::integer(2));
203
204 let simplified = expr.simplify();
205
206 match simplified {
207 Expression::Pow(base, exp) => {
208 assert_eq!(*exp, Expression::integer(2));
209 match *base {
210 Expression::Mul(factors) => {
211 assert_eq!(factors.len(), 2);
212 }
213 _ => panic!("Expected Mul base, got {:?}", base),
214 }
215 }
216 _ => panic!("Expected Pow, got {:?}", simplified),
217 }
218 }
219
220 #[test]
221 fn test_quaternion_power_not_distributed() {
222 let i = symbol!(i; quaternion);
223 let j = symbol!(j; quaternion);
224 let ij = Expression::mul(vec![
225 Expression::symbol(i.clone()),
226 Expression::symbol(j.clone()),
227 ]);
228 let expr = Expression::pow(ij, Expression::integer(2));
229
230 let simplified = expr.simplify();
231
232 match simplified {
233 Expression::Pow(_, exp) => {
234 assert_eq!(*exp, Expression::integer(2));
235 }
236 _ => panic!("Expected Pow, got {:?}", simplified),
237 }
238 }
239
240 #[test]
241 fn test_three_scalar_factors_power_distributed() {
242 let x = symbol!(x);
243 let y = symbol!(y);
244 let z = symbol!(z);
245 let xyz = Expression::mul(vec![
246 Expression::symbol(x.clone()),
247 Expression::symbol(y.clone()),
248 Expression::symbol(z.clone()),
249 ]);
250 let expr = Expression::pow(xyz, Expression::integer(3));
251
252 let simplified = expr.simplify();
253
254 match simplified {
255 Expression::Mul(factors) => {
256 assert_eq!(factors.len(), 3);
257 }
258 _ => panic!("Expected Mul, got {:?}", simplified),
259 }
260 }
261
262 #[test]
263 fn test_mixed_scalar_matrix_power_not_distributed() {
264 let x = symbol!(x);
265 let matrix_a = symbol!(A; matrix);
266 let xa = Expression::mul(vec![
267 Expression::symbol(x.clone()),
268 Expression::symbol(matrix_a.clone()),
269 ]);
270 let expr = Expression::pow(xa, Expression::integer(2));
271
272 let simplified = expr.simplify();
273
274 match simplified {
275 Expression::Pow(_, exp) => {
276 assert_eq!(*exp, Expression::integer(2));
277 }
278 _ => panic!("Expected Pow, got {:?}", simplified),
279 }
280 }
281
282 #[test]
283 fn test_numeric_power_distributed() {
284 let expr = Expression::pow(
285 Expression::mul(vec![Expression::integer(2), Expression::integer(3)]),
286 Expression::integer(2),
287 );
288
289 let simplified = expr.simplify();
290
291 assert_eq!(simplified, Expression::integer(36));
292 }
293}