la_stack/matrix.rs
1//! Fixed-size, stack-allocated square matrices.
2
3use crate::LaError;
4use crate::lu::Lu;
5
6/// Fixed-size square matrix `D×D`, stored inline.
7#[must_use]
8#[derive(Clone, Copy, Debug, PartialEq)]
9pub struct Matrix<const D: usize> {
10 pub(crate) rows: [[f64; D]; D],
11}
12
13impl<const D: usize> Matrix<D> {
14 /// Construct from row-major storage.
15 ///
16 /// # Examples
17 /// ```
18 /// use la_stack::prelude::*;
19 ///
20 /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
21 /// assert_eq!(m.get(0, 1), Some(2.0));
22 /// ```
23 #[inline]
24 pub const fn from_rows(rows: [[f64; D]; D]) -> Self {
25 Self { rows }
26 }
27
28 /// All-zeros matrix.
29 ///
30 /// # Examples
31 /// ```
32 /// use la_stack::prelude::*;
33 ///
34 /// let z = Matrix::<2>::zero();
35 /// assert_eq!(z.get(1, 1), Some(0.0));
36 /// ```
37 #[inline]
38 pub const fn zero() -> Self {
39 Self {
40 rows: [[0.0; D]; D],
41 }
42 }
43
44 /// Identity matrix.
45 ///
46 /// # Examples
47 /// ```
48 /// use la_stack::prelude::*;
49 ///
50 /// let i = Matrix::<3>::identity();
51 /// assert_eq!(i.get(0, 0), Some(1.0));
52 /// assert_eq!(i.get(0, 1), Some(0.0));
53 /// assert_eq!(i.get(2, 2), Some(1.0));
54 /// ```
55 #[inline]
56 pub const fn identity() -> Self {
57 let mut m = Self::zero();
58
59 let mut i = 0;
60 while i < D {
61 m.rows[i][i] = 1.0;
62 i += 1;
63 }
64
65 m
66 }
67
68 /// Get an element with bounds checking.
69 ///
70 /// # Examples
71 /// ```
72 /// use la_stack::prelude::*;
73 ///
74 /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
75 /// assert_eq!(m.get(1, 0), Some(3.0));
76 /// assert_eq!(m.get(2, 0), None);
77 /// ```
78 #[inline]
79 #[must_use]
80 pub const fn get(&self, r: usize, c: usize) -> Option<f64> {
81 if r < D && c < D {
82 Some(self.rows[r][c])
83 } else {
84 None
85 }
86 }
87
88 /// Set an element with bounds checking.
89 ///
90 /// Returns `true` if the index was in-bounds.
91 ///
92 /// # Examples
93 /// ```
94 /// use la_stack::prelude::*;
95 ///
96 /// let mut m = Matrix::<2>::zero();
97 /// assert!(m.set(0, 1, 2.5));
98 /// assert_eq!(m.get(0, 1), Some(2.5));
99 /// assert!(!m.set(10, 0, 1.0));
100 /// ```
101 #[inline]
102 pub const fn set(&mut self, r: usize, c: usize, value: f64) -> bool {
103 if r < D && c < D {
104 self.rows[r][c] = value;
105 true
106 } else {
107 false
108 }
109 }
110
111 /// Infinity norm (maximum absolute row sum).
112 ///
113 /// # Examples
114 /// ```
115 /// use la_stack::prelude::*;
116 ///
117 /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]);
118 /// assert!((m.inf_norm() - 7.0).abs() <= 1e-12);
119 /// ```
120 #[inline]
121 #[must_use]
122 pub fn inf_norm(&self) -> f64 {
123 let mut max_row_sum: f64 = 0.0;
124
125 for row in &self.rows {
126 let row_sum: f64 = row.iter().map(|&x| x.abs()).sum();
127 if row_sum > max_row_sum {
128 max_row_sum = row_sum;
129 }
130 }
131
132 max_row_sum
133 }
134
135 /// Compute an LU decomposition with partial pivoting.
136 ///
137 /// # Examples
138 /// ```
139 /// use la_stack::prelude::*;
140 ///
141 /// # fn main() -> Result<(), LaError> {
142 /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
143 /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
144 ///
145 /// let b = Vector::<2>::new([5.0, 11.0]);
146 /// let x = lu.solve_vec(b)?.into_array();
147 ///
148 /// assert!((x[0] - 1.0).abs() <= 1e-12);
149 /// assert!((x[1] - 2.0).abs() <= 1e-12);
150 /// # Ok(())
151 /// # }
152 /// ```
153 ///
154 /// # Errors
155 /// Returns [`LaError::Singular`] if no suitable pivot (|pivot| > `tol`) exists for a column.
156 /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
157 #[inline]
158 pub fn lu(self, tol: f64) -> Result<Lu<D>, LaError> {
159 Lu::factor(self, tol)
160 }
161
162 /// Determinant computed via LU decomposition.
163 ///
164 /// # Examples
165 /// ```
166 /// use la_stack::prelude::*;
167 ///
168 /// # fn main() -> Result<(), LaError> {
169 /// let det = Matrix::<3>::identity().det(DEFAULT_PIVOT_TOL)?;
170 /// assert!((det - 1.0).abs() <= 1e-12);
171 /// # Ok(())
172 /// # }
173 /// ```
174 ///
175 /// # Errors
176 /// Propagates LU factorization errors (e.g. singular matrices).
177 #[inline]
178 pub fn det(self, tol: f64) -> Result<f64, LaError> {
179 self.lu(tol).map(|lu| lu.det())
180 }
181}
182
183impl<const D: usize> Default for Matrix<D> {
184 #[inline]
185 fn default() -> Self {
186 Self::zero()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::DEFAULT_PIVOT_TOL;
194
195 use approx::assert_abs_diff_eq;
196 use pastey::paste;
197
198 macro_rules! gen_public_api_matrix_tests {
199 ($d:literal) => {
200 paste! {
201 #[test]
202 fn [<public_api_matrix_from_rows_get_set_bounds_checked_ $d d>]() {
203 let mut rows = [[0.0f64; $d]; $d];
204 rows[0][0] = 1.0;
205 rows[$d - 1][$d - 1] = -2.0;
206
207 let mut m = Matrix::<$d>::from_rows(rows);
208
209 assert_eq!(m.get(0, 0), Some(1.0));
210 assert_eq!(m.get($d - 1, $d - 1), Some(-2.0));
211
212 // Out-of-bounds is None.
213 assert_eq!(m.get($d, 0), None);
214
215 // Out-of-bounds set fails.
216 assert!(!m.set($d, 0, 3.0));
217
218 // In-bounds set works.
219 assert!(m.set(0, $d - 1, 3.0));
220 assert_eq!(m.get(0, $d - 1), Some(3.0));
221 }
222
223 #[test]
224 fn [<public_api_matrix_zero_and_default_are_zero_ $d d>]() {
225 let z = Matrix::<$d>::zero();
226 assert_abs_diff_eq!(z.inf_norm(), 0.0, epsilon = 0.0);
227
228 let d = Matrix::<$d>::default();
229 assert_abs_diff_eq!(d.inf_norm(), 0.0, epsilon = 0.0);
230 }
231
232 #[test]
233 fn [<public_api_matrix_inf_norm_max_row_sum_ $d d>]() {
234 let mut rows = [[0.0f64; $d]; $d];
235
236 // Row 0 has absolute row sum = D.
237 for c in 0..$d {
238 rows[0][c] = -1.0;
239 }
240
241 // Row 1 has smaller absolute row sum.
242 for c in 0..$d {
243 rows[1][c] = 0.5;
244 }
245
246 let m = Matrix::<$d>::from_rows(rows);
247 assert_abs_diff_eq!(m.inf_norm(), f64::from($d), epsilon = 0.0);
248 }
249
250 #[test]
251 fn [<public_api_matrix_identity_lu_det_solve_vec_ $d d>]() {
252 let m = Matrix::<$d>::identity();
253
254 // Identity has ones on diag and zeros off diag.
255 for r in 0..$d {
256 for c in 0..$d {
257 let expected = if r == c { 1.0 } else { 0.0 };
258 assert_abs_diff_eq!(m.get(r, c).unwrap(), expected, epsilon = 0.0);
259 }
260 }
261
262 // Determinant is 1.
263 let det = m.det(DEFAULT_PIVOT_TOL).unwrap();
264 assert_abs_diff_eq!(det, 1.0, epsilon = 1e-12);
265
266 // LU solve on identity returns the RHS.
267 let lu = m.lu(DEFAULT_PIVOT_TOL).unwrap();
268
269 let b_arr = {
270 let mut arr = [0.0f64; $d];
271 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
272 for (dst, src) in arr.iter_mut().zip(values.iter()) {
273 *dst = *src;
274 }
275 arr
276 };
277
278 let b = crate::Vector::<$d>::new(b_arr);
279 let x = lu.solve_vec(b).unwrap().into_array();
280
281 for (x_i, b_i) in x.iter().zip(b_arr.iter()) {
282 assert_abs_diff_eq!(*x_i, *b_i, epsilon = 1e-12);
283 }
284 }
285 }
286 };
287 }
288
289 // Mirror delaunay-style multi-dimension tests.
290 gen_public_api_matrix_tests!(2);
291 gen_public_api_matrix_tests!(3);
292 gen_public_api_matrix_tests!(4);
293 gen_public_api_matrix_tests!(5);
294}