la_stack/
matrix.rs

1//! Fixed-size, stack-allocated square matrices.
2
3use crate::LaError;
4use crate::lu::Lu;
5
6/// Fixed-size square matrix `D×D`, stored inline.
7#[must_use]
8#[derive(Clone, Copy, Debug, PartialEq)]
9pub struct Matrix<const D: usize> {
10    pub(crate) rows: [[f64; D]; D],
11}
12
13impl<const D: usize> Matrix<D> {
14    /// Construct from row-major storage.
15    ///
16    /// # Examples
17    /// ```
18    /// #![allow(unused_imports)]
19    /// use la_stack::prelude::*;
20    ///
21    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
22    /// assert_eq!(m.get(0, 1), Some(2.0));
23    /// ```
24    #[inline]
25    pub const fn from_rows(rows: [[f64; D]; D]) -> Self {
26        Self { rows }
27    }
28
29    /// All-zeros matrix.
30    ///
31    /// # Examples
32    /// ```
33    /// #![allow(unused_imports)]
34    /// use la_stack::prelude::*;
35    ///
36    /// let z = Matrix::<2>::zero();
37    /// assert_eq!(z.get(1, 1), Some(0.0));
38    /// ```
39    #[inline]
40    pub const fn zero() -> Self {
41        Self {
42            rows: [[0.0; D]; D],
43        }
44    }
45
46    /// Identity matrix.
47    ///
48    /// # Examples
49    /// ```
50    /// #![allow(unused_imports)]
51    /// use la_stack::prelude::*;
52    ///
53    /// let i = Matrix::<3>::identity();
54    /// assert_eq!(i.get(0, 0), Some(1.0));
55    /// assert_eq!(i.get(0, 1), Some(0.0));
56    /// assert_eq!(i.get(2, 2), Some(1.0));
57    /// ```
58    #[inline]
59    pub const fn identity() -> Self {
60        let mut m = Self::zero();
61
62        let mut i = 0;
63        while i < D {
64            m.rows[i][i] = 1.0;
65            i += 1;
66        }
67
68        m
69    }
70
71    /// Get an element with bounds checking.
72    ///
73    /// # Examples
74    /// ```
75    /// #![allow(unused_imports)]
76    /// use la_stack::prelude::*;
77    ///
78    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
79    /// assert_eq!(m.get(1, 0), Some(3.0));
80    /// assert_eq!(m.get(2, 0), None);
81    /// ```
82    #[inline]
83    #[must_use]
84    pub const fn get(&self, r: usize, c: usize) -> Option<f64> {
85        if r < D && c < D {
86            Some(self.rows[r][c])
87        } else {
88            None
89        }
90    }
91
92    /// Set an element with bounds checking.
93    ///
94    /// Returns `true` if the index was in-bounds.
95    ///
96    /// # Examples
97    /// ```
98    /// #![allow(unused_imports)]
99    /// use la_stack::prelude::*;
100    ///
101    /// let mut m = Matrix::<2>::zero();
102    /// assert!(m.set(0, 1, 2.5));
103    /// assert_eq!(m.get(0, 1), Some(2.5));
104    /// assert!(!m.set(10, 0, 1.0));
105    /// ```
106    #[inline]
107    pub const fn set(&mut self, r: usize, c: usize, value: f64) -> bool {
108        if r < D && c < D {
109            self.rows[r][c] = value;
110            true
111        } else {
112            false
113        }
114    }
115
116    /// Infinity norm (maximum absolute row sum).
117    ///
118    /// # Examples
119    /// ```
120    /// #![allow(unused_imports)]
121    /// use la_stack::prelude::*;
122    ///
123    /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]);
124    /// assert!((m.inf_norm() - 7.0).abs() <= 1e-12);
125    /// ```
126    #[inline]
127    #[must_use]
128    pub fn inf_norm(&self) -> f64 {
129        let mut max_row_sum: f64 = 0.0;
130
131        for row in &self.rows {
132            let row_sum: f64 = row.iter().map(|&x| x.abs()).sum();
133            if row_sum > max_row_sum {
134                max_row_sum = row_sum;
135            }
136        }
137
138        max_row_sum
139    }
140
141    /// Compute an LU decomposition with partial pivoting.
142    ///
143    /// # Examples
144    /// ```
145    /// #![allow(unused_imports)]
146    /// use la_stack::prelude::*;
147    ///
148    /// # fn main() -> Result<(), LaError> {
149    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
150    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
151    ///
152    /// let b = Vector::<2>::new([5.0, 11.0]);
153    /// let x = lu.solve_vec(b)?.into_array();
154    ///
155    /// assert!((x[0] - 1.0).abs() <= 1e-12);
156    /// assert!((x[1] - 2.0).abs() <= 1e-12);
157    /// # Ok(())
158    /// # }
159    /// ```
160    ///
161    /// # Errors
162    /// Returns [`LaError::Singular`] if no suitable pivot (|pivot| > `tol`) exists for a column.
163    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
164    #[inline]
165    pub fn lu(self, tol: f64) -> Result<Lu<D>, LaError> {
166        Lu::factor(self, tol)
167    }
168
169    /// Determinant computed via LU decomposition.
170    ///
171    /// # Examples
172    /// ```
173    /// #![allow(unused_imports)]
174    /// use la_stack::prelude::*;
175    ///
176    /// # fn main() -> Result<(), LaError> {
177    /// let det = Matrix::<3>::identity().det(DEFAULT_PIVOT_TOL)?;
178    /// assert!((det - 1.0).abs() <= 1e-12);
179    /// # Ok(())
180    /// # }
181    /// ```
182    ///
183    /// # Errors
184    /// Propagates LU factorization errors (e.g. singular matrices).
185    #[inline]
186    pub fn det(self, tol: f64) -> Result<f64, LaError> {
187        self.lu(tol).map(|lu| lu.det())
188    }
189}
190
191impl<const D: usize> Default for Matrix<D> {
192    #[inline]
193    fn default() -> Self {
194        Self::zero()
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::DEFAULT_PIVOT_TOL;
202
203    use approx::assert_abs_diff_eq;
204    use pastey::paste;
205
206    macro_rules! gen_public_api_matrix_tests {
207        ($d:literal) => {
208            paste! {
209                #[test]
210                fn [<public_api_matrix_from_rows_get_set_bounds_checked_ $d d>]() {
211                    let mut rows = [[0.0f64; $d]; $d];
212                    rows[0][0] = 1.0;
213                    rows[$d - 1][$d - 1] = -2.0;
214
215                    let mut m = Matrix::<$d>::from_rows(rows);
216
217                    assert_eq!(m.get(0, 0), Some(1.0));
218                    assert_eq!(m.get($d - 1, $d - 1), Some(-2.0));
219
220                    // Out-of-bounds is None.
221                    assert_eq!(m.get($d, 0), None);
222
223                    // Out-of-bounds set fails.
224                    assert!(!m.set($d, 0, 3.0));
225
226                    // In-bounds set works.
227                    assert!(m.set(0, $d - 1, 3.0));
228                    assert_eq!(m.get(0, $d - 1), Some(3.0));
229                }
230
231                #[test]
232                fn [<public_api_matrix_zero_and_default_are_zero_ $d d>]() {
233                    let z = Matrix::<$d>::zero();
234                    assert_abs_diff_eq!(z.inf_norm(), 0.0, epsilon = 0.0);
235
236                    let d = Matrix::<$d>::default();
237                    assert_abs_diff_eq!(d.inf_norm(), 0.0, epsilon = 0.0);
238                }
239
240                #[test]
241                fn [<public_api_matrix_inf_norm_max_row_sum_ $d d>]() {
242                    let mut rows = [[0.0f64; $d]; $d];
243
244                    // Row 0 has absolute row sum = D.
245                    for c in 0..$d {
246                        rows[0][c] = -1.0;
247                    }
248
249                    // Row 1 has smaller absolute row sum.
250                    for c in 0..$d {
251                        rows[1][c] = 0.5;
252                    }
253
254                    let m = Matrix::<$d>::from_rows(rows);
255                    assert_abs_diff_eq!(m.inf_norm(), f64::from($d), epsilon = 0.0);
256                }
257
258                #[test]
259                fn [<public_api_matrix_identity_lu_det_solve_vec_ $d d>]() {
260                    let m = Matrix::<$d>::identity();
261
262                    // Identity has ones on diag and zeros off diag.
263                    for r in 0..$d {
264                        for c in 0..$d {
265                            let expected = if r == c { 1.0 } else { 0.0 };
266                            assert_abs_diff_eq!(m.get(r, c).unwrap(), expected, epsilon = 0.0);
267                        }
268                    }
269
270                    // Determinant is 1.
271                    let det = m.det(DEFAULT_PIVOT_TOL).unwrap();
272                    assert_abs_diff_eq!(det, 1.0, epsilon = 1e-12);
273
274                    // LU solve on identity returns the RHS.
275                    let lu = m.lu(DEFAULT_PIVOT_TOL).unwrap();
276
277                    let b_arr = {
278                        let mut arr = [0.0f64; $d];
279                        let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
280                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
281                            *dst = *src;
282                        }
283                        arr
284                    };
285
286                    let b = crate::Vector::<$d>::new(b_arr);
287                    let x = lu.solve_vec(b).unwrap().into_array();
288
289                    for (x_i, b_i) in x.iter().zip(b_arr.iter()) {
290                        assert_abs_diff_eq!(*x_i, *b_i, epsilon = 1e-12);
291                    }
292                }
293            }
294        };
295    }
296
297    // Mirror delaunay-style multi-dimension tests.
298    gen_public_api_matrix_tests!(2);
299    gen_public_api_matrix_tests!(3);
300    gen_public_api_matrix_tests!(4);
301    gen_public_api_matrix_tests!(5);
302}