mathhook_core/simplify/arithmetic/
power.rs

1//! Power simplification operations
2
3use super::multiplication::simplify_multiplication;
4use super::Simplify;
5use crate::core::commutativity::Commutativity;
6use crate::core::{Expression, Number};
7use num_bigint::BigInt;
8use num_rational::BigRational;
9use std::sync::Arc;
10
11/// Power simplification
12pub fn simplify_power(base: &Expression, exp: &Expression) -> Expression {
13    let simplified_base = base.simplify();
14    let simplified_exp = exp.simplify();
15
16    match (&simplified_base, &simplified_exp) {
17        // x^0 = 1
18        (_, Expression::Number(Number::Integer(0))) => Expression::integer(1),
19        // x^1 = x (use already simplified base)
20        (_, Expression::Number(Number::Integer(1))) => simplified_base,
21        // 1^x = 1
22        (Expression::Number(Number::Integer(1)), _) => Expression::integer(1),
23        // 0^x = 0 (for x > 0)
24        (Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(n)))
25            if *n > 0 =>
26        {
27            Expression::integer(0)
28        }
29        // 0^(-1) = undefined (division by zero)
30        (Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(-1))) => {
31            Expression::function("undefined", vec![])
32        }
33        // a^n = a^n for positive integers a and n (compute the power)
34        (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
35            if *n > 0 && *a != 0 =>
36        {
37            // Use checked_pow to prevent overflow, promote to BigInt on overflow
38            if let Some(result) = (*a).checked_pow(*n as u32) {
39                Expression::integer(result)
40            } else {
41                // Overflow - use BigInt for arbitrary precision
42                let base_big = BigInt::from(*a);
43                let result_big = base_big.pow(*n as u32);
44                Expression::Number(Number::rational(BigRational::new(
45                    result_big,
46                    BigInt::from(1),
47                )))
48            }
49        }
50        // a^(-1) = 1/a (convert to rational for integers)
51        (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(-1)))
52            if *a != 0 =>
53        {
54            Expression::Number(Number::rational(BigRational::new(
55                BigInt::from(1),
56                BigInt::from(*a),
57            )))
58        }
59        // (a/b)^(-1) = b/a (reciprocal of rational)
60        (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(-1))) => {
61            Expression::Number(Number::rational(BigRational::new(
62                r.denom().clone(),
63                r.numer().clone(),
64            )))
65        }
66        // (a/b)^n = a^n/b^n for positive integers n
67        (Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(n)))
68            if *n > 0 =>
69        {
70            let exp = *n as u32;
71            let numerator = r.numer().pow(exp);
72            let denominator = r.denom().pow(exp);
73            Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
74        }
75        // a^(-n) = 1/(a^n) for positive integers a and n
76        (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
77            if *n < 0 && *a != 0 =>
78        {
79            let positive_exp = (-n) as u32;
80            let numerator = BigInt::from(1);
81            let denominator = BigInt::from(*a).pow(positive_exp);
82            Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
83        }
84        // sqrt(x)^2 = x (inverse function)
85        (Expression::Function { name, args }, Expression::Number(Number::Integer(2)))
86            if name.as_ref() == "sqrt" && args.len() == 1 =>
87        {
88            args[0].clone()
89        }
90        // (a^b)^c = a^(b*c)
91        (Expression::Pow(b, e), c) => {
92            let new_exp = simplify_multiplication(&[e.as_ref().clone(), c.clone()]);
93            Expression::Pow(Arc::new(b.as_ref().clone()), Arc::new(new_exp))
94        }
95        // (a*b)^n = a^n * b^n ONLY if commutative
96        (Expression::Mul(factors), Expression::Number(Number::Integer(n))) if *n > 0 => {
97            let commutativity = Commutativity::combine(factors.iter().map(|f| f.commutativity()));
98
99            if commutativity.can_sort() {
100                let powered_factors: Vec<Expression> = factors
101                    .iter()
102                    .map(|f| Expression::pow(f.clone(), simplified_exp.clone()))
103                    .collect();
104                simplify_multiplication(&powered_factors)
105            } else {
106                Expression::Pow(Arc::new(simplified_base), Arc::new(simplified_exp))
107            }
108        }
109        _ => Expression::Pow(Arc::new(simplified_base), Arc::new(simplified_exp)),
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::simplify::Simplify;
117    use crate::symbol;
118    use crate::Expression;
119
120    #[test]
121    fn test_power_simplification() {
122        let x = symbol!(x);
123
124        // x^0 = 1
125        let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(0));
126        assert_eq!(expr, Expression::integer(1));
127
128        // x^1 = x
129        let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(1));
130        assert_eq!(expr, Expression::symbol(x));
131    }
132
133    #[test]
134    fn test_scalar_power_distributed() {
135        let x = symbol!(x);
136        let y = symbol!(y);
137        let xy = Expression::mul(vec![
138            Expression::symbol(x.clone()),
139            Expression::symbol(y.clone()),
140        ]);
141        let expr = Expression::pow(xy, Expression::integer(2));
142
143        let simplified = expr.simplify();
144
145        match simplified {
146            Expression::Mul(factors) => {
147                assert_eq!(factors.len(), 2);
148                let has_x_squared = factors.iter().any(|f| {
149                    matches!(f, Expression::Pow(base, exp) if
150                        base.as_ref() == &Expression::symbol(symbol!(x)) &&
151                        exp.as_ref() == &Expression::integer(2))
152                });
153                let has_y_squared = factors.iter().any(|f| {
154                    matches!(f, Expression::Pow(base, exp) if
155                        base.as_ref() == &Expression::symbol(symbol!(y)) &&
156                        exp.as_ref() == &Expression::integer(2))
157                });
158                assert!(has_x_squared, "Expected x^2 in factors");
159                assert!(has_y_squared, "Expected y^2 in factors");
160            }
161            _ => panic!("Expected Mul, got {:?}", simplified),
162        }
163    }
164
165    #[test]
166    fn test_matrix_power_not_distributed() {
167        let matrix_a = symbol!(A; matrix);
168        let matrix_b = symbol!(B; matrix);
169        let ab = Expression::mul(vec![
170            Expression::symbol(matrix_a.clone()),
171            Expression::symbol(matrix_b.clone()),
172        ]);
173        let expr = Expression::pow(ab.clone(), Expression::integer(2));
174
175        let simplified = expr.simplify();
176
177        match simplified {
178            Expression::Pow(base, exp) => {
179                assert_eq!(exp.as_ref(), &Expression::integer(2));
180                match base.as_ref() {
181                    Expression::Mul(factors) => {
182                        assert_eq!(factors.len(), 2);
183                        assert!(factors.iter().all(|f| matches!(f, Expression::Symbol(s) if s.symbol_type() == crate::core::symbol::SymbolType::Matrix)));
184                    }
185                    _ => panic!("Expected Mul base, got {:?}", base),
186                }
187            }
188            _ => panic!("Expected Pow, got {:?}", simplified),
189        }
190    }
191
192    #[test]
193    fn test_operator_power_not_distributed() {
194        let matrix_p = symbol!(P; operator);
195        let matrix_q = symbol!(Q; operator);
196        let pq = Expression::mul(vec![
197            Expression::symbol(matrix_p.clone()),
198            Expression::symbol(matrix_q.clone()),
199        ]);
200        let expr = Expression::pow(pq, Expression::integer(2));
201
202        let simplified = expr.simplify();
203
204        match simplified {
205            Expression::Pow(base, exp) => {
206                assert_eq!(exp.as_ref(), &Expression::integer(2));
207                match base.as_ref() {
208                    Expression::Mul(factors) => {
209                        assert_eq!(factors.len(), 2);
210                    }
211                    _ => panic!("Expected Mul base, got {:?}", base),
212                }
213            }
214            _ => panic!("Expected Pow, got {:?}", simplified),
215        }
216    }
217
218    #[test]
219    fn test_quaternion_power_not_distributed() {
220        let i = symbol!(i; quaternion);
221        let j = symbol!(j; quaternion);
222        let ij = Expression::mul(vec![
223            Expression::symbol(i.clone()),
224            Expression::symbol(j.clone()),
225        ]);
226        let expr = Expression::pow(ij, Expression::integer(2));
227
228        let simplified = expr.simplify();
229
230        match simplified {
231            Expression::Pow(_, exp) => {
232                assert_eq!(exp.as_ref(), &Expression::integer(2));
233            }
234            _ => panic!("Expected Pow, got {:?}", simplified),
235        }
236    }
237
238    #[test]
239    fn test_three_scalar_factors_power_distributed() {
240        let x = symbol!(x);
241        let y = symbol!(y);
242        let z = symbol!(z);
243        let xyz = Expression::mul(vec![
244            Expression::symbol(x.clone()),
245            Expression::symbol(y.clone()),
246            Expression::symbol(z.clone()),
247        ]);
248        let expr = Expression::pow(xyz, Expression::integer(3));
249
250        let simplified = expr.simplify();
251
252        match simplified {
253            Expression::Mul(factors) => {
254                assert_eq!(factors.len(), 3);
255            }
256            _ => panic!("Expected Mul, got {:?}", simplified),
257        }
258    }
259
260    #[test]
261    fn test_mixed_scalar_matrix_power_not_distributed() {
262        let x = symbol!(x);
263        let matrix_a = symbol!(A; matrix);
264        let xa = Expression::mul(vec![
265            Expression::symbol(x.clone()),
266            Expression::symbol(matrix_a.clone()),
267        ]);
268        let expr = Expression::pow(xa, Expression::integer(2));
269
270        let simplified = expr.simplify();
271
272        match simplified {
273            Expression::Pow(_, exp) => {
274                assert_eq!(exp.as_ref(), &Expression::integer(2));
275            }
276            _ => panic!("Expected Pow, got {:?}", simplified),
277        }
278    }
279
280    #[test]
281    fn test_numeric_power_distributed() {
282        let expr = Expression::pow(
283            Expression::mul(vec![Expression::integer(2), Expression::integer(3)]),
284            Expression::integer(2),
285        );
286
287        let simplified = expr.simplify();
288
289        assert_eq!(simplified, Expression::integer(36));
290    }
291}