subsphere/
math.rs

1//! This module contains a minimal set of linear algebra types and functions used by this crate. It
2//! has a very specific grab bag of functionality and is not intended to be publicly exposed for
3//! general use.
4pub use std::f64::consts::PI;
5
6/// Contains functions related to vectors.
7pub(crate) mod vec {
8    /// Negates a vector.
9    #[inline]
10    pub fn neg<const N: usize>(a: [f64; N]) -> [f64; N] {
11        let mut res = [0.0; N];
12        for i in 0..N {
13            res[i] = -a[i];
14        }
15        res
16    }
17
18    /// Adds `a` to `b`.
19    #[inline]
20    pub const fn add<const N: usize>(a: [f64; N], b: [f64; N]) -> [f64; N] {
21        let mut res = [0.0; N];
22        let mut i = 0;
23        while i < N {
24            res[i] = a[i] + b[i];
25            i += 1;
26        }
27        res
28    }
29
30    /// Subtracts `b` from `a`.
31    #[inline]
32    pub fn sub<const N: usize>(a: [f64; N], b: [f64; N]) -> [f64; N] {
33        let mut res = [0.0; N];
34        for i in 0..N {
35            res[i] = a[i] - b[i];
36        }
37        res
38    }
39
40    /// Multiplies a vector by a scalar.
41    #[inline]
42    pub const fn mul<const N: usize>(a: [f64; N], b: f64) -> [f64; N] {
43        let mut res = [0.0; N];
44        let mut i = 0;
45        while i < N {
46            res[i] = a[i] * b;
47            i += 1;
48        }
49        res
50    }
51
52    /// Divides a vector by a scalar.
53    #[inline]
54    pub fn div<const N: usize>(a: [f64; N], b: f64) -> [f64; N] {
55        let mut res = [0.0; N];
56        for i in 0..N {
57            res[i] = a[i] / b;
58        }
59        res
60    }
61
62    /// Computes the dot product of two vectors.
63    #[inline]
64    pub const fn dot<const N: usize>(a: [f64; N], b: [f64; N]) -> f64 {
65        let mut res = 0.0;
66        let mut i = 0;
67        while i < N {
68            res += a[i] * b[i];
69            i += 1;
70        }
71        res
72    }
73
74    /// Computes the cross product of two vectors.
75    #[inline]
76    pub fn cross(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
77        [
78            a[1] * b[2] - a[2] * b[1],
79            a[2] * b[0] - a[0] * b[2],
80            a[0] * b[1] - a[1] * b[0],
81        ]
82    }
83
84    /// Normalizes a vector.
85    #[inline]
86    pub fn normalize<const N: usize>(a: [f64; N]) -> [f64; N] {
87        div(a, dot(a, a).sqrt())
88    }
89}
90
91/// Contains functions related to matrices.
92pub(crate) mod mat {
93    use super::*;
94
95    /// Multiplies a matrix with a vector.
96    #[inline]
97    pub fn apply<const N: usize, const M: usize>(mat: [[f64; N]; M], vec: [f64; M]) -> [f64; N] {
98        let mut res = [0.0; N];
99        for i in 0..M {
100            res = vec::add(res, vec::mul(mat[i], vec[i]));
101        }
102        res
103    }
104
105    /// Computes the determinant of a 3x3 matrix.
106    #[inline]
107    pub fn det_3(m: [[f64; 3]; 3]) -> f64 {
108        vec::dot(m[0], vec::cross(m[1], m[2]))
109    }
110
111    /// Computes the determinant of a 2x2 matrix.
112    #[inline]
113    pub fn det_2(m: [[f64; 2]; 2]) -> f64 {
114        m[0][0] * m[1][1] - m[0][1] * m[1][0]
115    }
116
117    /// Computes the adjoint of a 2x2 matrix.
118    #[inline]
119    pub fn adjoint_2(m: [[f64; 2]; 2]) -> [[f64; 2]; 2] {
120        let [[a, b], [c, d]] = m;
121        [[d, -b], [-c, a]]
122    }
123}
124
125/// Determines the real roots of a cubic polynomial of the form `a x³ + b x² + c x + d`.
126pub fn solve_cubic(a: f64, b: f64, c: f64, d: f64) -> impl Iterator<Item = f64> {
127    // Handle quadratic case explicitly.
128    const SMALL: f64 = 1.0e-6;
129    // TODO: This is pretty hacky. Is there a numerically stable continuation of the cubic
130    // formula for small leading coefficients?
131    if a.abs() < SMALL {
132        let disc = c * c - 4.0 * b * d;
133        if disc >= 0.0 {
134            let u = -c - c.signum() * disc.sqrt();
135            let mut iter = [0.0, u / (2.0 * b), 2.0 * d / u].into_iter();
136            iter.next().unwrap();
137            if b.abs() < SMALL {
138                iter.next().unwrap();
139            }
140            return iter;
141        } else {
142            let mut iter = [0.0, 0.0, 0.0].into_iter();
143            iter.next().unwrap();
144            iter.next().unwrap();
145            iter.next().unwrap();
146            return iter;
147        }
148    }
149
150    // See https://mathworld.wolfram.com/CubicFormula.html for derivation
151    let b_i_3a = (b / 3.0) / a;
152    let c_i_a = c / a;
153    let d_i_a = d / a;
154
155    // Let `x = t - b / 3 a`. Then the equation becomes `t³ + 3 q t - 2 r = 0` where:
156    let q = c_i_a / 3.0 - b_i_3a * b_i_3a;
157    let r = (b_i_3a * c_i_a - d_i_a) / 2.0 - b_i_3a * b_i_3a * b_i_3a;
158    let disc = q * q * q + r * r;
159    if disc >= 0.0 {
160        // Equation has one real
161        let w = (r + disc.sqrt()).cbrt();
162        let t = w - q / w;
163        let mut iter = [0.0, 0.0, t - b_i_3a].into_iter();
164        iter.next().unwrap();
165        iter.next().unwrap();
166        iter
167    } else {
168        // Equation has three real roots
169        let h = (r * r - disc).sqrt();
170        let s_r = h.cbrt();
171        // TODO: Evaluation of `acos` and `cos` shouldn't actually be necessary here.
172        let s_theta_0 = (r / h).acos() / 3.0;
173        let t_0 = 2.0 * s_r * (s_theta_0 + 2.0 * PI / 3.0).cos();
174        let x_0 = t_0 - b_i_3a;
175        let t_1 = 2.0 * s_r * (s_theta_0 - 2.0 * PI / 3.0).cos();
176        let x_1 = t_1 - b_i_3a;
177        let t_2 = 2.0 * s_r * s_theta_0.cos();
178        let x_2 = t_2 - b_i_3a;
179        [x_0, x_1, x_2].into_iter()
180    }
181}
182
183#[test]
184fn test_solve_cubic() {
185    // Cubic cases
186    assert_similar(solve_cubic(1.0, -1.0, 1.0, -1.0).collect(), [1.0]);
187    assert_similar(solve_cubic(1.0, 0.0, 0.0, -27.0).collect(), [3.0]);
188    assert_similar(solve_cubic(8.0, 8.0, 0.0, -3.0).collect(), [0.5]);
189    assert_similar(
190        solve_cubic(1.0, 4.0, 0.0, -5.0).collect(),
191        [
192            -(5.0f64.sqrt() + 5.0) / 2.0,
193            (5.0f64.sqrt() - 5.0) / 2.0,
194            1.0,
195        ],
196    );
197    assert_similar(
198        solve_cubic(1.0, -2.0, 0.0, 1.0).collect(),
199        [
200            (1.0 - 5.0f64.sqrt()) / 2.0,
201            1.0,
202            (1.0 + 5.0f64.sqrt()) / 2.0,
203        ],
204    );
205
206    // Quadratic cases
207    assert_similar(solve_cubic(0.0, 1.0, 0.0, -1.0).collect(), [-1.0, 1.0]);
208    assert_similar(solve_cubic(0.0, 1.0, 0.0, 1.0).collect(), []);
209
210    // Linear cases
211    assert_similar(solve_cubic(0.0, 0.0, 1.0, -1.0).collect(), [1.0]);
212}
213
214#[cfg(test)]
215fn assert_similar<const N: usize>(a: Vec<f64>, b: [f64; N]) {
216    if a.len() != b.len() || a.iter().zip(b.iter()).any(|(x, y)| (x - y).abs() > 1e-6) {
217        panic!("vectors are not similar, a = {:?}, b = {:?}", a, b);
218    }
219}