use crate::expr::Expr;
use crate::simplify::simplify;
#[must_use]
pub fn expand(expr: &Expr) -> Expr {
let expanded = expand_inner(expr);
simplify(&expanded)
}
fn expand_inner(expr: &Expr) -> Expr {
match expr {
Expr::Const(_) | Expr::Var(_) => expr.clone(),
Expr::Add(a, b) => {
let a = expand_inner(a);
let b = expand_inner(b);
Expr::Add(Box::new(a), Box::new(b))
}
Expr::Mul(a, b) => {
let a = expand_inner(a);
let b = expand_inner(b);
distribute(a, b)
}
Expr::Pow(base, exp) => {
let base = expand_inner(base);
let exp = expand_inner(exp);
#[allow(clippy::collapsible_if)]
if let Some(n) = exp.as_const() {
if n > 0.0 && (n - n.floor()).abs() < f64::EPSILON && n <= 8.0 {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let ni = n as u32;
if ni >= 2 {
let mut result = base.clone();
for _ in 1..ni {
result = distribute(result, base.clone());
}
return result;
}
}
}
Expr::Pow(Box::new(base), Box::new(exp))
}
Expr::Neg(inner) => {
let inner = expand_inner(inner);
Expr::Neg(Box::new(inner))
}
Expr::Fn(func, arg) => {
let arg = expand_inner(arg);
Expr::Fn(*func, Box::new(arg))
}
}
}
fn distribute(a: Expr, b: Expr) -> Expr {
match (a, b) {
(Expr::Add(a1, a2), b) => {
let left = distribute(*a1, b.clone());
let right = distribute(*a2, b);
Expr::Add(Box::new(left), Box::new(right))
}
(a, Expr::Add(b1, b2)) => {
let left = distribute(a.clone(), *b1);
let right = distribute(a, *b2);
Expr::Add(Box::new(left), Box::new(right))
}
(a, b) => Expr::Mul(Box::new(a), Box::new(b)),
}
}
#[must_use]
pub fn factor_out(expr: &Expr, term: &Expr) -> Expr {
let addends = collect_addends(expr);
let mut remainders = Vec::with_capacity(addends.len());
for addend in &addends {
if let Some(remainder) = try_divide(addend, term) {
remainders.push(remainder);
} else {
return expr.clone();
}
}
let sum = remainders
.into_iter()
.reduce(|acc, r| Expr::Add(Box::new(acc), Box::new(r)))
.unwrap_or(Expr::Const(0.0));
simplify(&Expr::Mul(Box::new(term.clone()), Box::new(sum)))
}
fn collect_addends(expr: &Expr) -> Vec<Expr> {
match expr {
Expr::Add(a, b) => {
let mut v = collect_addends(a);
v.extend(collect_addends(b));
v
}
_ => vec![expr.clone()],
}
}
fn try_divide(expr: &Expr, term: &Expr) -> Option<Expr> {
if expr == term {
return Some(Expr::Const(1.0));
}
if let Expr::Mul(a, b) = expr {
if a.as_ref() == term {
return Some(*b.clone());
}
if b.as_ref() == term {
return Some(*a.clone());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{constant, var};
use std::collections::HashMap;
fn eval(e: &Expr, x_val: f64) -> f64 {
let mut vars = HashMap::new();
vars.insert("x".into(), x_val);
vars.insert("a".into(), 2.0);
vars.insert("b".into(), 3.0);
vars.insert("c".into(), 4.0);
e.eval(&vars).unwrap()
}
#[test]
fn expand_distributes() {
let e = var("a") * (var("b") + var("c"));
let expanded = expand(&e);
assert!((eval(&e, 0.0) - eval(&expanded, 0.0)).abs() < f64::EPSILON);
}
#[test]
fn expand_square_of_sum() {
let e = Expr::Pow(Box::new(var("x") + constant(1.0)), Box::new(constant(2.0)));
let expanded = expand(&e);
assert!((eval(&expanded, 3.0) - 16.0).abs() < 1e-10);
}
#[test]
fn factor_out_common_term() {
let x = var("x");
let e = Expr::Add(
Box::new(Expr::Mul(Box::new(x.clone()), Box::new(var("a")))),
Box::new(Expr::Mul(Box::new(x.clone()), Box::new(var("b")))),
);
let factored = factor_out(&e, &x);
assert!((eval(&e, 5.0) - eval(&factored, 5.0)).abs() < f64::EPSILON);
}
#[test]
fn factor_out_fails_gracefully() {
let e = var("x") + var("y");
let factored = factor_out(&e, &var("x"));
assert_eq!(factored, e);
}
}