use super::multiplication::simplify_multiplication;
use super::Simplify;
use crate::core::commutativity::Commutativity;
use crate::core::{Expression, Number};
use num_bigint::BigInt;
use num_rational::BigRational;
use std::sync::Arc;
pub fn simplify_power(base: &Expression, exp: &Expression) -> Expression {
let simplified_base = base.simplify();
let simplified_exp = exp.simplify();
match (&simplified_base, &simplified_exp) {
(_, Expression::Number(Number::Integer(0))) => Expression::integer(1),
(_, Expression::Number(Number::Integer(1))) => simplified_base,
(Expression::Number(Number::Integer(1)), _) => Expression::integer(1),
(Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(n)))
if *n > 0 =>
{
Expression::integer(0)
}
(Expression::Number(Number::Integer(0)), Expression::Number(Number::Integer(-1))) => {
Expression::function("undefined", vec![])
}
(Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
if *n > 0 && *a != 0 =>
{
if let Some(result) = (*a).checked_pow(*n as u32) {
Expression::integer(result)
} else {
let base_big = BigInt::from(*a);
let result_big = base_big.pow(*n as u32);
Expression::Number(Number::rational(BigRational::new(
result_big,
BigInt::from(1),
)))
}
}
(Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(-1)))
if *a != 0 =>
{
Expression::Number(Number::rational(BigRational::new(
BigInt::from(1),
BigInt::from(*a),
)))
}
(Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(-1))) => {
Expression::Number(Number::rational(BigRational::new(
r.denom().clone(),
r.numer().clone(),
)))
}
(Expression::Number(Number::Rational(r)), Expression::Number(Number::Integer(n)))
if *n > 0 =>
{
let exp = *n as u32;
let numerator = r.numer().pow(exp);
let denominator = r.denom().pow(exp);
Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
}
(Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(n)))
if *n < 0 && *a != 0 =>
{
let positive_exp = (-n) as u32;
let numerator = BigInt::from(1);
let denominator = BigInt::from(*a).pow(positive_exp);
Expression::Number(Number::rational(BigRational::new(numerator, denominator)))
}
(Expression::Function { name, args }, Expression::Number(Number::Integer(2)))
if name.as_ref() == "sqrt" && args.len() == 1 =>
{
args[0].clone()
}
(Expression::Pow(b, e), c) => {
let new_exp = simplify_multiplication(&[e.as_ref().clone(), c.clone()]);
Expression::Pow(Arc::new(b.as_ref().clone()), Arc::new(new_exp))
}
(Expression::Mul(factors), Expression::Number(Number::Integer(n))) if *n > 0 => {
let commutativity = Commutativity::combine(factors.iter().map(|f| f.commutativity()));
if commutativity.can_sort() {
let powered_factors: Vec<Expression> = factors
.iter()
.map(|f| Expression::pow(f.clone(), simplified_exp.clone()))
.collect();
simplify_multiplication(&powered_factors)
} else {
Expression::Pow(Arc::new(simplified_base), Arc::new(simplified_exp))
}
}
_ => Expression::Pow(Arc::new(simplified_base), Arc::new(simplified_exp)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::simplify::Simplify;
use crate::symbol;
use crate::Expression;
#[test]
fn test_power_simplification() {
let x = symbol!(x);
let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(0));
assert_eq!(expr, Expression::integer(1));
let expr = simplify_power(&Expression::symbol(x.clone()), &Expression::integer(1));
assert_eq!(expr, Expression::symbol(x));
}
#[test]
fn test_scalar_power_distributed() {
let x = symbol!(x);
let y = symbol!(y);
let xy = Expression::mul(vec![
Expression::symbol(x.clone()),
Expression::symbol(y.clone()),
]);
let expr = Expression::pow(xy, Expression::integer(2));
let simplified = expr.simplify();
match simplified {
Expression::Mul(factors) => {
assert_eq!(factors.len(), 2);
let has_x_squared = factors.iter().any(|f| {
matches!(f, Expression::Pow(base, exp) if
base.as_ref() == &Expression::symbol(symbol!(x)) &&
exp.as_ref() == &Expression::integer(2))
});
let has_y_squared = factors.iter().any(|f| {
matches!(f, Expression::Pow(base, exp) if
base.as_ref() == &Expression::symbol(symbol!(y)) &&
exp.as_ref() == &Expression::integer(2))
});
assert!(has_x_squared, "Expected x^2 in factors");
assert!(has_y_squared, "Expected y^2 in factors");
}
_ => panic!("Expected Mul, got {:?}", simplified),
}
}
#[test]
fn test_matrix_power_not_distributed() {
let matrix_a = symbol!(A; matrix);
let matrix_b = symbol!(B; matrix);
let ab = Expression::mul(vec![
Expression::symbol(matrix_a.clone()),
Expression::symbol(matrix_b.clone()),
]);
let expr = Expression::pow(ab.clone(), Expression::integer(2));
let simplified = expr.simplify();
match simplified {
Expression::Pow(base, exp) => {
assert_eq!(exp.as_ref(), &Expression::integer(2));
match base.as_ref() {
Expression::Mul(factors) => {
assert_eq!(factors.len(), 2);
assert!(factors.iter().all(|f| matches!(f, Expression::Symbol(s) if s.symbol_type() == crate::core::symbol::SymbolType::Matrix)));
}
_ => panic!("Expected Mul base, got {:?}", base),
}
}
_ => panic!("Expected Pow, got {:?}", simplified),
}
}
#[test]
fn test_operator_power_not_distributed() {
let matrix_p = symbol!(P; operator);
let matrix_q = symbol!(Q; operator);
let pq = Expression::mul(vec![
Expression::symbol(matrix_p.clone()),
Expression::symbol(matrix_q.clone()),
]);
let expr = Expression::pow(pq, Expression::integer(2));
let simplified = expr.simplify();
match simplified {
Expression::Pow(base, exp) => {
assert_eq!(exp.as_ref(), &Expression::integer(2));
match base.as_ref() {
Expression::Mul(factors) => {
assert_eq!(factors.len(), 2);
}
_ => panic!("Expected Mul base, got {:?}", base),
}
}
_ => panic!("Expected Pow, got {:?}", simplified),
}
}
#[test]
fn test_quaternion_power_not_distributed() {
let i = symbol!(i; quaternion);
let j = symbol!(j; quaternion);
let ij = Expression::mul(vec![
Expression::symbol(i.clone()),
Expression::symbol(j.clone()),
]);
let expr = Expression::pow(ij, Expression::integer(2));
let simplified = expr.simplify();
match simplified {
Expression::Pow(_, exp) => {
assert_eq!(exp.as_ref(), &Expression::integer(2));
}
_ => panic!("Expected Pow, got {:?}", simplified),
}
}
#[test]
fn test_three_scalar_factors_power_distributed() {
let x = symbol!(x);
let y = symbol!(y);
let z = symbol!(z);
let xyz = Expression::mul(vec![
Expression::symbol(x.clone()),
Expression::symbol(y.clone()),
Expression::symbol(z.clone()),
]);
let expr = Expression::pow(xyz, Expression::integer(3));
let simplified = expr.simplify();
match simplified {
Expression::Mul(factors) => {
assert_eq!(factors.len(), 3);
}
_ => panic!("Expected Mul, got {:?}", simplified),
}
}
#[test]
fn test_mixed_scalar_matrix_power_not_distributed() {
let x = symbol!(x);
let matrix_a = symbol!(A; matrix);
let xa = Expression::mul(vec![
Expression::symbol(x.clone()),
Expression::symbol(matrix_a.clone()),
]);
let expr = Expression::pow(xa, Expression::integer(2));
let simplified = expr.simplify();
match simplified {
Expression::Pow(_, exp) => {
assert_eq!(exp.as_ref(), &Expression::integer(2));
}
_ => panic!("Expected Pow, got {:?}", simplified),
}
}
#[test]
fn test_numeric_power_distributed() {
let expr = Expression::pow(
Expression::mul(vec![Expression::integer(2), Expression::integer(3)]),
Expression::integer(2),
);
let simplified = expr.simplify();
assert_eq!(simplified, Expression::integer(36));
}
}