mathhook_core/simplify/arithmetic/
power.rs1use super::multiplication::simplify_multiplication;
4use super::Simplify;
5use crate::core::commutativity::Commutativity;
6use crate::core::{Expression, Number};
7use num_bigint::BigInt;
8use num_rational::BigRational;
9use std::sync::Arc;
10
11pub fn simplify_power(base: &Expression, exp: &Expression) -> Expression {
13 let simplified_base = base.simplify();
14 let simplified_exp = exp.simplify();
15
16 match (&simplified_base, &simplified_exp) {
17 (_, Expression::Number(Number::Integer(0))) => Expression::integer(1),
19 (_, Expression::Number(Number::Integer(1))) => simplified_base,
21 (Expression::Number(Number::Integer(1)), _) => Expression::integer(1),
23 (Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(n)))
25 if *n > 0 =>
26 {
27 Expression::integer(0)
28 }
29 (Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(-1))) => {
31 Expression::function("undefined", vec![])
32 }
33 (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
35 if *n > 0 && *a != 0 =>
36 {
37 if let Some(result) = (*a).checked_pow(*n as u32) {
39 Expression::integer(result)
40 } else {
41 let base_big = BigInt::from(*a);
43 let result_big = base_big.pow(*n as u32);
44 Expression::Number(Number::rational(BigRational::new(
45 result_big,
46 BigInt::from(1),
47 )))
48 }
49 }
50 (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(-1)))
52 if *a != 0 =>
53 {
54 Expression::Number(Number::rational(BigRational::new(
55 BigInt::from(1),
56 BigInt::from(*a),
57 )))
58 }
59 (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(-1))) => {
61 Expression::Number(Number::rational(BigRational::new(
62 r.denom().clone(),
63 r.numer().clone(),
64 )))
65 }
66 (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(n)))
68 if *n > 0 =>
69 {
70 let exp = *n as u32;
71 let numerator = r.numer().pow(exp);
72 let denominator = r.denom().pow(exp);
73 Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
74 }
75 (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
77 if *n < 0 && *a != 0 =>
78 {
79 let positive_exp = (-n) as u32;
80 let numerator = BigInt::from(1);
81 let denominator = BigInt::from(*a).pow(positive_exp);
82 Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
83 }
84 (Expression::Function { name, args }, Expression::Number(Number::Integer(2)))
86 if name.as_ref() == "sqrt" && args.len() == 1 =>
87 {
88 args[0].clone()
89 }
90 (Expression::Pow(b, e), c) => {
92 let new_exp = simplify_multiplication(&[e.as_ref().clone(), c.clone()]);
93 Expression::Pow(Arc::new(b.as_ref().clone()), Arc::new(new_exp))
94 }
95 (Expression::Mul(factors), Expression::Number(Number::Integer(n))) if *n > 0 => {
97 let commutativity = Commutativity::combine(factors.iter().map(|f| f.commutativity()));
98
99 if commutativity.can_sort() {
100 let powered_factors: Vec<Expression> = factors
101 .iter()
102 .map(|f| Expression::pow(f.clone(), simplified_exp.clone()))
103 .collect();
104 simplify_multiplication(&powered_factors)
105 } else {
106 Expression::Pow(Arc::new(simplified_base), Arc::new(simplified_exp))
107 }
108 }
109 _ => Expression::Pow(Arc::new(simplified_base), Arc::new(simplified_exp)),
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use crate::simplify::Simplify;
117 use crate::symbol;
118 use crate::Expression;
119
120 #[test]
121 fn test_power_simplification() {
122 let x = symbol!(x);
123
124 let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(0));
126 assert_eq!(expr, Expression::integer(1));
127
128 let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(1));
130 assert_eq!(expr, Expression::symbol(x));
131 }
132
133 #[test]
134 fn test_scalar_power_distributed() {
135 let x = symbol!(x);
136 let y = symbol!(y);
137 let xy = Expression::mul(vec![
138 Expression::symbol(x.clone()),
139 Expression::symbol(y.clone()),
140 ]);
141 let expr = Expression::pow(xy, Expression::integer(2));
142
143 let simplified = expr.simplify();
144
145 match simplified {
146 Expression::Mul(factors) => {
147 assert_eq!(factors.len(), 2);
148 let has_x_squared = factors.iter().any(|f| {
149 matches!(f, Expression::Pow(base, exp) if
150 base.as_ref() == &Expression::symbol(symbol!(x)) &&
151 exp.as_ref() == &Expression::integer(2))
152 });
153 let has_y_squared = factors.iter().any(|f| {
154 matches!(f, Expression::Pow(base, exp) if
155 base.as_ref() == &Expression::symbol(symbol!(y)) &&
156 exp.as_ref() == &Expression::integer(2))
157 });
158 assert!(has_x_squared, "Expected x^2 in factors");
159 assert!(has_y_squared, "Expected y^2 in factors");
160 }
161 _ => panic!("Expected Mul, got {:?}", simplified),
162 }
163 }
164
165 #[test]
166 fn test_matrix_power_not_distributed() {
167 let matrix_a = symbol!(A; matrix);
168 let matrix_b = symbol!(B; matrix);
169 let ab = Expression::mul(vec![
170 Expression::symbol(matrix_a.clone()),
171 Expression::symbol(matrix_b.clone()),
172 ]);
173 let expr = Expression::pow(ab.clone(), Expression::integer(2));
174
175 let simplified = expr.simplify();
176
177 match simplified {
178 Expression::Pow(base, exp) => {
179 assert_eq!(exp.as_ref(), &Expression::integer(2));
180 match base.as_ref() {
181 Expression::Mul(factors) => {
182 assert_eq!(factors.len(), 2);
183 assert!(factors.iter().all(|f| matches!(f, Expression::Symbol(s) if s.symbol_type() == crate::core::symbol::SymbolType::Matrix)));
184 }
185 _ => panic!("Expected Mul base, got {:?}", base),
186 }
187 }
188 _ => panic!("Expected Pow, got {:?}", simplified),
189 }
190 }
191
192 #[test]
193 fn test_operator_power_not_distributed() {
194 let matrix_p = symbol!(P; operator);
195 let matrix_q = symbol!(Q; operator);
196 let pq = Expression::mul(vec![
197 Expression::symbol(matrix_p.clone()),
198 Expression::symbol(matrix_q.clone()),
199 ]);
200 let expr = Expression::pow(pq, Expression::integer(2));
201
202 let simplified = expr.simplify();
203
204 match simplified {
205 Expression::Pow(base, exp) => {
206 assert_eq!(exp.as_ref(), &Expression::integer(2));
207 match base.as_ref() {
208 Expression::Mul(factors) => {
209 assert_eq!(factors.len(), 2);
210 }
211 _ => panic!("Expected Mul base, got {:?}", base),
212 }
213 }
214 _ => panic!("Expected Pow, got {:?}", simplified),
215 }
216 }
217
218 #[test]
219 fn test_quaternion_power_not_distributed() {
220 let i = symbol!(i; quaternion);
221 let j = symbol!(j; quaternion);
222 let ij = Expression::mul(vec![
223 Expression::symbol(i.clone()),
224 Expression::symbol(j.clone()),
225 ]);
226 let expr = Expression::pow(ij, Expression::integer(2));
227
228 let simplified = expr.simplify();
229
230 match simplified {
231 Expression::Pow(_, exp) => {
232 assert_eq!(exp.as_ref(), &Expression::integer(2));
233 }
234 _ => panic!("Expected Pow, got {:?}", simplified),
235 }
236 }
237
238 #[test]
239 fn test_three_scalar_factors_power_distributed() {
240 let x = symbol!(x);
241 let y = symbol!(y);
242 let z = symbol!(z);
243 let xyz = Expression::mul(vec![
244 Expression::symbol(x.clone()),
245 Expression::symbol(y.clone()),
246 Expression::symbol(z.clone()),
247 ]);
248 let expr = Expression::pow(xyz, Expression::integer(3));
249
250 let simplified = expr.simplify();
251
252 match simplified {
253 Expression::Mul(factors) => {
254 assert_eq!(factors.len(), 3);
255 }
256 _ => panic!("Expected Mul, got {:?}", simplified),
257 }
258 }
259
260 #[test]
261 fn test_mixed_scalar_matrix_power_not_distributed() {
262 let x = symbol!(x);
263 let matrix_a = symbol!(A; matrix);
264 let xa = Expression::mul(vec![
265 Expression::symbol(x.clone()),
266 Expression::symbol(matrix_a.clone()),
267 ]);
268 let expr = Expression::pow(xa, Expression::integer(2));
269
270 let simplified = expr.simplify();
271
272 match simplified {
273 Expression::Pow(_, exp) => {
274 assert_eq!(exp.as_ref(), &Expression::integer(2));
275 }
276 _ => panic!("Expected Pow, got {:?}", simplified),
277 }
278 }
279
280 #[test]
281 fn test_numeric_power_distributed() {
282 let expr = Expression::pow(
283 Expression::mul(vec![Expression::integer(2), Expression::integer(3)]),
284 Expression::integer(2),
285 );
286
287 let simplified = expr.simplify();
288
289 assert_eq!(simplified, Expression::integer(36));
290 }
291}