use crate::StrError;
use std::f64::consts::PI;
const EPS: f64 = 1e-12;
pub fn solve_cubic(a: f64, b: f64, c: f64, d: f64) -> Result<Vec<f64>, StrError> {
if a.abs() < EPS {
return Err("The absolute value of the leading coefficient 'a' must be nonzero (>= 1e-12).");
}
let p = b / a;
let q = c / a;
let r = d / a;
let p_sq = p * p;
let aa = q - p_sq / 3.0;
let bb = r - (p * q) / 3.0 + (2.0 * p * p_sq) / 27.0;
let half_b = bb / 2.0;
let third_a = aa / 3.0;
let delta = half_b * half_b + third_a.powf(3.0);
let mut ys = Vec::with_capacity(3);
match delta {
d if d > EPS => {
let sqrt_delta = delta.sqrt();
let u = (-half_b + sqrt_delta).cbrt();
let v = (-half_b - sqrt_delta).cbrt();
ys.push(u + v);
}
d if d < -EPS => {
let sqrt_neg_third_a = (-third_a).sqrt();
let radicand = -half_b / sqrt_neg_third_a.powf(3.0);
let radicand_clamped = radicand.clamp(-1.0, 1.0);
let theta = radicand_clamped.acos();
let factor = 2.0 * sqrt_neg_third_a;
let y1 = factor * (theta / 3.0).cos();
let y2 = factor * ((theta + 2.0 * PI) / 3.0).cos();
let y3 = factor * ((theta + 4.0 * PI) / 3.0).cos();
ys.extend_from_slice(&[y1, y2, y3]);
}
_ => {
if half_b.abs() < EPS {
ys.extend_from_slice(&[0.0, 0.0, 0.0]);
} else {
let u = (-half_b).cbrt();
let y1 = 2.0 * u;
let y2 = -u;
ys.extend_from_slice(&[y1, y2, y2]);
}
}
}
let p3 = p / 3.0;
let mut xs: Vec<f64> = ys.into_iter().map(|y| y - p3).collect();
xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
Ok(xs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_three_distinct_real_roots() {
let roots = solve_cubic(1.0, -6.0, 11.0, -6.0).unwrap();
assert_eq!(roots.len(), 3);
assert!((roots[0] - 1.0).abs() < EPS);
assert!((roots[1] - 2.0).abs() < EPS);
assert!((roots[2] - 3.0).abs() < EPS);
}
#[test]
fn test_triple_root() {
let roots = solve_cubic(1.0, 3.0, 3.0, 1.0).unwrap();
assert_eq!(roots.len(), 3);
roots.iter().for_each(|r| assert!((r + 1.0).abs() < EPS));
}
#[test]
fn test_invalid_leading_coeff() {
let err = solve_cubic(0.0, 1.0, 1.0, 1.0).unwrap_err();
assert_eq!(
err,
"The absolute value of the leading coefficient 'a' must be nonzero (>= 1e-12)."
);
}
#[test]
fn test_irreducible_case() {
let roots = solve_cubic(1.0, 0.0, -1.0, 0.0).unwrap();
assert_eq!(roots.len(), 3);
assert!((roots[0] + 1.0).abs() < EPS);
assert!((roots[1] - 0.0).abs() < EPS);
assert!((roots[2] - 1.0).abs() < EPS);
}
#[test]
fn test_single_real_root() {
let roots = solve_cubic(1.0, 1.0, 1.0, 1.0).unwrap();
assert_eq!(roots.len(), 1);
assert!((roots[0] + 1.0).abs() < EPS);
}
#[test]
fn test_double_root() {
let roots = solve_cubic(1.0, -4.0, 5.0, -2.0).unwrap();
assert_eq!(roots.len(), 3);
assert!((roots[0] - 1.0).abs() < EPS);
assert!((roots[1] - 1.0).abs() < EPS);
assert!((roots[2] - 2.0).abs() < EPS);
}
#[test]
fn test_near_zero_roots() {
let roots = solve_cubic(1.0, -3.0, 2.0, 0.0).unwrap();
assert_eq!(roots.len(), 3);
assert!((roots[0] - 0.0).abs() < EPS);
assert!((roots[1] - 1.0).abs() < EPS);
assert!((roots[2] - 2.0).abs() < EPS);
}
#[test]
fn test_large_coefficients() {
let a = 1e9;
let b = -6e6;
let c = 11e3;
let d = -6.0;
let roots = solve_cubic(a, b, c, d).unwrap();
assert_eq!(roots.len(), 3);
for root in &roots {
let value = a * root.powi(3) + b * root.powi(2) + c * root + d;
assert!(
value.abs() < 1e-6,
"Root {} does not satisfy the equation: value = {}",
root,
value
);
}
}
#[test]
fn test_small_coefficients() {
let a = 1e-9;
let b = -6e-6;
let c = 11e-3;
let d = -6.0;
let roots = solve_cubic(a, b, c, d).unwrap();
assert_eq!(roots.len(), 3);
for root in &roots {
let value = a * root.powi(3) + b * root.powi(2) + c * root + d;
assert!(
value.abs() < 1e-6,
"Root {} does not satisfy the equation: value = {}",
root,
value
);
}
}
#[test]
fn test_negative_coefficients() {
let roots = solve_cubic(-1.0, 6.0, -11.0, 6.0).unwrap();
assert_eq!(roots.len(), 3);
assert!((roots[0] - 1.0).abs() < EPS);
assert!((roots[1] - 2.0).abs() < EPS);
assert!((roots[2] - 3.0).abs() < EPS);
}
#[test]
fn test_fractional_coefficients() {
let a = 0.125;
let b = -0.75;
let c = 1.375;
let d = -0.75;
let roots = solve_cubic(a, b, c, d).unwrap();
assert_eq!(roots.len(), 3);
for root in &roots {
let value = a * root.powi(3) + b * root.powi(2) + c * root + d;
assert!(
value.abs() < 1e-6,
"Root {} does not satisfy the equation: value = {}",
root,
value
);
}
}
#[test]
fn test_floating_point_precision() {
let a = 1.0;
let b = -6e-15;
let c = 11e-30;
let d = -6e-45;
let roots = solve_cubic(a, b, c, d).unwrap();
assert_eq!(roots.len(), 3);
for root in &roots {
let value = a * root.powi(3) + b * root.powi(2) + c * root + d;
assert!(
value.abs() < 1e-40,
"Root {} does not satisfy the equation: value = {}",
root,
value
);
}
}
#[test]
fn test_error_handling() {
let err = solve_cubic(0.0, 1.0, 1.0, 1.0).unwrap_err();
assert_eq!(
err,
"The absolute value of the leading coefficient 'a' must be nonzero (>= 1e-12)."
);
let err_msg = format!("{}", err);
assert!(err_msg.contains("leading coefficient 'a' must be nonzero"));
}
#[test]
fn test_x_cubed_equals_zero() {
let roots = solve_cubic(1.0, 0.0, 0.0, 0.0).unwrap();
assert_eq!(roots.len(), 3);
assert!((roots[0] - 0.0).abs() < EPS);
assert!((roots[1] - 0.0).abs() < EPS);
assert!((roots[2] - 0.0).abs() < EPS);
}
#[test]
fn test_x_cubed_equals_d() {
let roots = solve_cubic(1.0, 0.0, 0.0, -8.0).unwrap();
assert_eq!(roots.len(), 1);
assert!((roots[0] - 2.0).abs() < EPS);
}
#[test]
fn test_roots_zero_and_non_zero() {
let roots = solve_cubic(1.0, -2.0, 0.0, 0.0).unwrap();
assert_eq!(roots.len(), 3);
assert!((roots[0] - 0.0).abs() < EPS);
assert!((roots[1] - 0.0).abs() < EPS);
assert!((roots[2] - 2.0).abs() < EPS);
}
}