la_stack/matrix.rs
1//! Fixed-size, stack-allocated square matrices.
2
3use crate::LaError;
4use crate::ldlt::Ldlt;
5use crate::lu::Lu;
6
7/// Fixed-size square matrix `D×D`, stored inline.
8#[must_use]
9#[derive(Clone, Copy, Debug, PartialEq)]
10pub struct Matrix<const D: usize> {
11 pub(crate) rows: [[f64; D]; D],
12}
13
14impl<const D: usize> Matrix<D> {
15 /// Construct from row-major storage.
16 ///
17 /// # Examples
18 /// ```
19 /// use la_stack::prelude::*;
20 ///
21 /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
22 /// assert_eq!(m.get(0, 1), Some(2.0));
23 /// ```
24 #[inline]
25 pub const fn from_rows(rows: [[f64; D]; D]) -> Self {
26 Self { rows }
27 }
28
29 /// All-zeros matrix.
30 ///
31 /// # Examples
32 /// ```
33 /// use la_stack::prelude::*;
34 ///
35 /// let z = Matrix::<2>::zero();
36 /// assert_eq!(z.get(1, 1), Some(0.0));
37 /// ```
38 #[inline]
39 pub const fn zero() -> Self {
40 Self {
41 rows: [[0.0; D]; D],
42 }
43 }
44
45 /// Identity matrix.
46 ///
47 /// # Examples
48 /// ```
49 /// use la_stack::prelude::*;
50 ///
51 /// let i = Matrix::<3>::identity();
52 /// assert_eq!(i.get(0, 0), Some(1.0));
53 /// assert_eq!(i.get(0, 1), Some(0.0));
54 /// assert_eq!(i.get(2, 2), Some(1.0));
55 /// ```
56 #[inline]
57 pub const fn identity() -> Self {
58 let mut m = Self::zero();
59
60 let mut i = 0;
61 while i < D {
62 m.rows[i][i] = 1.0;
63 i += 1;
64 }
65
66 m
67 }
68
69 /// Get an element with bounds checking.
70 ///
71 /// # Examples
72 /// ```
73 /// use la_stack::prelude::*;
74 ///
75 /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
76 /// assert_eq!(m.get(1, 0), Some(3.0));
77 /// assert_eq!(m.get(2, 0), None);
78 /// ```
79 #[inline]
80 #[must_use]
81 pub const fn get(&self, r: usize, c: usize) -> Option<f64> {
82 if r < D && c < D {
83 Some(self.rows[r][c])
84 } else {
85 None
86 }
87 }
88
89 /// Set an element with bounds checking.
90 ///
91 /// Returns `true` if the index was in-bounds.
92 ///
93 /// # Examples
94 /// ```
95 /// use la_stack::prelude::*;
96 ///
97 /// let mut m = Matrix::<2>::zero();
98 /// assert!(m.set(0, 1, 2.5));
99 /// assert_eq!(m.get(0, 1), Some(2.5));
100 /// assert!(!m.set(10, 0, 1.0));
101 /// ```
102 #[inline]
103 pub const fn set(&mut self, r: usize, c: usize, value: f64) -> bool {
104 if r < D && c < D {
105 self.rows[r][c] = value;
106 true
107 } else {
108 false
109 }
110 }
111
112 /// Infinity norm (maximum absolute row sum).
113 ///
114 /// # Examples
115 /// ```
116 /// use la_stack::prelude::*;
117 ///
118 /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]);
119 /// assert!((m.inf_norm() - 7.0).abs() <= 1e-12);
120 /// ```
121 #[inline]
122 #[must_use]
123 pub fn inf_norm(&self) -> f64 {
124 let mut max_row_sum: f64 = 0.0;
125
126 for row in &self.rows {
127 let row_sum: f64 = row.iter().map(|&x| x.abs()).sum();
128 if row_sum > max_row_sum {
129 max_row_sum = row_sum;
130 }
131 }
132
133 max_row_sum
134 }
135
136 /// Compute an LU decomposition with partial pivoting.
137 ///
138 /// # Examples
139 /// ```
140 /// use la_stack::prelude::*;
141 ///
142 /// # fn main() -> Result<(), LaError> {
143 /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
144 /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
145 ///
146 /// let b = Vector::<2>::new([5.0, 11.0]);
147 /// let x = lu.solve_vec(b)?.into_array();
148 ///
149 /// assert!((x[0] - 1.0).abs() <= 1e-12);
150 /// assert!((x[1] - 2.0).abs() <= 1e-12);
151 /// # Ok(())
152 /// # }
153 /// ```
154 ///
155 /// # Errors
156 /// Returns [`LaError::Singular`] if, for some column `k`, the largest-magnitude candidate pivot
157 /// in that column satisfies `|pivot| <= tol` (so no numerically usable pivot exists).
158 /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
159 #[inline]
160 pub fn lu(self, tol: f64) -> Result<Lu<D>, LaError> {
161 Lu::factor(self, tol)
162 }
163
164 /// Compute an LDLT factorization (`A = L D Lᵀ`) without pivoting.
165 ///
166 /// This is intended for symmetric positive definite (SPD) and positive semi-definite (PSD)
167 /// matrices such as Gram matrices.
168 ///
169 /// # Examples
170 /// ```
171 /// use la_stack::prelude::*;
172 ///
173 /// # fn main() -> Result<(), LaError> {
174 /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]);
175 /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?;
176 ///
177 /// // det(A) = 8
178 /// assert!((ldlt.det() - 8.0).abs() <= 1e-12);
179 ///
180 /// // Solve A x = b
181 /// let b = Vector::<2>::new([1.0, 2.0]);
182 /// let x = ldlt.solve_vec(b)?.into_array();
183 /// assert!((x[0] - (-0.125)).abs() <= 1e-12);
184 /// assert!((x[1] - 0.75).abs() <= 1e-12);
185 /// # Ok(())
186 /// # }
187 /// ```
188 ///
189 /// # Errors
190 /// Returns [`LaError::Singular`] if, for some step `k`, the required diagonal entry `d = D[k,k]`
191 /// is `<= tol` (non-positive or too small). This treats PSD degeneracy (and indefinite inputs)
192 /// as singular/degenerate.
193 /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
194 #[inline]
195 pub fn ldlt(self, tol: f64) -> Result<Ldlt<D>, LaError> {
196 Ldlt::factor(self, tol)
197 }
198
199 /// Determinant computed via LU decomposition.
200 ///
201 /// # Examples
202 /// ```
203 /// use la_stack::prelude::*;
204 ///
205 /// # fn main() -> Result<(), LaError> {
206 /// let det = Matrix::<3>::identity().det(DEFAULT_PIVOT_TOL)?;
207 /// assert!((det - 1.0).abs() <= 1e-12);
208 /// # Ok(())
209 /// # }
210 /// ```
211 ///
212 /// # Errors
213 /// Propagates LU factorization errors (e.g. singular matrices).
214 #[inline]
215 pub fn det(self, tol: f64) -> Result<f64, LaError> {
216 self.lu(tol).map(|lu| lu.det())
217 }
218}
219
220impl<const D: usize> Default for Matrix<D> {
221 #[inline]
222 fn default() -> Self {
223 Self::zero()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::DEFAULT_PIVOT_TOL;
231
232 use approx::assert_abs_diff_eq;
233 use pastey::paste;
234
235 macro_rules! gen_public_api_matrix_tests {
236 ($d:literal) => {
237 paste! {
238 #[test]
239 fn [<public_api_matrix_from_rows_get_set_bounds_checked_ $d d>]() {
240 let mut rows = [[0.0f64; $d]; $d];
241 rows[0][0] = 1.0;
242 rows[$d - 1][$d - 1] = -2.0;
243
244 let mut m = Matrix::<$d>::from_rows(rows);
245
246 assert_eq!(m.get(0, 0), Some(1.0));
247 assert_eq!(m.get($d - 1, $d - 1), Some(-2.0));
248
249 // Out-of-bounds is None.
250 assert_eq!(m.get($d, 0), None);
251
252 // Out-of-bounds set fails.
253 assert!(!m.set($d, 0, 3.0));
254
255 // In-bounds set works.
256 assert!(m.set(0, $d - 1, 3.0));
257 assert_eq!(m.get(0, $d - 1), Some(3.0));
258 }
259
260 #[test]
261 fn [<public_api_matrix_zero_and_default_are_zero_ $d d>]() {
262 let z = Matrix::<$d>::zero();
263 assert_abs_diff_eq!(z.inf_norm(), 0.0, epsilon = 0.0);
264
265 let d = Matrix::<$d>::default();
266 assert_abs_diff_eq!(d.inf_norm(), 0.0, epsilon = 0.0);
267 }
268
269 #[test]
270 fn [<public_api_matrix_inf_norm_max_row_sum_ $d d>]() {
271 let mut rows = [[0.0f64; $d]; $d];
272
273 // Row 0 has absolute row sum = D.
274 for c in 0..$d {
275 rows[0][c] = -1.0;
276 }
277
278 // Row 1 has smaller absolute row sum.
279 for c in 0..$d {
280 rows[1][c] = 0.5;
281 }
282
283 let m = Matrix::<$d>::from_rows(rows);
284 assert_abs_diff_eq!(m.inf_norm(), f64::from($d), epsilon = 0.0);
285 }
286
287 #[test]
288 fn [<public_api_matrix_identity_lu_det_solve_vec_ $d d>]() {
289 let m = Matrix::<$d>::identity();
290
291 // Identity has ones on diag and zeros off diag.
292 for r in 0..$d {
293 for c in 0..$d {
294 let expected = if r == c { 1.0 } else { 0.0 };
295 assert_abs_diff_eq!(m.get(r, c).unwrap(), expected, epsilon = 0.0);
296 }
297 }
298
299 // Determinant is 1.
300 let det = m.det(DEFAULT_PIVOT_TOL).unwrap();
301 assert_abs_diff_eq!(det, 1.0, epsilon = 1e-12);
302
303 // LU solve on identity returns the RHS.
304 let lu = m.lu(DEFAULT_PIVOT_TOL).unwrap();
305
306 let b_arr = {
307 let mut arr = [0.0f64; $d];
308 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
309 for (dst, src) in arr.iter_mut().zip(values.iter()) {
310 *dst = *src;
311 }
312 arr
313 };
314
315 let b = crate::Vector::<$d>::new(b_arr);
316 let x = lu.solve_vec(b).unwrap().into_array();
317
318 for (x_i, b_i) in x.iter().zip(b_arr.iter()) {
319 assert_abs_diff_eq!(*x_i, *b_i, epsilon = 1e-12);
320 }
321 }
322 }
323 };
324 }
325
326 // Mirror delaunay-style multi-dimension tests.
327 gen_public_api_matrix_tests!(2);
328 gen_public_api_matrix_tests!(3);
329 gen_public_api_matrix_tests!(4);
330 gen_public_api_matrix_tests!(5);
331}