nsys-math-utils 1.0.0

Math types and traits
Documentation
//! Elementary and linear algebraic methods

use crate::{num, NonZero};
use crate::traits::*;

pub mod linear;

/// Solve the cubic equation `ax^3 + bx^2 + cx + d = 0` where `a != 0`. Returns one,
/// two, or three real roots.
///
/// Roots are returned in sorted order.
pub fn solve_cubic_equation <S> (a : NonZero <S>, b : S, c : S, d : S)
  -> [Option <S>; 3]
where S : Real + Cbrt + num::NumCast + std::fmt::Debug {
  // algorithm from roots crate:
  // <https://github.com/vorot/roots/blob/2809d56ae1773b6b21a4204c0576af5263134f5c/src/analytical/cubic.rs>
  let a = *a;
  let eighteen     = S::two() * S::nine();
  let twentyseven  = S::three() * S::nine();
  let fiftyfour    = S::two() * twentyseven;
  let one_third    = S::one() / S::three();
  let two_third_pi = S::two() * S::frac_pi_3();
  // TODO: approx eq ?
  let mut roots = if b == S::zero() {
    // depressed cubic
    // x^3 + cx + d = 0
    let c = c / a;
    let d = d / a;
    if c == S::zero() {
      [Some (-d.cbrt()), None, None]
    } else if d == S::zero() {
      let quadratic_roots = solve_quadratic_equation (S::one(), S::zero(), c);
      match quadratic_roots {
        [Some (r0), Some (r1)] => [Some (r0), Some (r1), Some (S::zero())],
        [Some (r0), None]      => [Some (r0), Some (S::zero()), None],
        [None,      None]      => [Some (S::zero()), None, None],
        [None,      Some (_)]  => unreachable!()
      }
    } else {
      let dd = d * d / S::four() + c * c * c / twentyseven;
      if dd < S::zero() {
        // n*d^2 + m*c^3 < 0 => c < 0
        let aa = (-S::four() * c * one_third).sqrt();
        let phi = (-S::four() * d / (aa * aa * aa)).acos() * one_third;
        [ Some (aa * phi.cos()),
          Some (aa * (phi + two_third_pi).cos()),
          Some (aa * (phi - two_third_pi).cos())
        ]
      } else {
        let sqrt_dd = dd.sqrt();
        let d_div_2 = d * S::half();
        let x1 = (sqrt_dd - d_div_2).cbrt() - (sqrt_dd + d_div_2).cbrt();
        if dd == S::zero() {
          // one simple and one double root
          [Some (x1), Some (d_div_2), None]
        } else {
          // one real root
          [Some (x1), None, None]
        }
      }
    }
  } else if a == S::one() {
    // normalized cubic
    // x^3 + bx^2 + cx + d = 0
    let q = (S::three() * c - b * b) / S::nine();
    let r = (S::nine() * b * c - twentyseven * d - S::two() * b * b * b) / fiftyfour;
    let q3 = q * q * q;
    let dd = q3 + r * r;
    let b_div_3 = b * one_third;
    if dd < S::zero() {
      let phi_3 = (r / (-q3).sqrt()).acos() * one_third;
      let sqrt_q_2 = S::two() * (-q).sqrt();
      [ Some (sqrt_q_2 * phi_3.cos() - b_div_3),
        Some (sqrt_q_2 * (phi_3 - S::two() * S::frac_pi_3()).cos() - b_div_3),
        Some (sqrt_q_2 * (phi_3 + S::two() * S::frac_pi_3()).cos() - b_div_3)
      ]
    } else {
      let sqrt_dd = dd.sqrt();
      let s = (r + sqrt_dd).cbrt();
      let t = (r - sqrt_dd).cbrt();
      let root0 = s + t - b_div_3;
      // TODO: approx eq ?
      if s == t {
        if s + t == S::zero() {
          [Some (root0), None, None]
        } else {
          let root1 = -(s + t) * S::half() - b_div_3;
          [Some (root0), Some (root1), None]
        }
      } else {
        [Some (root0), None, None]
      }
    }
  } else {
    let dd = eighteen * a*b*c*d - S::four()*b*b*b*d + b*b*c*c - S::four()*a*c*c*c
      - twentyseven*a*a*d*d;
    let d0 = b*b - S::three()*a*c;
    let d1 = S::two()*b*b*b - S::nine()*a*b*c + twentyseven*a*a*d;
    if d < S::zero() {
      // one real root
      let sqrt = (-twentyseven * a * a * dd).sqrt();
      let cc   = S::cbrt (if d1 < S::zero() { d1 - sqrt } else { d1 + sqrt } / S::two());
      [Some ( -(b + cc + d0 / cc) / (S::three() * a)), None, None]
    } else if d == S::zero() {
      // multiple roots
      if d0 == S::zero() {
        // triple root
        [Some (-b / (a * S::three())), None, None]
      } else {
        // simple root and double root
        let simple_root = (S::nine() * a * d - b * c) / (d0 * S::two());
        let double_root = (S::four() * a * b * c - S::nine() * a * a * d - b * b * b)
          / (a * d0);
        [Some (simple_root), Some (double_root), None]
      }
    } else {
      // three real roots
      // 1
      let div_by_3a = S::one() / (S::three() * a);
      let c3_img    = (twentyseven * a * a * dd).sqrt() / S::two();
      let c3_real   = d1 * S::half();
      let c3_module = (c3_img * c3_img + c3_real * c3_real).sqrt();
      let c3_phase  = S::two() * (c3_img / (c3_real + c3_module)).atan();
      let c_module  = c3_module.cbrt();
      let c_phase   = c3_phase * one_third;
      let c_real    = c_module * c_phase.cos();
      let c_img     = c_module * c_phase.sin();
      let x0_real   = -(b + c_real + (d0 * c_real) / (c_module * c_module)) * div_by_3a;
      // 2
      let e_real    = -S::half();
      let e_img     = S::sqrt_3() * S::half();
      let c1_real   = c_real * e_real - c_img * e_img;
      let c1_img    = c_real * e_img + c_img * e_real;
      let x1_real   =
        -(b + c1_real + (d0 * c1_real) / (c1_real * c1_real + c1_img * c1_img))
        * div_by_3a;
      // 3
      let c2_real = c1_real * e_real - c1_img * e_img;
      let c2_img  = c1_real * e_img + c1_img * e_real;
      let x2_real =
        -(b + c2_real + (d0 * c2_real) / (c2_real * c2_real + c2_img * c2_img))
        * div_by_3a;
      [Some (x0_real), Some (x1_real), Some (x2_real)]
    }
  };
  roots.sort_by (|a, b| match (a, b) {
    (Some (a), Some (b)) => {
      debug_assert_ne!(a, b); // TODO: is this possible ?
      a.partial_cmp (b).unwrap()
    }
    (None,     Some (_)) => std::cmp::Ordering::Greater,
    (Some (_), None)     => std::cmp::Ordering::Less,
    (None,     None)     => std::cmp::Ordering::Equal
  });
  roots
}

