la_stack/
matrix.rs

1//! Fixed-size, stack-allocated square matrices.
2
3use crate::LaError;
4use crate::lu::Lu;
5
6/// Fixed-size square matrix `D×D`, stored inline.
7#[must_use]
8#[derive(Clone, Copy, Debug, PartialEq)]
9pub struct Matrix<const D: usize> {
10    pub(crate) rows: [[f64; D]; D],
11}
12
13impl<const D: usize> Matrix<D> {
14    /// Construct from row-major storage.
15    ///
16    /// # Examples
17    /// ```
18    /// use la_stack::prelude::*;
19    ///
20    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
21    /// assert_eq!(m.get(0, 1), Some(2.0));
22    /// ```
23    #[inline]
24    pub const fn from_rows(rows: [[f64; D]; D]) -> Self {
25        Self { rows }
26    }
27
28    /// All-zeros matrix.
29    ///
30    /// # Examples
31    /// ```
32    /// use la_stack::prelude::*;
33    ///
34    /// let z = Matrix::<2>::zero();
35    /// assert_eq!(z.get(1, 1), Some(0.0));
36    /// ```
37    #[inline]
38    pub const fn zero() -> Self {
39        Self {
40            rows: [[0.0; D]; D],
41        }
42    }
43
44    /// Identity matrix.
45    ///
46    /// # Examples
47    /// ```
48    /// use la_stack::prelude::*;
49    ///
50    /// let i = Matrix::<3>::identity();
51    /// assert_eq!(i.get(0, 0), Some(1.0));
52    /// assert_eq!(i.get(0, 1), Some(0.0));
53    /// assert_eq!(i.get(2, 2), Some(1.0));
54    /// ```
55    #[inline]
56    pub const fn identity() -> Self {
57        let mut m = Self::zero();
58
59        let mut i = 0;
60        while i < D {
61            m.rows[i][i] = 1.0;
62            i += 1;
63        }
64
65        m
66    }
67
68    /// Get an element with bounds checking.
69    ///
70    /// # Examples
71    /// ```
72    /// use la_stack::prelude::*;
73    ///
74    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
75    /// assert_eq!(m.get(1, 0), Some(3.0));
76    /// assert_eq!(m.get(2, 0), None);
77    /// ```
78    #[inline]
79    #[must_use]
80    pub const fn get(&self, r: usize, c: usize) -> Option<f64> {
81        if r < D && c < D {
82            Some(self.rows[r][c])
83        } else {
84            None
85        }
86    }
87
88    /// Set an element with bounds checking.
89    ///
90    /// Returns `true` if the index was in-bounds.
91    ///
92    /// # Examples
93    /// ```
94    /// use la_stack::prelude::*;
95    ///
96    /// let mut m = Matrix::<2>::zero();
97    /// assert!(m.set(0, 1, 2.5));
98    /// assert_eq!(m.get(0, 1), Some(2.5));
99    /// assert!(!m.set(10, 0, 1.0));
100    /// ```
101    #[inline]
102    pub const fn set(&mut self, r: usize, c: usize, value: f64) -> bool {
103        if r < D && c < D {
104            self.rows[r][c] = value;
105            true
106        } else {
107            false
108        }
109    }
110
111    /// Infinity norm (maximum absolute row sum).
112    ///
113    /// # Examples
114    /// ```
115    /// use la_stack::prelude::*;
116    ///
117    /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]);
118    /// assert!((m.inf_norm() - 7.0).abs() <= 1e-12);
119    /// ```
120    #[inline]
121    #[must_use]
122    pub fn inf_norm(&self) -> f64 {
123        let mut max_row_sum: f64 = 0.0;
124
125        for row in &self.rows {
126            let row_sum: f64 = row.iter().map(|&x| x.abs()).sum();
127            if row_sum > max_row_sum {
128                max_row_sum = row_sum;
129            }
130        }
131
132        max_row_sum
133    }
134
135    /// Compute an LU decomposition with partial pivoting.
136    ///
137    /// # Examples
138    /// ```
139    /// use la_stack::prelude::*;
140    ///
141    /// # fn main() -> Result<(), LaError> {
142    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
143    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
144    ///
145    /// let b = Vector::<2>::new([5.0, 11.0]);
146    /// let x = lu.solve_vec(b)?.into_array();
147    ///
148    /// assert!((x[0] - 1.0).abs() <= 1e-12);
149    /// assert!((x[1] - 2.0).abs() <= 1e-12);
150    /// # Ok(())
151    /// # }
152    /// ```
153    ///
154    /// # Errors
155    /// Returns [`LaError::Singular`] if no suitable pivot (|pivot| > `tol`) exists for a column.
156    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
157    #[inline]
158    pub fn lu(self, tol: f64) -> Result<Lu<D>, LaError> {
159        Lu::factor(self, tol)
160    }
161
162    /// Determinant computed via LU decomposition.
163    ///
164    /// # Examples
165    /// ```
166    /// use la_stack::prelude::*;
167    ///
168    /// # fn main() -> Result<(), LaError> {
169    /// let det = Matrix::<3>::identity().det(DEFAULT_PIVOT_TOL)?;
170    /// assert!((det - 1.0).abs() <= 1e-12);
171    /// # Ok(())
172    /// # }
173    /// ```
174    ///
175    /// # Errors
176    /// Propagates LU factorization errors (e.g. singular matrices).
177    #[inline]
178    pub fn det(self, tol: f64) -> Result<f64, LaError> {
179        self.lu(tol).map(|lu| lu.det())
180    }
181}
182
183impl<const D: usize> Default for Matrix<D> {
184    #[inline]
185    fn default() -> Self {
186        Self::zero()
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::DEFAULT_PIVOT_TOL;
194
195    use approx::assert_abs_diff_eq;
196    use pastey::paste;
197
198    macro_rules! gen_public_api_matrix_tests {
199        ($d:literal) => {
200            paste! {
201                #[test]
202                fn [<public_api_matrix_from_rows_get_set_bounds_checked_ $d d>]() {
203                    let mut rows = [[0.0f64; $d]; $d];
204                    rows[0][0] = 1.0;
205                    rows[$d - 1][$d - 1] = -2.0;
206
207                    let mut m = Matrix::<$d>::from_rows(rows);
208
209                    assert_eq!(m.get(0, 0), Some(1.0));
210                    assert_eq!(m.get($d - 1, $d - 1), Some(-2.0));
211
212                    // Out-of-bounds is None.
213                    assert_eq!(m.get($d, 0), None);
214
215                    // Out-of-bounds set fails.
216                    assert!(!m.set($d, 0, 3.0));
217
218                    // In-bounds set works.
219                    assert!(m.set(0, $d - 1, 3.0));
220                    assert_eq!(m.get(0, $d - 1), Some(3.0));
221                }
222
223                #[test]
224                fn [<public_api_matrix_zero_and_default_are_zero_ $d d>]() {
225                    let z = Matrix::<$d>::zero();
226                    assert_abs_diff_eq!(z.inf_norm(), 0.0, epsilon = 0.0);
227
228                    let d = Matrix::<$d>::default();
229                    assert_abs_diff_eq!(d.inf_norm(), 0.0, epsilon = 0.0);
230                }
231
232                #[test]
233                fn [<public_api_matrix_inf_norm_max_row_sum_ $d d>]() {
234                    let mut rows = [[0.0f64; $d]; $d];
235
236                    // Row 0 has absolute row sum = D.
237                    for c in 0..$d {
238                        rows[0][c] = -1.0;
239                    }
240
241                    // Row 1 has smaller absolute row sum.
242                    for c in 0..$d {
243                        rows[1][c] = 0.5;
244                    }
245
246                    let m = Matrix::<$d>::from_rows(rows);
247                    assert_abs_diff_eq!(m.inf_norm(), f64::from($d), epsilon = 0.0);
248                }
249
250                #[test]
251                fn [<public_api_matrix_identity_lu_det_solve_vec_ $d d>]() {
252                    let m = Matrix::<$d>::identity();
253
254                    // Identity has ones on diag and zeros off diag.
255                    for r in 0..$d {
256                        for c in 0..$d {
257                            let expected = if r == c { 1.0 } else { 0.0 };
258                            assert_abs_diff_eq!(m.get(r, c).unwrap(), expected, epsilon = 0.0);
259                        }
260                    }
261
262                    // Determinant is 1.
263                    let det = m.det(DEFAULT_PIVOT_TOL).unwrap();
264                    assert_abs_diff_eq!(det, 1.0, epsilon = 1e-12);
265
266                    // LU solve on identity returns the RHS.
267                    let lu = m.lu(DEFAULT_PIVOT_TOL).unwrap();
268
269                    let b_arr = {
270                        let mut arr = [0.0f64; $d];
271                        let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
272                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
273                            *dst = *src;
274                        }
275                        arr
276                    };
277
278                    let b = crate::Vector::<$d>::new(b_arr);
279                    let x = lu.solve_vec(b).unwrap().into_array();
280
281                    for (x_i, b_i) in x.iter().zip(b_arr.iter()) {
282                        assert_abs_diff_eq!(*x_i, *b_i, epsilon = 1e-12);
283                    }
284                }
285            }
286        };
287    }
288
289    // Mirror delaunay-style multi-dimension tests.
290    gen_public_api_matrix_tests!(2);
291    gen_public_api_matrix_tests!(3);
292    gen_public_api_matrix_tests!(4);
293    gen_public_api_matrix_tests!(5);
294}