use crate::diff::diff;
use crate::error::{Result, SymError};
use crate::expr::{Expr, constant};
use crate::simplify::simplify;
pub fn solve_linear(expr: &Expr, var: &str) -> Result<Expr> {
let expr = simplify(expr);
let a = diff(&expr, var);
if a.free_variables().contains(var) {
return Err(SymError::SolveFailure {
reason: "expression is not linear in the variable",
});
}
let b = simplify(&expr.substitute(var, &constant(0.0)));
let neg_b = Expr::Neg(Box::new(b));
let result = Expr::Mul(
Box::new(neg_b),
Box::new(Expr::Pow(Box::new(a), Box::new(constant(-1.0)))),
);
Ok(simplify(&result))
}
pub fn solve_quadratic(expr: &Expr, var: &str) -> Result<Vec<Expr>> {
let expr = simplify(expr);
let d1 = diff(&expr, var);
let d2 = diff(&d1, var);
if d2.free_variables().contains(var) {
return Err(SymError::SolveFailure {
reason: "expression is not quadratic in the variable",
});
}
let c = simplify(&expr.substitute(var, &constant(0.0)));
let b = simplify(&d1.substitute(var, &constant(0.0)));
let empty = std::collections::HashMap::new();
let a_val = d2.eval(&empty).map_err(|_| SymError::SolveFailure {
reason: "could not evaluate quadratic coefficient a",
})? / 2.0;
let b_val = b.eval(&empty).map_err(|_| SymError::SolveFailure {
reason: "could not evaluate quadratic coefficient b",
})?;
let c_val = c.eval(&empty).map_err(|_| SymError::SolveFailure {
reason: "could not evaluate quadratic coefficient c",
})?;
if a_val.abs() < f64::EPSILON {
return Err(SymError::SolveFailure {
reason: "leading coefficient is zero; not a quadratic",
});
}
let discriminant = b_val * b_val - 4.0 * a_val * c_val;
if discriminant < -f64::EPSILON {
return Ok(Vec::new()); }
let sqrt_d = discriminant.abs().sqrt();
if discriminant.abs() < f64::EPSILON {
let root = -b_val / (2.0 * a_val);
Ok(vec![constant(root)])
} else {
let r1 = (-b_val + sqrt_d) / (2.0 * a_val);
let r2 = (-b_val - sqrt_d) / (2.0 * a_val);
let mut roots = vec![r1, r2];
roots.sort_by(f64::total_cmp);
Ok(roots.into_iter().map(constant).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expr::{constant, var};
use std::collections::HashMap;
#[test]
fn solve_linear_simple() {
let e = constant(2.0) * var("x") + constant(-6.0);
let sol = solve_linear(&e, "x").unwrap();
let val = sol.eval(&HashMap::new()).unwrap();
assert!((val - 3.0).abs() < 1e-10);
}
#[test]
fn solve_linear_with_offset() {
let e = var("x") + constant(5.0);
let sol = solve_linear(&e, "x").unwrap();
let val = sol.eval(&HashMap::new()).unwrap();
assert!((val - (-5.0)).abs() < 1e-10);
}
#[test]
fn solve_linear_non_linear_fails() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(2.0)));
let err = solve_linear(&e, "x").unwrap_err();
assert!(matches!(err, SymError::SolveFailure { .. }));
}
#[test]
fn solve_quadratic_two_roots() {
let x = var("x");
let e = Expr::Pow(Box::new(x.clone()), Box::new(constant(2.0)))
+ constant(-5.0) * x
+ constant(6.0);
let roots = solve_quadratic(&e, "x").unwrap();
assert_eq!(roots.len(), 2);
let empty = HashMap::new();
let r0 = roots[0].eval(&empty).unwrap();
let r1 = roots[1].eval(&empty).unwrap();
assert!((r0 - 2.0).abs() < 1e-10);
assert!((r1 - 3.0).abs() < 1e-10);
}
#[test]
fn solve_quadratic_one_root() {
let x = var("x");
let e = Expr::Pow(Box::new(x.clone()), Box::new(constant(2.0)))
+ constant(-4.0) * x
+ constant(4.0);
let roots = solve_quadratic(&e, "x").unwrap();
assert_eq!(roots.len(), 1);
let val = roots[0].eval(&HashMap::new()).unwrap();
assert!((val - 2.0).abs() < 1e-10);
}
#[test]
fn solve_quadratic_no_real_roots() {
let e = Expr::Pow(Box::new(var("x")), Box::new(constant(2.0))) + constant(1.0);
let roots = solve_quadratic(&e, "x").unwrap();
assert!(roots.is_empty());
}
}