/// Solve the quadratic equation `ax^2 + bx + c = 0` where `a != 0`. Returns zero, one,
/// or two real roots.
///
/// Roots are returned in sorted order;
pub fn solve_quadratic_equation <S> (a : S, b : S, c : S) -> [Option <S>; 2] where
  S : OrderedField + Sqrt + std::fmt::Debug
{
  debug_assert_ne!(a, S::zero());
  let discriminant = b * b - S::four() * a * c;
  if discriminant < S::zero() {
    // no real roots
    [None, None]
  } else if discriminant == S::zero() {
    // single root
    [Some (S::half() * (-b / a)), None]
  } else {
    // two roots
    let discriminant_sqrt = discriminant.sqrt();
    let div_a = S::one() / a;
    let root0 = S::half() * div_a * (-b + discriminant_sqrt);
    let root1 = S::half() * div_a * (-b - discriminant_sqrt);
    debug_assert_ne!(root0, root1); // TODO: is this possible ?
    if root0 < root1 {
      [Some (root0), Some (root1)]
    } else {
      [Some (root1), Some (root0)]
    }
  }
}

#[cfg(test)]
mod tests {
  #![expect(clippy::unreadable_literal)]
  use crate::approx;
  use super::*;

  #[test]
  fn solve_cubic() {
    // three distinct real roots
    // x^3 - 6x^2 + 11x - 6 = 0
    // roots: 1, 2, 3
    let roots = solve_cubic_equation (NonZero::noisy (1.0), -6.0, 11.0, -6.0);
    assert_eq!(roots[0].unwrap(), 1.0);
    approx::assert_relative_eq!(roots[1].unwrap(), 2.0);
    assert_eq!(roots[2].unwrap(), 3.0);
    // one real root
    // x^3 + x + 1 = 0
    // root: -0.68233
    let roots = solve_cubic_equation (NonZero::noisy (1.0), 0.0, 1.0, 1.0);
    assert_eq!(roots, [Some (-0.6823278038280194), None, None]);
    // triple root
    // x^3 - 3x^2 + 3x - 1 = 0
    // root: 1.0
    let roots = solve_cubic_equation (NonZero::noisy (1.0), -3.0, 3.0, -1.0);
    assert_eq!(roots, [Some (1.0), None, None]);
    // three distinct real roots
    // 2x^3 - 4x^2 - 2x + 4 = 0
    // roots: -1, 1, 2
    let roots = solve_cubic_equation (NonZero::noisy (2.0), -4.0, -2.0, 4.0);
    approx::assert_relative_eq!(roots[0].unwrap(), -1.0);
    approx::assert_relative_eq!(roots[1].unwrap(),  1.0);
    assert_eq!(roots[2].unwrap(), 2.0);
  }
}