use std::collections::HashMap;
use std::ops::Neg;
use num_traits::ToPrimitive;
use crate::symbolic::core::Expr;
use crate::symbolic::core::SparsePolynomial;
use crate::symbolic::polynomial::differentiate_poly;
use crate::symbolic::polynomial::gcd;
use crate::symbolic::simplify::as_f64;
#[must_use]
pub fn sturm_sequence(
poly: &SparsePolynomial,
var: &str,
) -> Vec<SparsePolynomial> {
if poly.terms.is_empty() {
return vec![];
}
let p_prime = differentiate_poly(poly, var);
let common_divisor = gcd(poly.clone(), p_prime, var);
let p0 = poly.clone().long_division(&common_divisor, var).0;
let mut seq = Vec::new();
seq.push(p0.clone());
let p1 = differentiate_poly(&p0, var);
if p1.terms.is_empty() {
return seq;
}
seq.push(p1);
let mut i = 1;
while !seq[i].terms.is_empty() && seq[i].degree(var) > 0 {
let p_prev = &seq[i - 1];
let p_curr = &seq[i];
let (_, remainder) = p_prev.clone().long_division(&p_curr.clone(), var);
if remainder.terms.is_empty() {
break;
}
seq.push(remainder.neg());
i += 1;
}
seq
}
pub(crate) fn count_sign_changes(
sequence: &[SparsePolynomial],
point: f64,
var: &str,
) -> usize {
let mut changes = 0;
let mut last_sign: Option<i8> = None;
let mut vars = HashMap::new();
vars.insert(var.to_string(), point);
for poly in sequence {
let val = poly.eval(&vars);
let sign = if val > 1e-9 {
Some(1)
} else if val < -1e-9 {
Some(-1)
} else {
None
};
if let Some(s) = sign {
if let Some(ls) = last_sign {
if s != ls {
changes += 1;
}
}
last_sign = Some(s);
}
}
changes
}
pub fn count_real_roots_in_interval(
poly: &SparsePolynomial,
var: &str,
a: f64,
b: f64,
) -> Result<usize, String> {
let seq = sturm_sequence(poly, var);
let changes_a = count_sign_changes(&seq, a, var);
let changes_b = count_sign_changes(&seq, b, var);
Ok(changes_a.saturating_sub(changes_b))
}
pub fn isolate_real_roots(
poly: &SparsePolynomial,
var: &str,
precision: f64,
) -> Result<Vec<(f64, f64)>, String> {
let sq_free = poly
.clone()
.long_division(&gcd(poly.clone(), differentiate_poly(poly, var), var), var)
.0;
let seq = sturm_sequence(&sq_free, var);
let bound = root_bound(&sq_free, var)?;
let mut roots = Vec::new();
let mut stack = vec![(-bound, bound)];
while let Some((a, b)) = stack.pop() {
if b - a < precision {
continue;
}
let changes_a = count_sign_changes(&seq, a, var);
let changes_b = count_sign_changes(&seq, b, var);
let num_roots = changes_a.saturating_sub(changes_b);
if num_roots == 1 {
let mut low = a;
let mut high = b;
loop {
if (high - low).abs() <= precision {
break;
}
let mid = f64::midpoint(low, high);
if count_sign_changes(&seq, low, var) - count_sign_changes(&seq, mid, var) > 0 {
high = mid;
} else {
low = mid;
}
}
roots.push((low, high));
} else if num_roots > 1 {
let mid = f64::midpoint(a, b);
stack.push((a, mid));
stack.push((mid, b));
}
}
roots.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
Ok(roots)
}
pub(crate) fn root_bound(
poly: &SparsePolynomial,
var: &str,
) -> Result<f64, String> {
let coeffs = poly.get_coeffs_as_vec(var);
if coeffs.is_empty() {
return Ok(1.0);
}
let leading_coeff_expr = match coeffs.first() {
| Some(c) => c,
| None => unreachable!(),
};
let simplified_lc = crate::symbolic::simplify_dag::simplify(&leading_coeff_expr.clone());
let lc = as_f64(&simplified_lc).ok_or(
"Leading coefficient is \
not numerical.",
)?;
if lc == 0.0 {
return Err("Leading coefficient \
cannot be zero."
.to_string());
}
let max_coeff = coeffs
.iter()
.skip(1)
.map(|c| {
as_f64(&crate::symbolic::simplify_dag::simplify(&c.clone()))
.unwrap_or(0.0)
.abs()
})
.fold(0.0, f64::max);
Ok(1.0 + max_coeff / lc.abs())
}
#[must_use]
pub fn eval_expr<S: std::hash::BuildHasher>(
expr: &Expr,
vars: &HashMap<String, f64, S>,
) -> f64 {
match expr {
| Expr::Dag(node) => eval_expr(&node.to_expr().expect("Dag Eval Expr"), vars),
| Expr::Constant(c) => *c,
| Expr::BigInt(i) => i.to_f64().unwrap_or(0.0),
| Expr::Variable(v) => *vars.get(v).unwrap_or(&0.0),
| Expr::Add(a, b) => eval_expr(a, vars) + eval_expr(b, vars),
| Expr::Sub(a, b) => eval_expr(a, vars) - eval_expr(b, vars),
| Expr::Mul(a, b) => eval_expr(a, vars) * eval_expr(b, vars),
| Expr::Div(a, b) => eval_expr(a, vars) / eval_expr(b, vars),
| Expr::Power(b, e) => eval_expr(b, vars).powf(eval_expr(e, vars)),
| Expr::Neg(a) => -eval_expr(a, vars),
| _ => 0.0,
}
}