la_stack/
matrix.rs

1//! Fixed-size, stack-allocated square matrices.
2
3use crate::LaError;
4use crate::ldlt::Ldlt;
5use crate::lu::Lu;
6
7/// Fixed-size square matrix `D×D`, stored inline.
8#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Matrix<const D: usize> {
11    pub(crate) rows: [[f64; D]; D],
12}
13
14impl<const D: usize> Matrix<D> {
15    /// Construct from row-major storage.
16    ///
17    /// # Examples
18    /// ```
19    /// use la_stack::prelude::*;
20    ///
21    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
22    /// assert_eq!(m.get(0, 1), Some(2.0));
23    /// ```
24    #[inline]
25    pub const fn from_rows(rows: [[f64; D]; D]) -> Self {
26        Self { rows }
27    }
28
29    /// All-zeros matrix.
30    ///
31    /// # Examples
32    /// ```
33    /// use la_stack::prelude::*;
34    ///
35    /// let z = Matrix::<2>::zero();
36    /// assert_eq!(z.get(1, 1), Some(0.0));
37    /// ```
38    #[inline]
39    pub const fn zero() -> Self {
40        Self {
41            rows: [[0.0; D]; D],
42        }
43    }
44
45    /// Identity matrix.
46    ///
47    /// # Examples
48    /// ```
49    /// use la_stack::prelude::*;
50    ///
51    /// let i = Matrix::<3>::identity();
52    /// assert_eq!(i.get(0, 0), Some(1.0));
53    /// assert_eq!(i.get(0, 1), Some(0.0));
54    /// assert_eq!(i.get(2, 2), Some(1.0));
55    /// ```
56    #[inline]
57    pub const fn identity() -> Self {
58        let mut m = Self::zero();
59
60        let mut i = 0;
61        while i < D {
62            m.rows[i][i] = 1.0;
63            i += 1;
64        }
65
66        m
67    }
68
69    /// Get an element with bounds checking.
70    ///
71    /// # Examples
72    /// ```
73    /// use la_stack::prelude::*;
74    ///
75    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
76    /// assert_eq!(m.get(1, 0), Some(3.0));
77    /// assert_eq!(m.get(2, 0), None);
78    /// ```
79    #[inline]
80    #[must_use]
81    pub const fn get(&self, r: usize, c: usize) -> Option<f64> {
82        if r < D && c < D {
83            Some(self.rows[r][c])
84        } else {
85            None
86        }
87    }
88
89    /// Set an element with bounds checking.
90    ///
91    /// Returns `true` if the index was in-bounds.
92    ///
93    /// # Examples
94    /// ```
95    /// use la_stack::prelude::*;
96    ///
97    /// let mut m = Matrix::<2>::zero();
98    /// assert!(m.set(0, 1, 2.5));
99    /// assert_eq!(m.get(0, 1), Some(2.5));
100    /// assert!(!m.set(10, 0, 1.0));
101    /// ```
102    #[inline]
103    pub const fn set(&mut self, r: usize, c: usize, value: f64) -> bool {
104        if r < D && c < D {
105            self.rows[r][c] = value;
106            true
107        } else {
108            false
109        }
110    }
111
112    /// Infinity norm (maximum absolute row sum).
113    ///
114    /// # Examples
115    /// ```
116    /// use la_stack::prelude::*;
117    ///
118    /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]);
119    /// assert!((m.inf_norm() - 7.0).abs() <= 1e-12);
120    /// ```
121    #[inline]
122    #[must_use]
123    pub fn inf_norm(&self) -> f64 {
124        let mut max_row_sum: f64 = 0.0;
125
126        for row in &self.rows {
127            let row_sum: f64 = row.iter().map(|&x| x.abs()).sum();
128            if row_sum > max_row_sum {
129                max_row_sum = row_sum;
130            }
131        }
132
133        max_row_sum
134    }
135
136    /// Compute an LU decomposition with partial pivoting.
137    ///
138    /// # Examples
139    /// ```
140    /// use la_stack::prelude::*;
141    ///
142    /// # fn main() -> Result<(), LaError> {
143    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
144    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
145    ///
146    /// let b = Vector::<2>::new([5.0, 11.0]);
147    /// let x = lu.solve_vec(b)?.into_array();
148    ///
149    /// assert!((x[0] - 1.0).abs() <= 1e-12);
150    /// assert!((x[1] - 2.0).abs() <= 1e-12);
151    /// # Ok(())
152    /// # }
153    /// ```
154    ///
155    /// # Errors
156    /// Returns [`LaError::Singular`] if, for some column `k`, the largest-magnitude candidate pivot
157    /// in that column satisfies `|pivot| <= tol` (so no numerically usable pivot exists).
158    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
159    #[inline]
160    pub fn lu(self, tol: f64) -> Result<Lu<D>, LaError> {
161        Lu::factor(self, tol)
162    }
163
164    /// Compute an LDLT factorization (`A = L D Lᵀ`) without pivoting.
165    ///
166    /// This is intended for symmetric positive definite (SPD) and positive semi-definite (PSD)
167    /// matrices such as Gram matrices.
168    ///
169    /// # Examples
170    /// ```
171    /// use la_stack::prelude::*;
172    ///
173    /// # fn main() -> Result<(), LaError> {
174    /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]);
175    /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?;
176    ///
177    /// // det(A) = 8
178    /// assert!((ldlt.det() - 8.0).abs() <= 1e-12);
179    ///
180    /// // Solve A x = b
181    /// let b = Vector::<2>::new([1.0, 2.0]);
182    /// let x = ldlt.solve_vec(b)?.into_array();
183    /// assert!((x[0] - (-0.125)).abs() <= 1e-12);
184    /// assert!((x[1] - 0.75).abs() <= 1e-12);
185    /// # Ok(())
186    /// # }
187    /// ```
188    ///
189    /// # Errors
190    /// Returns [`LaError::Singular`] if, for some step `k`, the required diagonal entry `d = D[k,k]`
191    /// is `<= tol` (non-positive or too small). This treats PSD degeneracy (and indefinite inputs)
192    /// as singular/degenerate.
193    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
194    #[inline]
195    pub fn ldlt(self, tol: f64) -> Result<Ldlt<D>, LaError> {
196        Ldlt::factor(self, tol)
197    }
198
199    /// Determinant computed via LU decomposition.
200    ///
201    /// # Examples
202    /// ```
203    /// use la_stack::prelude::*;
204    ///
205    /// # fn main() -> Result<(), LaError> {
206    /// let det = Matrix::<3>::identity().det(DEFAULT_PIVOT_TOL)?;
207    /// assert!((det - 1.0).abs() <= 1e-12);
208    /// # Ok(())
209    /// # }
210    /// ```
211    ///
212    /// # Errors
213    /// Propagates LU factorization errors (e.g. singular matrices).
214    #[inline]
215    pub fn det(self, tol: f64) -> Result<f64, LaError> {
216        self.lu(tol).map(|lu| lu.det())
217    }
218}
219
220impl<const D: usize> Default for Matrix<D> {
221    #[inline]
222    fn default() -> Self {
223        Self::zero()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::DEFAULT_PIVOT_TOL;
231
232    use approx::assert_abs_diff_eq;
233    use pastey::paste;
234
235    macro_rules! gen_public_api_matrix_tests {
236        ($d:literal) => {
237            paste! {
238                #[test]
239                fn [<public_api_matrix_from_rows_get_set_bounds_checked_ $d d>]() {
240                    let mut rows = [[0.0f64; $d]; $d];
241                    rows[0][0] = 1.0;
242                    rows[$d - 1][$d - 1] = -2.0;
243
244                    let mut m = Matrix::<$d>::from_rows(rows);
245
246                    assert_eq!(m.get(0, 0), Some(1.0));
247                    assert_eq!(m.get($d - 1, $d - 1), Some(-2.0));
248
249                    // Out-of-bounds is None.
250                    assert_eq!(m.get($d, 0), None);
251
252                    // Out-of-bounds set fails.
253                    assert!(!m.set($d, 0, 3.0));
254
255                    // In-bounds set works.
256                    assert!(m.set(0, $d - 1, 3.0));
257                    assert_eq!(m.get(0, $d - 1), Some(3.0));
258                }
259
260                #[test]
261                fn [<public_api_matrix_zero_and_default_are_zero_ $d d>]() {
262                    let z = Matrix::<$d>::zero();
263                    assert_abs_diff_eq!(z.inf_norm(), 0.0, epsilon = 0.0);
264
265                    let d = Matrix::<$d>::default();
266                    assert_abs_diff_eq!(d.inf_norm(), 0.0, epsilon = 0.0);
267                }
268
269                #[test]
270                fn [<public_api_matrix_inf_norm_max_row_sum_ $d d>]() {
271                    let mut rows = [[0.0f64; $d]; $d];
272
273                    // Row 0 has absolute row sum = D.
274                    for c in 0..$d {
275                        rows[0][c] = -1.0;
276                    }
277
278                    // Row 1 has smaller absolute row sum.
279                    for c in 0..$d {
280                        rows[1][c] = 0.5;
281                    }
282
283                    let m = Matrix::<$d>::from_rows(rows);
284                    assert_abs_diff_eq!(m.inf_norm(), f64::from($d), epsilon = 0.0);
285                }
286
287                #[test]
288                fn [<public_api_matrix_identity_lu_det_solve_vec_ $d d>]() {
289                    let m = Matrix::<$d>::identity();
290
291                    // Identity has ones on diag and zeros off diag.
292                    for r in 0..$d {
293                        for c in 0..$d {
294                            let expected = if r == c { 1.0 } else { 0.0 };
295                            assert_abs_diff_eq!(m.get(r, c).unwrap(), expected, epsilon = 0.0);
296                        }
297                    }
298
299                    // Determinant is 1.
300                    let det = m.det(DEFAULT_PIVOT_TOL).unwrap();
301                    assert_abs_diff_eq!(det, 1.0, epsilon = 1e-12);
302
303                    // LU solve on identity returns the RHS.
304                    let lu = m.lu(DEFAULT_PIVOT_TOL).unwrap();
305
306                    let b_arr = {
307                        let mut arr = [0.0f64; $d];
308                        let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
309                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
310                            *dst = *src;
311                        }
312                        arr
313                    };
314
315                    let b = crate::Vector::<$d>::new(b_arr);
316                    let x = lu.solve_vec(b).unwrap().into_array();
317
318                    for (x_i, b_i) in x.iter().zip(b_arr.iter()) {
319                        assert_abs_diff_eq!(*x_i, *b_i, epsilon = 1e-12);
320                    }
321                }
322            }
323        };
324    }
325
326    // Mirror delaunay-style multi-dimension tests.
327    gen_public_api_matrix_tests!(2);
328    gen_public_api_matrix_tests!(3);
329    gen_public_api_matrix_tests!(4);
330    gen_public_api_matrix_tests!(5);
331}