use crate::{num, NonZero};
use crate::traits::*;
pub mod linear;
pub fn solve_cubic_equation <S> (a : NonZero <S>, b : S, c : S, d : S)
-> [Option <S>; 3]
where S : Real + Cbrt + num::NumCast + std::fmt::Debug {
let a = *a;
let eighteen = S::two() * S::nine();
let twentyseven = S::three() * S::nine();
let fiftyfour = S::two() * twentyseven;
let one_third = S::one() / S::three();
let two_third_pi = S::two() * S::frac_pi_3();
let mut roots = if b == S::zero() {
let c = c / a;
let d = d / a;
if c == S::zero() {
[Some (-d.cbrt()), None, None]
} else if d == S::zero() {
let quadratic_roots = solve_quadratic_equation (S::one(), S::zero(), c);
match quadratic_roots {
[Some (r0), Some (r1)] => [Some (r0), Some (r1), Some (S::zero())],
[Some (r0), None] => [Some (r0), Some (S::zero()), None],
[None, None] => [Some (S::zero()), None, None],
[None, Some (_)] => unreachable!()
}
} else {
let dd = d * d / S::four() + c * c * c / twentyseven;
if dd < S::zero() {
let aa = (-S::four() * c * one_third).sqrt();
let phi = (-S::four() * d / (aa * aa * aa)).acos() * one_third;
[ Some (aa * phi.cos()),
Some (aa * (phi + two_third_pi).cos()),
Some (aa * (phi - two_third_pi).cos())
]
} else {
let sqrt_dd = dd.sqrt();
let d_div_2 = d * S::half();
let x1 = (sqrt_dd - d_div_2).cbrt() - (sqrt_dd + d_div_2).cbrt();
if dd == S::zero() {
[Some (x1), Some (d_div_2), None]
} else {
[Some (x1), None, None]
}
}
}
} else if a == S::one() {
let q = (S::three() * c - b * b) / S::nine();
let r = (S::nine() * b * c - twentyseven * d - S::two() * b * b * b) / fiftyfour;
let q3 = q * q * q;
let dd = q3 + r * r;
let b_div_3 = b * one_third;
if dd < S::zero() {
let phi_3 = (r / (-q3).sqrt()).acos() * one_third;
let sqrt_q_2 = S::two() * (-q).sqrt();
[ Some (sqrt_q_2 * phi_3.cos() - b_div_3),
Some (sqrt_q_2 * (phi_3 - S::two() * S::frac_pi_3()).cos() - b_div_3),
Some (sqrt_q_2 * (phi_3 + S::two() * S::frac_pi_3()).cos() - b_div_3)
]
} else {
let sqrt_dd = dd.sqrt();
let s = (r + sqrt_dd).cbrt();
let t = (r - sqrt_dd).cbrt();
let root0 = s + t - b_div_3;
if s == t {
if s + t == S::zero() {
[Some (root0), None, None]
} else {
let root1 = -(s + t) * S::half() - b_div_3;
[Some (root0), Some (root1), None]
}
} else {
[Some (root0), None, None]
}
}
} else {
let dd = eighteen * a*b*c*d - S::four()*b*b*b*d + b*b*c*c - S::four()*a*c*c*c
- twentyseven*a*a*d*d;
let d0 = b*b - S::three()*a*c;
let d1 = S::two()*b*b*b - S::nine()*a*b*c + twentyseven*a*a*d;
if d < S::zero() {
let sqrt = (-twentyseven * a * a * dd).sqrt();
let cc = S::cbrt (if d1 < S::zero() { d1 - sqrt } else { d1 + sqrt } / S::two());
[Some ( -(b + cc + d0 / cc) / (S::three() * a)), None, None]
} else if d == S::zero() {
if d0 == S::zero() {
[Some (-b / (a * S::three())), None, None]
} else {
let simple_root = (S::nine() * a * d - b * c) / (d0 * S::two());
let double_root = (S::four() * a * b * c - S::nine() * a * a * d - b * b * b)
/ (a * d0);
[Some (simple_root), Some (double_root), None]
}
} else {
let div_by_3a = S::one() / (S::three() * a);
let c3_img = (twentyseven * a * a * dd).sqrt() / S::two();
let c3_real = d1 * S::half();
let c3_module = (c3_img * c3_img + c3_real * c3_real).sqrt();
let c3_phase = S::two() * (c3_img / (c3_real + c3_module)).atan();
let c_module = c3_module.cbrt();
let c_phase = c3_phase * one_third;
let c_real = c_module * c_phase.cos();
let c_img = c_module * c_phase.sin();
let x0_real = -(b + c_real + (d0 * c_real) / (c_module * c_module)) * div_by_3a;
let e_real = -S::half();
let e_img = S::sqrt_3() * S::half();
let c1_real = c_real * e_real - c_img * e_img;
let c1_img = c_real * e_img + c_img * e_real;
let x1_real =
-(b + c1_real + (d0 * c1_real) / (c1_real * c1_real + c1_img * c1_img))
* div_by_3a;
let c2_real = c1_real * e_real - c1_img * e_img;
let c2_img = c1_real * e_img + c1_img * e_real;
let x2_real =
-(b + c2_real + (d0 * c2_real) / (c2_real * c2_real + c2_img * c2_img))
* div_by_3a;
[Some (x0_real), Some (x1_real), Some (x2_real)]
}
};
roots.sort_by (|a, b| match (a, b) {
(Some (a), Some (b)) => {
debug_assert_ne!(a, b); a.partial_cmp (b).unwrap()
}
(None, Some (_)) => std::cmp::Ordering::Greater,
(Some (_), None) => std::cmp::Ordering::Less,
(None, None) => std::cmp::Ordering::Equal
});
roots
}
pub fn solve_quadratic_equation <S> (a : S, b : S, c : S) -> [Option <S>; 2] where
S : OrderedField + Sqrt + std::fmt::Debug
{
debug_assert_ne!(a, S::zero());
let discriminant = b * b - S::four() * a * c;
if discriminant < S::zero() {
[None, None]
} else if discriminant == S::zero() {
[Some (S::half() * (-b / a)), None]
} else {
let discriminant_sqrt = discriminant.sqrt();
let div_a = S::one() / a;
let root0 = S::half() * div_a * (-b + discriminant_sqrt);
let root1 = S::half() * div_a * (-b - discriminant_sqrt);
debug_assert_ne!(root0, root1); if root0 < root1 {
[Some (root0), Some (root1)]
} else {
[Some (root1), Some (root0)]
}
}
}
#[cfg(test)]
mod tests {
#![expect(clippy::unreadable_literal)]
use crate::approx;
use super::*;
#[test]
fn solve_cubic() {
let roots = solve_cubic_equation (NonZero::noisy (1.0), -6.0, 11.0, -6.0);
assert_eq!(roots[0].unwrap(), 1.0);
approx::assert_relative_eq!(roots[1].unwrap(), 2.0);
assert_eq!(roots[2].unwrap(), 3.0);
let roots = solve_cubic_equation (NonZero::noisy (1.0), 0.0, 1.0, 1.0);
assert_eq!(roots, [Some (-0.6823278038280194), None, None]);
let roots = solve_cubic_equation (NonZero::noisy (1.0), -3.0, 3.0, -1.0);
assert_eq!(roots, [Some (1.0), None, None]);
let roots = solve_cubic_equation (NonZero::noisy (2.0), -4.0, -2.0, 4.0);
approx::assert_relative_eq!(roots[0].unwrap(), -1.0);
approx::assert_relative_eq!(roots[1].unwrap(), 1.0);
assert_eq!(roots[2].unwrap(), 2.0);
}
}