use num_complex::Complex64;
const DEGENERATE_COEFF_EPS: f64 = 1e-12;
const MISSING_ROOT: Complex64 = Complex64::new(f64::NAN, f64::NAN);
pub fn solve_cubic(alpha: f64, beta: f64, gamma: f64, delta: f64) -> [Complex64; 3] {
use std::f64::consts::PI;
if alpha.abs() < DEGENERATE_COEFF_EPS {
return solve_quadratic_padded(beta, gamma, delta);
}
let a = beta / alpha;
let b = gamma / alpha;
let c = delta / alpha;
let q = (a * a - 3.0 * b) / 9.0;
let r = (2.0 * a * a * a - 9.0 * a * b + 27.0 * c) / 54.0;
let discriminant = r * r - q * q * q;
if discriminant < 0.0 {
if q <= 0.0 {
panic!("Invalid state: discriminant < 0 but Q ≤ 0");
}
let q_cubed = q * q * q;
let sqrt_q_cubed = q_cubed.sqrt();
let cos_arg = (r / sqrt_q_cubed).clamp(-1.0, 1.0);
let theta = cos_arg.acos();
let sqrt_q = q.sqrt();
[
Complex64::new(-2.0 * sqrt_q * (theta / 3.0).cos() - a / 3.0, 0.0),
Complex64::new(
-2.0 * sqrt_q * ((theta + 2.0 * PI) / 3.0).cos() - a / 3.0,
0.0,
),
Complex64::new(
-2.0 * sqrt_q * ((theta - 2.0 * PI) / 3.0).cos() - a / 3.0,
0.0,
),
]
} else {
let sqrt_discriminant = discriminant.sqrt();
let a_val = -r.signum() * (r.abs() + sqrt_discriminant).cbrt();
let b_val = if a_val.abs() < 1e-10 { 0.0 } else { q / a_val };
let real_part = -0.5 * (a_val + b_val) - a / 3.0;
let imag_part = 3.0_f64.sqrt() * (a_val - b_val) / 2.0;
[
Complex64::new(a_val + b_val - a / 3.0, 0.0),
Complex64::new(real_part, imag_part),
Complex64::new(real_part, -imag_part),
]
}
}
fn solve_quadratic_padded(beta: f64, gamma: f64, delta: f64) -> [Complex64; 3] {
if beta.abs() < DEGENERATE_COEFF_EPS {
if gamma.abs() < DEGENERATE_COEFF_EPS {
return [MISSING_ROOT; 3];
}
return [
Complex64::new(-delta / gamma, 0.0),
MISSING_ROOT,
MISSING_ROOT,
];
}
let discriminant = gamma * gamma - 4.0 * beta * delta;
let two_beta = 2.0 * beta;
if discriminant >= 0.0 {
let sqrt_disc = discriminant.sqrt();
[
Complex64::new((-gamma + sqrt_disc) / two_beta, 0.0),
Complex64::new((-gamma - sqrt_disc) / two_beta, 0.0),
MISSING_ROOT,
]
} else {
let sqrt_disc = (-discriminant).sqrt();
let real = -gamma / two_beta;
let imag = sqrt_disc / two_beta;
[
Complex64::new(real, imag),
Complex64::new(real, -imag),
MISSING_ROOT,
]
}
}
pub fn extract_real_roots(roots: &[Complex64]) -> Vec<f64> {
roots.iter().filter(|r| r.im == 0.0).map(|r| r.re).collect()
}
pub fn extract_real_pairs(pairs: &[(Complex64, Complex64)], tolerance: f64) -> Vec<(f64, f64)> {
pairs
.iter()
.filter(|(lambda, mu)| lambda.im.abs() < tolerance && mu.im.abs() < tolerance)
.map(|(lambda, mu)| (lambda.re, mu.re))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_cubic_three_real_roots() {
let roots = solve_cubic(1.0, -6.0, 11.0, -6.0);
let mut real_roots: Vec<f64> = roots.iter().map(|r| r.re).collect();
real_roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
for root in &roots {
assert!(root.im.abs() < 1e-10, "Root should be real: {:?}", root);
}
assert!(approx_eq(real_roots[0], 1.0, 1e-10));
assert!(approx_eq(real_roots[1], 2.0, 1e-10));
assert!(approx_eq(real_roots[2], 3.0, 1e-10));
}
#[test]
fn test_cubic_one_real_root() {
let roots = solve_cubic(1.0, 0.0, 0.0, -1.0);
let real_roots = extract_real_roots(&roots);
assert_eq!(real_roots.len(), 1);
assert!(approx_eq(real_roots[0], 1.0, 1e-10));
let complex_roots: Vec<_> = roots.iter().filter(|r| r.im.abs() > 1e-10).collect();
assert_eq!(complex_roots.len(), 2);
}
#[test]
fn test_cubic_repeated_root() {
let roots = solve_cubic(1.0, -6.0, 12.0, -8.0);
for root in &roots {
assert!(
approx_eq(root.re, 2.0, 1e-9),
"Root should be 2.0, got {}",
root.re
);
assert!(root.im.abs() < 1e-9, "Root should be real");
}
}
#[test]
fn test_extract_real_roots() {
let roots = solve_cubic(1.0, -6.0, 11.0, -6.0);
let real_roots = extract_real_roots(&roots);
assert_eq!(real_roots.len(), 3);
let mut sorted = real_roots.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert!(approx_eq(sorted[0], 1.0, 1e-10));
assert!(approx_eq(sorted[1], 2.0, 1e-10));
assert!(approx_eq(sorted[2], 3.0, 1e-10));
}
#[test]
fn test_cubic_negative_leading_coefficient() {
let roots = solve_cubic(-1.0, 6.0, -11.0, 6.0);
let mut real_roots: Vec<f64> = roots.iter().map(|r| r.re).collect();
real_roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert!(approx_eq(real_roots[0], 1.0, 1e-10));
assert!(approx_eq(real_roots[1], 2.0, 1e-10));
assert!(approx_eq(real_roots[2], 3.0, 1e-10));
}
#[test]
fn test_cubic_with_zero_root() {
let roots = solve_cubic(1.0, 0.0, -1.0, 0.0);
let mut real_roots: Vec<f64> = roots.iter().map(|r| r.re).collect();
real_roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert!(approx_eq(real_roots[0], -1.0, 1e-10));
assert!(approx_eq(real_roots[1], 0.0, 1e-10));
assert!(approx_eq(real_roots[2], 1.0, 1e-10));
}
#[test]
fn test_cubic_alpha_zero_falls_back_to_quadratic() {
let roots = solve_cubic(0.0, 1.0, -2.0, 1.0);
let real_roots = extract_real_roots(&roots);
assert_eq!(real_roots.len(), 2);
for r in &real_roots {
assert!(approx_eq(*r, 1.0, 1e-10));
}
}
#[test]
fn test_cubic_alpha_near_zero_falls_back_to_quadratic() {
let roots = solve_cubic(-7.3e-15, 2.0, -4.0, 2.0);
let real_roots = extract_real_roots(&roots);
assert!(!real_roots.is_empty(), "expected a real fallback root");
for r in &real_roots {
assert!(approx_eq(*r, 1.0, 1e-9), "got {}", r);
}
}
#[test]
fn test_cubic_alpha_zero_linear_fallback() {
let roots = solve_cubic(0.0, 0.0, 2.0, 4.0);
let real_roots = extract_real_roots(&roots);
assert_eq!(real_roots.len(), 1);
assert!(approx_eq(real_roots[0], -2.0, 1e-12));
}
#[test]
fn test_cubic_alpha_zero_complex_quadratic() {
let roots = solve_cubic(0.0, 1.0, 0.0, 1.0);
let real_roots = extract_real_roots(&roots);
assert!(real_roots.is_empty());
}
#[test]
fn test_cubic_near_boundary() {
let roots = solve_cubic(1.0, -4.0, 5.0, -2.0);
let mut real_roots: Vec<f64> = roots.iter().map(|r| r.re).collect();
real_roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert!(approx_eq(real_roots[0], 1.0, 1e-9));
assert!(approx_eq(real_roots[1], 1.0, 1e-9));
assert!(approx_eq(real_roots[2], 2.0, 1e-9));
}
}