math_utils/algebra/
mod.rs1use crate::num_traits as num;
4use crate::traits::*;
5
6pub mod linear;
7
8pub fn solve_cubic_equation <S> (a : S, b : S, c : S, d : S) -> [Option <S>; 3] where
13 S : Real + Cbrt + num::NumCast + std::fmt::Debug
14{
15 debug_assert_ne!(a, S::zero());
18 let eighteen = S::two() * S::nine();
19 let twentyseven = S::three() * S::nine();
20 let fiftyfour = S::two() * twentyseven;
21 let one_third = S::one() / S::three();
22 let two_third_pi = S::two() * S::frac_pi_3();
23 let mut roots = if b == S::zero() {
25 let c = c / a;
28 let d = d / a;
29 if c == S::zero() {
30 [Some (-d.cbrt()), None, None]
31 } else if d == S::zero() {
32 let quadratic_roots = solve_quadratic_equation (S::one(), S::zero(), c);
33 match quadratic_roots {
34 [Some (r0), Some (r1)] => [Some (r0), Some (r1), Some (S::zero())],
35 [Some (r0), None] => [Some (r0), Some (S::zero()), None],
36 [None, None] => [Some (S::zero()), None, None],
37 [None, Some (_)] => unreachable!()
38 }
39 } else {
40 let dd = d * d / S::four() + c * c * c / twentyseven;
41 if dd < S::zero() {
42 let aa = (-S::four() * c * one_third).sqrt();
44 let phi = (-S::four() * d / (aa * aa * aa)).acos() * one_third;
45 [ Some (aa * phi.cos()),
46 Some (aa * (phi + two_third_pi).cos()),
47 Some (aa * (phi - two_third_pi).cos())
48 ]
49 } else {
50 let sqrt_dd = dd.sqrt();
51 let d_div_2 = d * S::half();
52 let x1 = (sqrt_dd - d_div_2).cbrt() - (sqrt_dd + d_div_2).cbrt();
53 if dd == S::zero() {
54 [Some (x1), Some (d_div_2), None]
56 } else {
57 [Some (x1), None, None]
59 }
60 }
61 }
62 } else if a == S::one() {
63 let q = (S::three() * c - b * b) / S::nine();
66 let r = (S::nine() * b * c - twentyseven * d - S::two() * b * b * b) / fiftyfour;
67 let q3 = q * q * q;
68 let dd = q3 + r * r;
69 let b_div_3 = b * one_third;
70 if dd < S::zero() {
71 let phi_3 = (r / (-q3).sqrt()).acos() * one_third;
72 let sqrt_q_2 = S::two() * (-q).sqrt();
73 [ Some (sqrt_q_2 * phi_3.cos() - b_div_3),
74 Some (sqrt_q_2 * (phi_3 - S::two() * S::frac_pi_3()).cos() - b_div_3),
75 Some (sqrt_q_2 * (phi_3 + S::two() * S::frac_pi_3()).cos() - b_div_3)
76 ]
77 } else {
78 let sqrt_dd = dd.sqrt();
79 let s = (r + sqrt_dd).cbrt();
80 let t = (r - sqrt_dd).cbrt();
81 let root0 = s + t - b_div_3;
82 if s == t {
84 if s + t == S::zero() {
85 [Some (root0), None, None]
86 } else {
87 let root1 = -(s + t) * S::half() - b_div_3;
88 [Some (root0), Some (root1), None]
89 }
90 } else {
91 [Some (root0), None, None]
92 }
93 }
94 } else {
95 let dd = eighteen * a*b*c*d - S::four()*b*b*b*d + b*b*c*c - S::four()*a*c*c*c
96 - twentyseven*a*a*d*d;
97 let d0 = b*b - S::three()*a*c;
98 let d1 = S::two()*b*b*b - S::nine()*a*b*c + twentyseven*a*a*d;
99 if d < S::zero() {
100 let sqrt = (-twentyseven * a * a * dd).sqrt();
102 let cc = S::cbrt (if d1 < S::zero() { d1 - sqrt } else { d1 + sqrt } / S::two());
103 [Some ( -(b + cc + d0 / cc) / (S::three() * a)), None, None]
104 } else if d == S::zero() {
105 if d0 == S::zero() {
107 [Some (-b / (a * S::three())), None, None]
109 } else {
110 let simple_root = (S::nine() * a * d - b * c) / (d0 * S::two());
112 let double_root = (S::four() * a * b * c - S::nine() * a * a * d - b * b * b)
113 / (a * d0);
114 [Some (simple_root), Some (double_root), None]
115 }
116 } else {
117 let div_by_3a = S::one() / (S::three() * a);
120 let c3_img = (twentyseven * a * a * dd).sqrt() / S::two();
121 let c3_real = d1 * S::half();
122 let c3_module = (c3_img * c3_img + c3_real * c3_real).sqrt();
123 let c3_phase = S::two() * (c3_img / (c3_real + c3_module)).atan();
124 let c_module = c3_module.cbrt();
125 let c_phase = c3_phase * one_third;
126 let c_real = c_module * c_phase.cos();
127 let c_img = c_module * c_phase.sin();
128 let x0_real = -(b + c_real + (d0 * c_real) / (c_module * c_module)) * div_by_3a;
129 let e_real = -S::half();
131 let e_img = S::sqrt_3() * S::half();
132 let c1_real = c_real * e_real - c_img * e_img;
133 let c1_img = c_real * e_img + c_img * e_real;
134 let x1_real =
135 -(b + c1_real + (d0 * c1_real) / (c1_real * c1_real + c1_img * c1_img))
136 * div_by_3a;
137 let c2_real = c1_real * e_real - c1_img * e_img;
139 let c2_img = c1_real * e_img + c1_img * e_real;
140 let x2_real =
141 -(b + c2_real + (d0 * c2_real) / (c2_real * c2_real + c2_img * c2_img))
142 * div_by_3a;
143 [Some (x0_real), Some (x1_real), Some (x2_real)]
144 }
145 };
146 roots.sort_by (|a, b| match (a, b) {
147 (Some (a), Some (b)) => {
148 debug_assert_ne!(a, b); a.partial_cmp (b).unwrap()
150 }
151 (None, Some (_)) => std::cmp::Ordering::Greater,
152 (Some (_), None) => std::cmp::Ordering::Less,
153 (None, None) => std::cmp::Ordering::Equal
154 });
155 roots
156}
157
158pub fn solve_quadratic_equation <S> (a : S, b : S, c : S) -> [Option <S>; 2] where
163 S : OrderedField + Sqrt + std::fmt::Debug
164{
165 debug_assert_ne!(a, S::zero());
166 let discriminant = b * b - S::four() * a * c;
167 if discriminant < S::zero() {
168 [None, None]
170 } else if discriminant == S::zero() {
171 [Some (S::half() * (-b / a)), None]
173 } else {
174 let discriminant_sqrt = discriminant.sqrt();
176 let div_a = S::one() / a;
177 let root0 = S::half() * div_a * (-b + discriminant_sqrt);
178 let root1 = S::half() * div_a * (-b - discriminant_sqrt);
179 debug_assert_ne!(root0, root1); if root0 < root1 {
181 [Some (root0), Some (root1)]
182 } else {
183 [Some (root1), Some (root0)]
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 #![expect(clippy::unreadable_literal)]
191 use crate::approx;
192 use super::*;
193
194 #[test]
195 fn solve_cubic() {
196 let roots = solve_cubic_equation (1.0, -6.0, 11.0, -6.0);
200 assert_eq!(roots[0].unwrap(), 1.0);
201 approx::assert_relative_eq!(roots[1].unwrap(), 2.0);
202 assert_eq!(roots[2].unwrap(), 3.0);
203 let roots = solve_cubic_equation (1.0, 0.0, 1.0, 1.0);
207 assert_eq!(roots, [Some (-0.6823278038280194), None, None]);
208 let roots = solve_cubic_equation (1.0, -3.0, 3.0, -1.0);
212 assert_eq!(roots, [Some (1.0), None, None]);
213 let roots = solve_cubic_equation (2.0, -4.0, -2.0, 4.0);
217 approx::assert_relative_eq!(roots[0].unwrap(), -1.0);
218 approx::assert_relative_eq!(roots[1].unwrap(), 1.0);
219 assert_eq!(roots[2].unwrap(), 2.0);
220 }
221}