use thiserror::Error;
use crate::numerics::utils::{is_near_zero, is_near_zero_eps};
#[derive(Debug, Error, PartialEq)]
pub enum CubicError {
#[error("leading coefficient is near zero ({0}); not a cubic equation")]
NotCubic(f64),
#[error("non-finite coefficient: a={a}, b={b}, c={c}, d={d}")]
NotFinite { a: f64, b: f64, c: f64, d: f64 },
}
pub fn solve_real(a: f64, b: f64, c: f64, d: f64) -> Result<Vec<f64>, CubicError> {
if !(a.is_finite() && b.is_finite() && c.is_finite() && d.is_finite()) {
return Err(CubicError::NotFinite { a, b, c, d });
}
if is_near_zero(a) {
return Err(CubicError::NotCubic(a));
}
let b1 = b / a;
let c1 = c / a;
let d1 = d / a;
let shift = b1 / 3.0;
let p = c1 - b1 * b1 / 3.0;
let q = 2.0 * b1.powi(3) / 27.0 - b1 * c1 / 3.0 + d1;
let half_q = q / 2.0;
let third_p = p / 3.0;
let delta = half_q * half_q + third_p * third_p * third_p;
let scale = (half_q * half_q).max(third_p.abs().powi(3)).max(1.0);
let delta_eps = 1e-14 * scale;
let mut ys: Vec<f64> = if is_near_zero_eps(delta, delta_eps) {
repeated_roots(p, q)
} else if delta > 0.0 {
let sqrt_delta = delta.sqrt();
let u = (-half_q + sqrt_delta).cbrt();
let v = (-half_q - sqrt_delta).cbrt();
vec![u + v]
} else {
let neg_third_p = -third_p; let r = (-third_p.powi(3)).sqrt();
let arg = (-half_q / r).clamp(-1.0, 1.0);
let phi = arg.acos();
let amp = 2.0 * neg_third_p.sqrt();
let two_pi = 2.0 * std::f64::consts::PI;
vec![
amp * (phi / 3.0).cos(),
amp * ((phi - two_pi) / 3.0).cos(),
amp * ((phi + two_pi) / 3.0).cos(),
]
};
for y in ys.iter_mut() {
*y -= shift;
}
ys.sort_by(|a, b| a.partial_cmp(b).unwrap());
dedup_close(&mut ys, 1e-12);
Ok(ys)
}
fn repeated_roots(p: f64, q: f64) -> Vec<f64> {
if is_near_zero(p) {
if is_near_zero(q) {
vec![0.0]
} else {
vec![(-q).cbrt()]
}
} else {
let y_double = -3.0 * q / (2.0 * p);
let y_single = 3.0 * q / p;
vec![y_double, y_single]
}
}
fn dedup_close(vs: &mut Vec<f64>, eps: f64) {
vs.dedup_by(|a, b| (*a - *b).abs() <= eps);
}
#[inline]
pub fn select_smallest(roots: &[f64]) -> Option<f64> {
roots.first().copied()
}
#[inline]
pub fn select_largest(roots: &[f64]) -> Option<f64> {
roots.last().copied()
}
#[cfg(test)]
mod tests {
use super::*;
fn cubic_eval(a: f64, b: f64, c: f64, d: f64, x: f64) -> f64 {
((a * x + b) * x + c) * x + d
}
fn assert_close(actual: f64, expected: f64, tol: f64) {
assert!(
(actual - expected).abs() < tol,
"expected {expected}, got {actual} (diff {})",
(actual - expected).abs()
);
}
#[test]
fn rejects_non_cubic() {
let r = solve_real(0.0, 1.0, 2.0, 3.0).unwrap_err();
assert!(matches!(r, CubicError::NotCubic(_)));
}
#[test]
fn rejects_nan_coefficient() {
let r = solve_real(1.0, f64::NAN, 0.0, 0.0).unwrap_err();
assert!(matches!(r, CubicError::NotFinite { .. }));
}
#[test]
fn one_real_root_via_cardano() {
let roots = solve_real(1.0, 0.0, 1.0, 1.0).unwrap();
assert_eq!(roots.len(), 1);
assert_close(roots[0], -0.682_327_803_828_019_3, 1e-12);
assert!(cubic_eval(1.0, 0.0, 1.0, 1.0, roots[0]).abs() < 1e-10);
}
#[test]
fn three_real_roots_via_trigonometric_form() {
let roots = solve_real(1.0, -6.0, 11.0, -6.0).unwrap();
assert_eq!(roots.len(), 3);
assert_close(roots[0], 1.0, 1e-10);
assert_close(roots[1], 2.0, 1e-10);
assert_close(roots[2], 3.0, 1e-10);
}
#[test]
fn double_plus_single_root() {
let roots = solve_real(1.0, -3.0, 0.0, 4.0).unwrap();
assert_eq!(roots.len(), 2);
assert_close(roots[0], -1.0, 1e-9);
assert_close(roots[1], 2.0, 1e-9);
}
#[test]
fn triple_root() {
let roots = solve_real(1.0, -15.0, 75.0, -125.0).unwrap();
assert_eq!(roots.len(), 1, "triple root collapses after dedup");
assert_close(roots[0], 5.0, 1e-6);
}
#[test]
fn vle_z_factor_style_cubic() {
let roots = solve_real(1.0, -1.05, 0.215, -0.012).unwrap();
assert_eq!(roots.len(), 3);
assert_close(roots[0], 0.10, 1e-10);
assert_close(roots[1], 0.15, 1e-10);
assert_close(roots[2], 0.80, 1e-10);
assert_close(select_smallest(&roots).unwrap(), 0.10, 1e-10);
assert_close(select_largest(&roots).unwrap(), 0.80, 1e-10);
}
#[test]
fn near_degenerate_discriminant_does_not_panic() {
let roots = solve_real(1.0, 1e-16, 0.0, 0.0).unwrap();
assert!(!roots.is_empty());
for r in &roots {
assert!(r.is_finite(), "root must be finite, got {r}");
}
}
}