Skip to main content

la_stack/
matrix.rs

1//! Fixed-size, stack-allocated square matrices.
2
3use crate::LaError;
4use crate::ldlt::Ldlt;
5use crate::lu::Lu;
6
7/// Fixed-size square matrix `D×D`, stored inline.
8#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Matrix<const D: usize> {
11    pub(crate) rows: [[f64; D]; D],
12}
13
14impl<const D: usize> Matrix<D> {
15    /// Construct from row-major storage.
16    ///
17    /// # Examples
18    /// ```
19    /// use la_stack::prelude::*;
20    ///
21    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
22    /// assert_eq!(m.get(0, 1), Some(2.0));
23    /// ```
24    #[inline]
25    pub const fn from_rows(rows: [[f64; D]; D]) -> Self {
26        Self { rows }
27    }
28
29    /// All-zeros matrix.
30    ///
31    /// # Examples
32    /// ```
33    /// use la_stack::prelude::*;
34    ///
35    /// let z = Matrix::<2>::zero();
36    /// assert_eq!(z.get(1, 1), Some(0.0));
37    /// ```
38    #[inline]
39    pub const fn zero() -> Self {
40        Self {
41            rows: [[0.0; D]; D],
42        }
43    }
44
45    /// Identity matrix.
46    ///
47    /// # Examples
48    /// ```
49    /// use la_stack::prelude::*;
50    ///
51    /// let i = Matrix::<3>::identity();
52    /// assert_eq!(i.get(0, 0), Some(1.0));
53    /// assert_eq!(i.get(0, 1), Some(0.0));
54    /// assert_eq!(i.get(2, 2), Some(1.0));
55    /// ```
56    #[inline]
57    pub const fn identity() -> Self {
58        let mut m = Self::zero();
59
60        let mut i = 0;
61        while i < D {
62            m.rows[i][i] = 1.0;
63            i += 1;
64        }
65
66        m
67    }
68
69    /// Get an element with bounds checking.
70    ///
71    /// # Examples
72    /// ```
73    /// use la_stack::prelude::*;
74    ///
75    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
76    /// assert_eq!(m.get(1, 0), Some(3.0));
77    /// assert_eq!(m.get(2, 0), None);
78    /// ```
79    #[inline]
80    #[must_use]
81    pub const fn get(&self, r: usize, c: usize) -> Option<f64> {
82        if r < D && c < D {
83            Some(self.rows[r][c])
84        } else {
85            None
86        }
87    }
88
89    /// Set an element with bounds checking.
90    ///
91    /// Returns `true` if the index was in-bounds.
92    ///
93    /// # Examples
94    /// ```
95    /// use la_stack::prelude::*;
96    ///
97    /// let mut m = Matrix::<2>::zero();
98    /// assert!(m.set(0, 1, 2.5));
99    /// assert_eq!(m.get(0, 1), Some(2.5));
100    /// assert!(!m.set(10, 0, 1.0));
101    /// ```
102    #[inline]
103    pub const fn set(&mut self, r: usize, c: usize, value: f64) -> bool {
104        if r < D && c < D {
105            self.rows[r][c] = value;
106            true
107        } else {
108            false
109        }
110    }
111
112    /// Infinity norm (maximum absolute row sum).
113    ///
114    /// # Examples
115    /// ```
116    /// use la_stack::prelude::*;
117    ///
118    /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]);
119    /// assert!((m.inf_norm() - 7.0).abs() <= 1e-12);
120    /// ```
121    #[inline]
122    #[must_use]
123    pub fn inf_norm(&self) -> f64 {
124        let mut max_row_sum: f64 = 0.0;
125
126        for row in &self.rows {
127            let row_sum: f64 = row.iter().map(|&x| x.abs()).sum();
128            if row_sum > max_row_sum {
129                max_row_sum = row_sum;
130            }
131        }
132
133        max_row_sum
134    }
135
136    /// Compute an LU decomposition with partial pivoting.
137    ///
138    /// # Examples
139    /// ```
140    /// use la_stack::prelude::*;
141    ///
142    /// # fn main() -> Result<(), LaError> {
143    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
144    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
145    ///
146    /// let b = Vector::<2>::new([5.0, 11.0]);
147    /// let x = lu.solve_vec(b)?.into_array();
148    ///
149    /// assert!((x[0] - 1.0).abs() <= 1e-12);
150    /// assert!((x[1] - 2.0).abs() <= 1e-12);
151    /// # Ok(())
152    /// # }
153    /// ```
154    ///
155    /// # Errors
156    /// Returns [`LaError::Singular`] if, for some column `k`, the largest-magnitude candidate pivot
157    /// in that column satisfies `|pivot| <= tol` (so no numerically usable pivot exists).
158    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
159    #[inline]
160    pub fn lu(self, tol: f64) -> Result<Lu<D>, LaError> {
161        Lu::factor(self, tol)
162    }
163
164    /// Compute an LDLT factorization (`A = L D Lᵀ`) without pivoting.
165    ///
166    /// This is intended for symmetric positive definite (SPD) and positive semi-definite (PSD)
167    /// matrices such as Gram matrices.
168    ///
169    /// # Examples
170    /// ```
171    /// use la_stack::prelude::*;
172    ///
173    /// # fn main() -> Result<(), LaError> {
174    /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]);
175    /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?;
176    ///
177    /// // det(A) = 8
178    /// assert!((ldlt.det() - 8.0).abs() <= 1e-12);
179    ///
180    /// // Solve A x = b
181    /// let b = Vector::<2>::new([1.0, 2.0]);
182    /// let x = ldlt.solve_vec(b)?.into_array();
183    /// assert!((x[0] - (-0.125)).abs() <= 1e-12);
184    /// assert!((x[1] - 0.75).abs() <= 1e-12);
185    /// # Ok(())
186    /// # }
187    /// ```
188    ///
189    /// # Errors
190    /// Returns [`LaError::Singular`] if, for some step `k`, the required diagonal entry `d = D[k,k]`
191    /// is `<= tol` (non-positive or too small). This treats PSD degeneracy (and indefinite inputs)
192    /// as singular/degenerate.
193    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
194    #[inline]
195    pub fn ldlt(self, tol: f64) -> Result<Ldlt<D>, LaError> {
196        Ldlt::factor(self, tol)
197    }
198
199    /// Closed-form determinant for dimensions 0–4, bypassing LU factorization.
200    ///
201    /// Returns `Some(det)` for `D` ∈ {0, 1, 2, 3, 4}, `None` for D ≥ 5.
202    /// `D = 0` returns `Some(1.0)` (empty product).
203    /// This is a `const fn` (Rust 1.94+) and uses fused multiply-add (`mul_add`)
204    /// for improved accuracy and performance.
205    ///
206    /// For a determinant that works for any dimension (falling back to LU for D ≥ 5),
207    /// use [`det`](Self::det).
208    ///
209    /// # Examples
210    /// ```
211    /// use la_stack::prelude::*;
212    ///
213    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
214    /// assert!((m.det_direct().unwrap() - (-2.0)).abs() <= 1e-12);
215    ///
216    /// // D = 0 is the empty product.
217    /// assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0));
218    ///
219    /// // D ≥ 5 returns None.
220    /// assert!(Matrix::<5>::identity().det_direct().is_none());
221    /// ```
222    #[inline]
223    #[must_use]
224    pub const fn det_direct(&self) -> Option<f64> {
225        match D {
226            0 => Some(1.0),
227            1 => Some(self.rows[0][0]),
228            2 => {
229                // ad - bc
230                Some(self.rows[0][0].mul_add(self.rows[1][1], -(self.rows[0][1] * self.rows[1][0])))
231            }
232            3 => {
233                // Cofactor expansion on first row.
234                let m00 =
235                    self.rows[1][1].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][1]));
236                let m01 =
237                    self.rows[1][0].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][0]));
238                let m02 =
239                    self.rows[1][0].mul_add(self.rows[2][1], -(self.rows[1][1] * self.rows[2][0]));
240                Some(
241                    self.rows[0][0]
242                        .mul_add(m00, (-self.rows[0][1]).mul_add(m01, self.rows[0][2] * m02)),
243                )
244            }
245            4 => {
246                // Cofactor expansion on first row → four 3×3 sub-determinants.
247                // Hoist the 6 unique 2×2 minors from rows 2–3 (each used twice).
248                let r = &self.rows;
249
250                // 2×2 minors: s_ij = r[2][i]*r[3][j] - r[2][j]*r[3][i]
251                let s23 = r[2][2].mul_add(r[3][3], -(r[2][3] * r[3][2])); // cols 2,3
252                let s13 = r[2][1].mul_add(r[3][3], -(r[2][3] * r[3][1])); // cols 1,3
253                let s12 = r[2][1].mul_add(r[3][2], -(r[2][2] * r[3][1])); // cols 1,2
254                let s03 = r[2][0].mul_add(r[3][3], -(r[2][3] * r[3][0])); // cols 0,3
255                let s02 = r[2][0].mul_add(r[3][2], -(r[2][2] * r[3][0])); // cols 0,2
256                let s01 = r[2][0].mul_add(r[3][1], -(r[2][1] * r[3][0])); // cols 0,1
257
258                // 3×3 cofactors via row 1 expansion using hoisted minors.
259                let c00 = r[1][1].mul_add(s23, (-r[1][2]).mul_add(s13, r[1][3] * s12));
260                let c01 = r[1][0].mul_add(s23, (-r[1][2]).mul_add(s03, r[1][3] * s02));
261                let c02 = r[1][0].mul_add(s13, (-r[1][1]).mul_add(s03, r[1][3] * s01));
262                let c03 = r[1][0].mul_add(s12, (-r[1][1]).mul_add(s02, r[1][2] * s01));
263
264                Some(r[0][0].mul_add(
265                    c00,
266                    (-r[0][1]).mul_add(c01, r[0][2].mul_add(c02, -(r[0][3] * c03))),
267                ))
268            }
269            _ => None,
270        }
271    }
272
273    /// Determinant, using closed-form formulas for D ≤ 4 and LU decomposition for D ≥ 5.
274    ///
275    /// For D ∈ {1, 2, 3, 4}, this bypasses LU factorization entirely for a significant
276    /// speedup (see [`det_direct`](Self::det_direct)). The `tol` parameter is only used
277    /// by the LU fallback path for D ≥ 5.
278    ///
279    /// # Examples
280    /// ```
281    /// use la_stack::prelude::*;
282    ///
283    /// # fn main() -> Result<(), LaError> {
284    /// let det = Matrix::<3>::identity().det(DEFAULT_PIVOT_TOL)?;
285    /// assert!((det - 1.0).abs() <= 1e-12);
286    /// # Ok(())
287    /// # }
288    /// ```
289    ///
290    /// # Errors
291    /// Returns [`LaError::NonFinite`] if the result contains NaN or infinity.
292    /// For D ≥ 5, propagates LU factorization errors (e.g. [`LaError::Singular`]).
293    #[inline]
294    pub fn det(self, tol: f64) -> Result<f64, LaError> {
295        if let Some(d) = self.det_direct() {
296            return if d.is_finite() {
297                Ok(d)
298            } else {
299                Err(LaError::NonFinite { pivot_col: 0 })
300            };
301        }
302        self.lu(tol).map(|lu| lu.det())
303    }
304}
305
306impl<const D: usize> Default for Matrix<D> {
307    #[inline]
308    fn default() -> Self {
309        Self::zero()
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::DEFAULT_PIVOT_TOL;
317
318    use approx::assert_abs_diff_eq;
319    use pastey::paste;
320    use std::hint::black_box;
321
322    macro_rules! gen_public_api_matrix_tests {
323        ($d:literal) => {
324            paste! {
325                #[test]
326                fn [<public_api_matrix_from_rows_get_set_bounds_checked_ $d d>]() {
327                    let mut rows = [[0.0f64; $d]; $d];
328                    rows[0][0] = 1.0;
329                    rows[$d - 1][$d - 1] = -2.0;
330
331                    let mut m = Matrix::<$d>::from_rows(rows);
332
333                    assert_eq!(m.get(0, 0), Some(1.0));
334                    assert_eq!(m.get($d - 1, $d - 1), Some(-2.0));
335
336                    // Out-of-bounds is None.
337                    assert_eq!(m.get($d, 0), None);
338
339                    // Out-of-bounds set fails.
340                    assert!(!m.set($d, 0, 3.0));
341
342                    // In-bounds set works.
343                    assert!(m.set(0, $d - 1, 3.0));
344                    assert_eq!(m.get(0, $d - 1), Some(3.0));
345                }
346
347                #[test]
348                fn [<public_api_matrix_zero_and_default_are_zero_ $d d>]() {
349                    let z = Matrix::<$d>::zero();
350                    assert_abs_diff_eq!(z.inf_norm(), 0.0, epsilon = 0.0);
351
352                    let d = Matrix::<$d>::default();
353                    assert_abs_diff_eq!(d.inf_norm(), 0.0, epsilon = 0.0);
354                }
355
356                #[test]
357                fn [<public_api_matrix_inf_norm_max_row_sum_ $d d>]() {
358                    let mut rows = [[0.0f64; $d]; $d];
359
360                    // Row 0 has absolute row sum = D.
361                    for c in 0..$d {
362                        rows[0][c] = -1.0;
363                    }
364
365                    // Row 1 has smaller absolute row sum.
366                    for c in 0..$d {
367                        rows[1][c] = 0.5;
368                    }
369
370                    let m = Matrix::<$d>::from_rows(rows);
371                    assert_abs_diff_eq!(m.inf_norm(), f64::from($d), epsilon = 0.0);
372                }
373
374                #[test]
375                fn [<public_api_matrix_identity_lu_det_solve_vec_ $d d>]() {
376                    let m = Matrix::<$d>::identity();
377
378                    // Identity has ones on diag and zeros off diag.
379                    for r in 0..$d {
380                        for c in 0..$d {
381                            let expected = if r == c { 1.0 } else { 0.0 };
382                            assert_abs_diff_eq!(m.get(r, c).unwrap(), expected, epsilon = 0.0);
383                        }
384                    }
385
386                    // Determinant is 1.
387                    let det = m.det(DEFAULT_PIVOT_TOL).unwrap();
388                    assert_abs_diff_eq!(det, 1.0, epsilon = 1e-12);
389
390                    // LU solve on identity returns the RHS.
391                    let lu = m.lu(DEFAULT_PIVOT_TOL).unwrap();
392
393                    let b_arr = {
394                        let mut arr = [0.0f64; $d];
395                        let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
396                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
397                            *dst = *src;
398                        }
399                        arr
400                    };
401
402                    let b = crate::Vector::<$d>::new(b_arr);
403                    let x = lu.solve_vec(b).unwrap().into_array();
404
405                    for (x_i, b_i) in x.iter().zip(b_arr.iter()) {
406                        assert_abs_diff_eq!(*x_i, *b_i, epsilon = 1e-12);
407                    }
408                }
409            }
410        };
411    }
412
413    // Mirror delaunay-style multi-dimension tests.
414    gen_public_api_matrix_tests!(2);
415    gen_public_api_matrix_tests!(3);
416    gen_public_api_matrix_tests!(4);
417    gen_public_api_matrix_tests!(5);
418
419    // === det_direct tests ===
420
421    #[test]
422    fn det_direct_d0_is_one() {
423        assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0));
424    }
425
426    #[test]
427    fn det_direct_d1_returns_element() {
428        let m = Matrix::<1>::from_rows([[42.0]]);
429        assert_eq!(m.det_direct(), Some(42.0));
430    }
431
432    #[test]
433    fn det_direct_d2_known_value() {
434        // [[1,2],[3,4]] → det = 1*4 - 2*3 = -2
435        // black_box prevents compile-time constant folding of the const fn.
436        let m = black_box(Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]));
437        assert_abs_diff_eq!(m.det_direct().unwrap(), -2.0, epsilon = 1e-15);
438    }
439
440    #[test]
441    fn det_direct_d3_known_value() {
442        // Classic 3×3: det = 0
443        let m = black_box(Matrix::<3>::from_rows([
444            [1.0, 2.0, 3.0],
445            [4.0, 5.0, 6.0],
446            [7.0, 8.0, 9.0],
447        ]));
448        assert_abs_diff_eq!(m.det_direct().unwrap(), 0.0, epsilon = 1e-12);
449    }
450
451    #[test]
452    fn det_direct_d3_nonsingular() {
453        // [[2,1,0],[0,3,1],[1,0,2]] → det = 2*(6-0) - 1*(0-1) + 0 = 13
454        let m = black_box(Matrix::<3>::from_rows([
455            [2.0, 1.0, 0.0],
456            [0.0, 3.0, 1.0],
457            [1.0, 0.0, 2.0],
458        ]));
459        assert_abs_diff_eq!(m.det_direct().unwrap(), 13.0, epsilon = 1e-12);
460    }
461
462    #[test]
463    fn det_direct_d4_identity() {
464        let m = black_box(Matrix::<4>::identity());
465        assert_abs_diff_eq!(m.det_direct().unwrap(), 1.0, epsilon = 1e-15);
466    }
467
468    #[test]
469    fn det_direct_d4_known_value() {
470        // Diagonal matrix: det = product of diagonal entries.
471        let mut rows = [[0.0f64; 4]; 4];
472        rows[0][0] = 2.0;
473        rows[1][1] = 3.0;
474        rows[2][2] = 5.0;
475        rows[3][3] = 7.0;
476        let m = black_box(Matrix::<4>::from_rows(rows));
477        assert_abs_diff_eq!(m.det_direct().unwrap(), 210.0, epsilon = 1e-12);
478    }
479
480    #[test]
481    fn det_direct_d5_returns_none() {
482        assert_eq!(Matrix::<5>::identity().det_direct(), None);
483    }
484
485    #[test]
486    fn det_direct_d8_returns_none() {
487        assert_eq!(Matrix::<8>::zero().det_direct(), None);
488    }
489
490    macro_rules! gen_det_direct_agrees_with_lu {
491        ($d:literal) => {
492            paste! {
493                #[test]
494                #[allow(clippy::cast_precision_loss)] // r, c, D are tiny integers
495                fn [<det_direct_agrees_with_lu_ $d d>]() {
496                    // Well-conditioned matrix: diagonally dominant.
497                    let mut rows = [[0.0f64; $d]; $d];
498                    for r in 0..$d {
499                        for c in 0..$d {
500                            rows[r][c] = if r == c {
501                                (r as f64) + f64::from($d) + 1.0
502                            } else {
503                                0.1 / ((r + c + 1) as f64)
504                            };
505                        }
506                    }
507                    let m = Matrix::<$d>::from_rows(rows);
508                    let direct = m.det_direct().unwrap();
509                    let lu_det = m.lu(DEFAULT_PIVOT_TOL).unwrap().det();
510                    let eps = lu_det.abs().mul_add(1e-12, 1e-12);
511                    assert_abs_diff_eq!(direct, lu_det, epsilon = eps);
512                }
513            }
514        };
515    }
516
517    gen_det_direct_agrees_with_lu!(1);
518    gen_det_direct_agrees_with_lu!(2);
519    gen_det_direct_agrees_with_lu!(3);
520    gen_det_direct_agrees_with_lu!(4);
521
522    #[test]
523    fn det_direct_identity_all_dims() {
524        assert_abs_diff_eq!(
525            Matrix::<1>::identity().det_direct().unwrap(),
526            1.0,
527            epsilon = 0.0
528        );
529        assert_abs_diff_eq!(
530            Matrix::<2>::identity().det_direct().unwrap(),
531            1.0,
532            epsilon = 0.0
533        );
534        assert_abs_diff_eq!(
535            Matrix::<3>::identity().det_direct().unwrap(),
536            1.0,
537            epsilon = 0.0
538        );
539        assert_abs_diff_eq!(
540            Matrix::<4>::identity().det_direct().unwrap(),
541            1.0,
542            epsilon = 0.0
543        );
544    }
545
546    #[test]
547    fn det_direct_zero_matrix() {
548        assert_abs_diff_eq!(
549            Matrix::<2>::zero().det_direct().unwrap(),
550            0.0,
551            epsilon = 0.0
552        );
553        assert_abs_diff_eq!(
554            Matrix::<3>::zero().det_direct().unwrap(),
555            0.0,
556            epsilon = 0.0
557        );
558        assert_abs_diff_eq!(
559            Matrix::<4>::zero().det_direct().unwrap(),
560            0.0,
561            epsilon = 0.0
562        );
563    }
564
565    #[test]
566    fn det_returns_nonfinite_error_for_nan_d2() {
567        let m = Matrix::<2>::from_rows([[f64::NAN, 1.0], [1.0, 1.0]]);
568        assert_eq!(
569            m.det(DEFAULT_PIVOT_TOL),
570            Err(LaError::NonFinite { pivot_col: 0 })
571        );
572    }
573
574    #[test]
575    fn det_returns_nonfinite_error_for_inf_d3() {
576        let m =
577            Matrix::<3>::from_rows([[f64::INFINITY, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
578        assert_eq!(
579            m.det(DEFAULT_PIVOT_TOL),
580            Err(LaError::NonFinite { pivot_col: 0 })
581        );
582    }
583
584    #[test]
585    fn det_direct_is_const_evaluable_d2() {
586        // Const evaluation proves the function is truly const fn.
587        const DET: Option<f64> = {
588            let m = Matrix::<2>::from_rows([[1.0, 0.0], [0.0, 1.0]]);
589            m.det_direct()
590        };
591        assert_eq!(DET, Some(1.0));
592    }
593
594    #[test]
595    fn det_direct_is_const_evaluable_d3() {
596        const DET: Option<f64> = {
597            let m = Matrix::<3>::from_rows([[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 5.0]]);
598            m.det_direct()
599        };
600        assert_eq!(DET, Some(30.0));
601    }
602}