mathhook_core/simplify/arithmetic/
power.rs

1//! Power simplification operations
2
3use super::multiplication::simplify_multiplication;
4use super::Simplify;
5use crate::core::commutativity::Commutativity;
6use crate::core::{Expression, Number};
7use num_bigint::BigInt;
8use num_rational::BigRational;
9
10/// Power simplification
11pub fn simplify_power(base: &Expression, exp: &Expression) -> Expression {
12    // First, simplify both base and exponent for better pattern matching
13    let simplified_base = base.simplify();
14    let simplified_exp = exp.simplify();
15
16    match (&simplified_base, &simplified_exp) {
17        // x^0 = 1
18        (_, Expression::Number(Number::Integer(0))) => Expression::integer(1),
19        // x^1 = x (use already simplified base)
20        (_, Expression::Number(Number::Integer(1))) => simplified_base,
21        // 1^x = 1
22        (Expression::Number(Number::Integer(1)), _) => Expression::integer(1),
23        // 0^x = 0 (for x > 0)
24        (Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(n)))
25            if *n > 0 =>
26        {
27            Expression::integer(0)
28        }
29        // 0^(-1) = undefined (division by zero)
30        (Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(-1))) => {
31            Expression::function("undefined".to_owned(), vec![])
32        }
33        // a^n = a^n for positive integers a and n (compute the power)
34        (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
35            if *n > 0 && *a != 0 =>
36        {
37            // Use checked_pow to prevent overflow, promote to BigInt on overflow
38            if let Some(result) = (*a).checked_pow(*n as u32) {
39                Expression::integer(result)
40            } else {
41                // Overflow - use BigInt for arbitrary precision
42                let base_big = BigInt::from(*a);
43                let result_big = base_big.pow(*n as u32);
44                Expression::Number(Number::rational(BigRational::new(
45                    result_big,
46                    BigInt::from(1),
47                )))
48            }
49        }
50        // a^(-1) = 1/a (convert to rational for integers)
51        (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(-1)))
52            if *a != 0 =>
53        {
54            Expression::Number(Number::rational(BigRational::new(
55                BigInt::from(1),
56                BigInt::from(*a),
57            )))
58        }
59        // (a/b)^(-1) = b/a (reciprocal of rational)
60        (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(-1))) => {
61            Expression::Number(Number::rational(BigRational::new(
62                r.denom().clone(),
63                r.numer().clone(),
64            )))
65        }
66        // (a/b)^n = a^n/b^n for positive integers n
67        (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(n)))
68            if *n > 0 =>
69        {
70            let exp = *n as u32;
71            let numerator = r.numer().pow(exp);
72            let denominator = r.denom().pow(exp);
73            Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
74        }
75        // a^(-n) = 1/(a^n) for positive integers a and n
76        (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
77            if *n < 0 && *a != 0 =>
78        {
79            let positive_exp = (-n) as u32;
80            let numerator = BigInt::from(1);
81            let denominator = BigInt::from(*a).pow(positive_exp);
82            Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
83        }
84        // sqrt(x)^2 = x (inverse function)
85        (Expression::Function { name, args }, Expression::Number(Number::Integer(2)))
86            if name == "sqrt" && args.len() == 1 =>
87        {
88            args[0].clone()
89        }
90        // (a^b)^c = a^(b*c)
91        (Expression::Pow(b, e), c) => {
92            let new_exp = simplify_multiplication(&[(**e).clone(), c.clone()]);
93            Expression::Pow(Box::new((**b).clone()), Box::new(new_exp))
94        }
95        // (a*b)^n = a^n * b^n ONLY if commutative
96        (Expression::Mul(factors), Expression::Number(Number::Integer(n))) if *n > 0 => {
97            let commutativity = Commutativity::combine(factors.iter().map(|f| f.commutativity()));
98
99            if commutativity.can_sort() {
100                // Safe to distribute - all factors commutative
101                let powered_factors: Vec<Expression> = factors
102                    .iter()
103                    .map(|f| Expression::pow(f.clone(), simplified_exp.clone()))
104                    .collect();
105                simplify_multiplication(&powered_factors)
106            } else {
107                // Noncommutative - keep as (a*b)^n
108                Expression::Pow(Box::new(simplified_base), Box::new(simplified_exp))
109            }
110        }
111        _ => Expression::Pow(Box::new(simplified_base), Box::new(simplified_exp)),
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::simplify::Simplify;
119    use crate::symbol;
120    use crate::Expression;
121
122    #[test]
123    fn test_power_simplification() {
124        let x = symbol!(x);
125
126        // x^0 = 1
127        let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(0));
128        assert_eq!(expr, Expression::integer(1));
129
130        // x^1 = x
131        let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(1));
132        assert_eq!(expr, Expression::symbol(x));
133    }
134
135    #[test]
136    fn test_scalar_power_distributed() {
137        let x = symbol!(x);
138        let y = symbol!(y);
139        let xy = Expression::mul(vec![
140            Expression::symbol(x.clone()),
141            Expression::symbol(y.clone()),
142        ]);
143        let expr = Expression::pow(xy, Expression::integer(2));
144
145        let simplified = expr.simplify();
146
147        match simplified {
148            Expression::Mul(factors) => {
149                assert_eq!(factors.len(), 2);
150                let has_x_squared = factors.iter().any(|f| {
151                    matches!(f, Expression::Pow(base, exp) if
152                        **base == Expression::symbol(symbol!(x)) &&
153                        **exp == Expression::integer(2))
154                });
155                let has_y_squared = factors.iter().any(|f| {
156                    matches!(f, Expression::Pow(base, exp) if
157                        **base == Expression::symbol(symbol!(y)) &&
158                        **exp == Expression::integer(2))
159                });
160                assert!(has_x_squared, "Expected x^2 in factors");
161                assert!(has_y_squared, "Expected y^2 in factors");
162            }
163            _ => panic!("Expected Mul, got {:?}", simplified),
164        }
165    }
166
167    #[test]
168    fn test_matrix_power_not_distributed() {
169        let matrix_a = symbol!(A; matrix);
170        let matrix_b = symbol!(B; matrix);
171        let ab = Expression::mul(vec![
172            Expression::symbol(matrix_a.clone()),
173            Expression::symbol(matrix_b.clone()),
174        ]);
175        let expr = Expression::pow(ab.clone(), Expression::integer(2));
176
177        let simplified = expr.simplify();
178
179        match simplified {
180            Expression::Pow(base, exp) => {
181                assert_eq!(*exp, Expression::integer(2));
182                match *base {
183                    Expression::Mul(factors) => {
184                        assert_eq!(factors.len(), 2);
185                        assert!(factors.iter().all(|f| matches!(f, Expression::Symbol(s) if s.symbol_type() == crate::core::symbol::SymbolType::Matrix)));
186                    }
187                    _ => panic!("Expected Mul base, got {:?}", base),
188                }
189            }
190            _ => panic!("Expected Pow, got {:?}", simplified),
191        }
192    }
193
194    #[test]
195    fn test_operator_power_not_distributed() {
196        let matrix_p = symbol!(P; operator);
197        let matrix_q = symbol!(Q; operator);
198        let pq = Expression::mul(vec![
199            Expression::symbol(matrix_p.clone()),
200            Expression::symbol(matrix_q.clone()),
201        ]);
202        let expr = Expression::pow(pq, Expression::integer(2));
203
204        let simplified = expr.simplify();
205
206        match simplified {
207            Expression::Pow(base, exp) => {
208                assert_eq!(*exp, Expression::integer(2));
209                match *base {
210                    Expression::Mul(factors) => {
211                        assert_eq!(factors.len(), 2);
212                    }
213                    _ => panic!("Expected Mul base, got {:?}", base),
214                }
215            }
216            _ => panic!("Expected Pow, got {:?}", simplified),
217        }
218    }
219
220    #[test]
221    fn test_quaternion_power_not_distributed() {
222        let i = symbol!(i; quaternion);
223        let j = symbol!(j; quaternion);
224        let ij = Expression::mul(vec![
225            Expression::symbol(i.clone()),
226            Expression::symbol(j.clone()),
227        ]);
228        let expr = Expression::pow(ij, Expression::integer(2));
229
230        let simplified = expr.simplify();
231
232        match simplified {
233            Expression::Pow(_, exp) => {
234                assert_eq!(*exp, Expression::integer(2));
235            }
236            _ => panic!("Expected Pow, got {:?}", simplified),
237        }
238    }
239
240    #[test]
241    fn test_three_scalar_factors_power_distributed() {
242        let x = symbol!(x);
243        let y = symbol!(y);
244        let z = symbol!(z);
245        let xyz = Expression::mul(vec![
246            Expression::symbol(x.clone()),
247            Expression::symbol(y.clone()),
248            Expression::symbol(z.clone()),
249        ]);
250        let expr = Expression::pow(xyz, Expression::integer(3));
251
252        let simplified = expr.simplify();
253
254        match simplified {
255            Expression::Mul(factors) => {
256                assert_eq!(factors.len(), 3);
257            }
258            _ => panic!("Expected Mul, got {:?}", simplified),
259        }
260    }
261
262    #[test]
263    fn test_mixed_scalar_matrix_power_not_distributed() {
264        let x = symbol!(x);
265        let matrix_a = symbol!(A; matrix);
266        let xa = Expression::mul(vec![
267            Expression::symbol(x.clone()),
268            Expression::symbol(matrix_a.clone()),
269        ]);
270        let expr = Expression::pow(xa, Expression::integer(2));
271
272        let simplified = expr.simplify();
273
274        match simplified {
275            Expression::Pow(_, exp) => {
276                assert_eq!(*exp, Expression::integer(2));
277            }
278            _ => panic!("Expected Pow, got {:?}", simplified),
279        }
280    }
281
282    #[test]
283    fn test_numeric_power_distributed() {
284        let expr = Expression::pow(
285            Expression::mul(vec![Expression::integer(2), Expression::integer(3)]),
286            Expression::integer(2),
287        );
288
289        let simplified = expr.simplify();
290
291        assert_eq!(simplified, Expression::integer(36));
292    }
293}