math_utils/algebra/
mod.rs

1//! Elementary and linear algebraic methods
2
3use crate::num_traits as num;
4use crate::traits::*;
5
6pub mod linear;
7
8/// Solve the cubic equation `ax^3 + bx^2 + cx + d = 0` where `a != 0`. Returns one,
9/// two, or three real roots.
10///
11/// Roots are returned in sorted order.
12pub fn solve_cubic_equation <S> (a : S, b : S, c : S, d : S) -> [Option <S>; 3] where
13  S : Real + Cbrt + num::NumCast + std::fmt::Debug
14{
15  // algorithm from roots crate:
16  // <https://github.com/vorot/roots/blob/2809d56ae1773b6b21a4204c0576af5263134f5c/src/analytical/cubic.rs>
17  debug_assert_ne!(a, S::zero());
18  let eighteen     = S::two() * S::nine();
19  let twentyseven  = S::three() * S::nine();
20  let fiftyfour    = S::two() * twentyseven;
21  let one_third    = S::one() / S::three();
22  let two_third_pi = S::two() * S::frac_pi_3();
23  // TODO: approx eq ?
24  let mut roots = if b == S::zero() {
25    // depressed cubic
26    // x^3 + cx + d = 0
27    let c = c / a;
28    let d = d / a;
29    if c == S::zero() {
30      [Some (-d.cbrt()), None, None]
31    } else if d == S::zero() {
32      let quadratic_roots = solve_quadratic_equation (S::one(), S::zero(), c);
33      match quadratic_roots {
34        [Some (r0), Some (r1)] => [Some (r0), Some (r1), Some (S::zero())],
35        [Some (r0), None]      => [Some (r0), Some (S::zero()), None],
36        [None,      None]      => [Some (S::zero()), None, None],
37        [None,      Some (_)]  => unreachable!()
38      }
39    } else {
40      let dd = d * d / S::four() + c * c * c / twentyseven;
41      if dd < S::zero() {
42        // n*d^2 + m*c^3 < 0 => c < 0
43        let aa = (-S::four() * c * one_third).sqrt();
44        let phi = (-S::four() * d / (aa * aa * aa)).acos() * one_third;
45        [ Some (aa * phi.cos()),
46          Some (aa * (phi + two_third_pi).cos()),
47          Some (aa * (phi - two_third_pi).cos())
48        ]
49      } else {
50        let sqrt_dd = dd.sqrt();
51        let d_div_2 = d * S::half();
52        let x1 = (sqrt_dd - d_div_2).cbrt() - (sqrt_dd + d_div_2).cbrt();
53        if dd == S::zero() {
54          // one simple and one double root
55          [Some (x1), Some (d_div_2), None]
56        } else {
57          // one real root
58          [Some (x1), None, None]
59        }
60      }
61    }
62  } else if a == S::one() {
63    // normalized cubic
64    // x^3 + bx^2 + cx + d = 0
65    let q = (S::three() * c - b * b) / S::nine();
66    let r = (S::nine() * b * c - twentyseven * d - S::two() * b * b * b) / fiftyfour;
67    let q3 = q * q * q;
68    let dd = q3 + r * r;
69    let b_div_3 = b * one_third;
70    if dd < S::zero() {
71      let phi_3 = (r / (-q3).sqrt()).acos() * one_third;
72      let sqrt_q_2 = S::two() * (-q).sqrt();
73      [ Some (sqrt_q_2 * phi_3.cos() - b_div_3),
74        Some (sqrt_q_2 * (phi_3 - S::two() * S::frac_pi_3()).cos() - b_div_3),
75        Some (sqrt_q_2 * (phi_3 + S::two() * S::frac_pi_3()).cos() - b_div_3)
76      ]
77    } else {
78      let sqrt_dd = dd.sqrt();
79      let s = (r + sqrt_dd).cbrt();
80      let t = (r - sqrt_dd).cbrt();
81      let root0 = s + t - b_div_3;
82      // TODO: approx eq ?
83      if s == t {
84        if s + t == S::zero() {
85          [Some (root0), None, None]
86        } else {
87          let root1 = -(s + t) * S::half() - b_div_3;
88          [Some (root0), Some (root1), None]
89        }
90      } else {
91        [Some (root0), None, None]
92      }
93    }
94  } else {
95    let dd = eighteen * a*b*c*d - S::four()*b*b*b*d + b*b*c*c - S::four()*a*c*c*c
96      - twentyseven*a*a*d*d;
97    let d0 = b*b - S::three()*a*c;
98    let d1 = S::two()*b*b*b - S::nine()*a*b*c + twentyseven*a*a*d;
99    if d < S::zero() {
100      // one real root
101      let sqrt = (-twentyseven * a * a * dd).sqrt();
102      let cc   = S::cbrt (if d1 < S::zero() { d1 - sqrt } else { d1 + sqrt } / S::two());
103      [Some ( -(b + cc + d0 / cc) / (S::three() * a)), None, None]
104    } else if d == S::zero() {
105      // multiple roots
106      if d0 == S::zero() {
107        // triple root
108        [Some (-b / (a * S::three())), None, None]
109      } else {
110        // simple root and double root
111        let simple_root = (S::nine() * a * d - b * c) / (d0 * S::two());
112        let double_root = (S::four() * a * b * c - S::nine() * a * a * d - b * b * b)
113          / (a * d0);
114        [Some (simple_root), Some (double_root), None]
115      }
116    } else {
117      // three real roots
118      // 1
119      let div_by_3a = S::one() / (S::three() * a);
120      let c3_img    = (twentyseven * a * a * dd).sqrt() / S::two();
121      let c3_real   = d1 * S::half();
122      let c3_module = (c3_img * c3_img + c3_real * c3_real).sqrt();
123      let c3_phase  = S::two() * (c3_img / (c3_real + c3_module)).atan();
124      let c_module  = c3_module.cbrt();
125      let c_phase   = c3_phase * one_third;
126      let c_real    = c_module * c_phase.cos();
127      let c_img     = c_module * c_phase.sin();
128      let x0_real   = -(b + c_real + (d0 * c_real) / (c_module * c_module)) * div_by_3a;
129      // 2
130      let e_real    = -S::half();
131      let e_img     = S::sqrt_3() * S::half();
132      let c1_real   = c_real * e_real - c_img * e_img;
133      let c1_img    = c_real * e_img + c_img * e_real;
134      let x1_real   =
135        -(b + c1_real + (d0 * c1_real) / (c1_real * c1_real + c1_img * c1_img))
136        * div_by_3a;
137      // 3
138      let c2_real = c1_real * e_real - c1_img * e_img;
139      let c2_img  = c1_real * e_img + c1_img * e_real;
140      let x2_real =
141        -(b + c2_real + (d0 * c2_real) / (c2_real * c2_real + c2_img * c2_img))
142        * div_by_3a;
143      [Some (x0_real), Some (x1_real), Some (x2_real)]
144    }
145  };
146  roots.sort_by (|a, b| match (a, b) {
147    (Some (a), Some (b)) => {
148      debug_assert_ne!(a, b); // TODO: is this possible ?
149      a.partial_cmp (b).unwrap()
150    }
151    (None,     Some (_)) => std::cmp::Ordering::Greater,
152    (Some (_), None)     => std::cmp::Ordering::Less,
153    (None,     None)     => std::cmp::Ordering::Equal
154  });
155  roots
156}
157
158/// Solve the quadratic equation `ax^2 + bx + c = 0` where `a != 0`. Returns zero, one,
159/// or two real roots.
160///
161/// Roots are returned in sorted order;
162pub fn solve_quadratic_equation <S> (a : S, b : S, c : S) -> [Option <S>; 2] where
163  S : OrderedField + Sqrt + std::fmt::Debug
164{
165  debug_assert_ne!(a, S::zero());
166  let discriminant = b * b - S::four() * a * c;
167  if discriminant < S::zero() {
168    // no real roots
169    [None, None]
170  } else if discriminant == S::zero() {
171    // single root
172    [Some (S::half() * (-b / a)), None]
173  } else {
174    // two roots
175    let discriminant_sqrt = discriminant.sqrt();
176    let div_a = S::one() / a;
177    let root0 = S::half() * div_a * (-b + discriminant_sqrt);
178    let root1 = S::half() * div_a * (-b - discriminant_sqrt);
179    debug_assert_ne!(root0, root1); // TODO: is this possible ?
180    if root0 < root1 {
181      [Some (root0), Some (root1)]
182    } else {
183      [Some (root1), Some (root0)]
184    }
185  }
186}
187
188#[cfg(test)]
189mod tests {
190  #![expect(clippy::unreadable_literal)]
191  use crate::approx;
192  use super::*;
193
194  #[test]
195  fn solve_cubic() {
196    // three distinct real roots
197    // x^3 - 6x^2 + 11x - 6 = 0
198    // roots: 1, 2, 3
199    let roots = solve_cubic_equation (1.0, -6.0, 11.0, -6.0);
200    assert_eq!(roots[0].unwrap(), 1.0);
201    approx::assert_relative_eq!(roots[1].unwrap(), 2.0);
202    assert_eq!(roots[2].unwrap(), 3.0);
203    // one real root
204    // x^3 + x + 1 = 0
205    // root: -0.68233
206    let roots = solve_cubic_equation (1.0, 0.0, 1.0, 1.0);
207    assert_eq!(roots, [Some (-0.6823278038280194), None, None]);
208    // triple root
209    // x^3 - 3x^2 + 3x - 1 = 0
210    // root: 1.0
211    let roots = solve_cubic_equation (1.0, -3.0, 3.0, -1.0);
212    assert_eq!(roots, [Some (1.0), None, None]);
213    // three distinct real roots
214    // 2x^3 - 4x^2 - 2x + 4 = 0
215    // roots: -1, 1, 2
216    let roots = solve_cubic_equation (2.0, -4.0, -2.0, 4.0);
217    approx::assert_relative_eq!(roots[0].unwrap(), -1.0);
218    approx::assert_relative_eq!(roots[1].unwrap(),  1.0);
219    assert_eq!(roots[2].unwrap(), 2.0);
220  }
221}