la_stack/matrix.rs
1//! Fixed-size, stack-allocated square matrices.
2
3use crate::LaError;
4use crate::lu::Lu;
5
6/// Fixed-size square matrix `D×D`, stored inline.
7#[must_use]
8#[derive(Clone, Copy, Debug, PartialEq)]
9pub struct Matrix<const D: usize> {
10 pub(crate) rows: [[f64; D]; D],
11}
12
13impl<const D: usize> Matrix<D> {
14 /// Construct from row-major storage.
15 ///
16 /// # Examples
17 /// ```
18 /// #![allow(unused_imports)]
19 /// use la_stack::prelude::*;
20 ///
21 /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
22 /// assert_eq!(m.get(0, 1), Some(2.0));
23 /// ```
24 #[inline]
25 pub const fn from_rows(rows: [[f64; D]; D]) -> Self {
26 Self { rows }
27 }
28
29 /// All-zeros matrix.
30 ///
31 /// # Examples
32 /// ```
33 /// #![allow(unused_imports)]
34 /// use la_stack::prelude::*;
35 ///
36 /// let z = Matrix::<2>::zero();
37 /// assert_eq!(z.get(1, 1), Some(0.0));
38 /// ```
39 #[inline]
40 pub const fn zero() -> Self {
41 Self {
42 rows: [[0.0; D]; D],
43 }
44 }
45
46 /// Identity matrix.
47 ///
48 /// # Examples
49 /// ```
50 /// #![allow(unused_imports)]
51 /// use la_stack::prelude::*;
52 ///
53 /// let i = Matrix::<3>::identity();
54 /// assert_eq!(i.get(0, 0), Some(1.0));
55 /// assert_eq!(i.get(0, 1), Some(0.0));
56 /// assert_eq!(i.get(2, 2), Some(1.0));
57 /// ```
58 #[inline]
59 pub const fn identity() -> Self {
60 let mut m = Self::zero();
61
62 let mut i = 0;
63 while i < D {
64 m.rows[i][i] = 1.0;
65 i += 1;
66 }
67
68 m
69 }
70
71 /// Get an element with bounds checking.
72 ///
73 /// # Examples
74 /// ```
75 /// #![allow(unused_imports)]
76 /// use la_stack::prelude::*;
77 ///
78 /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
79 /// assert_eq!(m.get(1, 0), Some(3.0));
80 /// assert_eq!(m.get(2, 0), None);
81 /// ```
82 #[inline]
83 #[must_use]
84 pub const fn get(&self, r: usize, c: usize) -> Option<f64> {
85 if r < D && c < D {
86 Some(self.rows[r][c])
87 } else {
88 None
89 }
90 }
91
92 /// Set an element with bounds checking.
93 ///
94 /// Returns `true` if the index was in-bounds.
95 ///
96 /// # Examples
97 /// ```
98 /// #![allow(unused_imports)]
99 /// use la_stack::prelude::*;
100 ///
101 /// let mut m = Matrix::<2>::zero();
102 /// assert!(m.set(0, 1, 2.5));
103 /// assert_eq!(m.get(0, 1), Some(2.5));
104 /// assert!(!m.set(10, 0, 1.0));
105 /// ```
106 #[inline]
107 pub const fn set(&mut self, r: usize, c: usize, value: f64) -> bool {
108 if r < D && c < D {
109 self.rows[r][c] = value;
110 true
111 } else {
112 false
113 }
114 }
115
116 /// Infinity norm (maximum absolute row sum).
117 ///
118 /// # Examples
119 /// ```
120 /// #![allow(unused_imports)]
121 /// use la_stack::prelude::*;
122 ///
123 /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]);
124 /// assert!((m.inf_norm() - 7.0).abs() <= 1e-12);
125 /// ```
126 #[inline]
127 #[must_use]
128 pub fn inf_norm(&self) -> f64 {
129 let mut max_row_sum: f64 = 0.0;
130
131 for row in &self.rows {
132 let row_sum: f64 = row.iter().map(|&x| x.abs()).sum();
133 if row_sum > max_row_sum {
134 max_row_sum = row_sum;
135 }
136 }
137
138 max_row_sum
139 }
140
141 /// Compute an LU decomposition with partial pivoting.
142 ///
143 /// # Examples
144 /// ```
145 /// #![allow(unused_imports)]
146 /// use la_stack::prelude::*;
147 ///
148 /// # fn main() -> Result<(), LaError> {
149 /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
150 /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
151 ///
152 /// let b = Vector::<2>::new([5.0, 11.0]);
153 /// let x = lu.solve_vec(b)?.into_array();
154 ///
155 /// assert!((x[0] - 1.0).abs() <= 1e-12);
156 /// assert!((x[1] - 2.0).abs() <= 1e-12);
157 /// # Ok(())
158 /// # }
159 /// ```
160 ///
161 /// # Errors
162 /// Returns [`LaError::Singular`] if no suitable pivot (|pivot| > `tol`) exists for a column.
163 /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
164 #[inline]
165 pub fn lu(self, tol: f64) -> Result<Lu<D>, LaError> {
166 Lu::factor(self, tol)
167 }
168
169 /// Determinant computed via LU decomposition.
170 ///
171 /// # Examples
172 /// ```
173 /// #![allow(unused_imports)]
174 /// use la_stack::prelude::*;
175 ///
176 /// # fn main() -> Result<(), LaError> {
177 /// let det = Matrix::<3>::identity().det(DEFAULT_PIVOT_TOL)?;
178 /// assert!((det - 1.0).abs() <= 1e-12);
179 /// # Ok(())
180 /// # }
181 /// ```
182 ///
183 /// # Errors
184 /// Propagates LU factorization errors (e.g. singular matrices).
185 #[inline]
186 pub fn det(self, tol: f64) -> Result<f64, LaError> {
187 self.lu(tol).map(|lu| lu.det())
188 }
189}
190
191impl<const D: usize> Default for Matrix<D> {
192 #[inline]
193 fn default() -> Self {
194 Self::zero()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::DEFAULT_PIVOT_TOL;
202
203 use approx::assert_abs_diff_eq;
204 use pastey::paste;
205
206 macro_rules! gen_public_api_matrix_tests {
207 ($d:literal) => {
208 paste! {
209 #[test]
210 fn [<public_api_matrix_from_rows_get_set_bounds_checked_ $d d>]() {
211 let mut rows = [[0.0f64; $d]; $d];
212 rows[0][0] = 1.0;
213 rows[$d - 1][$d - 1] = -2.0;
214
215 let mut m = Matrix::<$d>::from_rows(rows);
216
217 assert_eq!(m.get(0, 0), Some(1.0));
218 assert_eq!(m.get($d - 1, $d - 1), Some(-2.0));
219
220 // Out-of-bounds is None.
221 assert_eq!(m.get($d, 0), None);
222
223 // Out-of-bounds set fails.
224 assert!(!m.set($d, 0, 3.0));
225
226 // In-bounds set works.
227 assert!(m.set(0, $d - 1, 3.0));
228 assert_eq!(m.get(0, $d - 1), Some(3.0));
229 }
230
231 #[test]
232 fn [<public_api_matrix_zero_and_default_are_zero_ $d d>]() {
233 let z = Matrix::<$d>::zero();
234 assert_abs_diff_eq!(z.inf_norm(), 0.0, epsilon = 0.0);
235
236 let d = Matrix::<$d>::default();
237 assert_abs_diff_eq!(d.inf_norm(), 0.0, epsilon = 0.0);
238 }
239
240 #[test]
241 fn [<public_api_matrix_inf_norm_max_row_sum_ $d d>]() {
242 let mut rows = [[0.0f64; $d]; $d];
243
244 // Row 0 has absolute row sum = D.
245 for c in 0..$d {
246 rows[0][c] = -1.0;
247 }
248
249 // Row 1 has smaller absolute row sum.
250 for c in 0..$d {
251 rows[1][c] = 0.5;
252 }
253
254 let m = Matrix::<$d>::from_rows(rows);
255 assert_abs_diff_eq!(m.inf_norm(), f64::from($d), epsilon = 0.0);
256 }
257
258 #[test]
259 fn [<public_api_matrix_identity_lu_det_solve_vec_ $d d>]() {
260 let m = Matrix::<$d>::identity();
261
262 // Identity has ones on diag and zeros off diag.
263 for r in 0..$d {
264 for c in 0..$d {
265 let expected = if r == c { 1.0 } else { 0.0 };
266 assert_abs_diff_eq!(m.get(r, c).unwrap(), expected, epsilon = 0.0);
267 }
268 }
269
270 // Determinant is 1.
271 let det = m.det(DEFAULT_PIVOT_TOL).unwrap();
272 assert_abs_diff_eq!(det, 1.0, epsilon = 1e-12);
273
274 // LU solve on identity returns the RHS.
275 let lu = m.lu(DEFAULT_PIVOT_TOL).unwrap();
276
277 let b_arr = {
278 let mut arr = [0.0f64; $d];
279 let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
280 for (dst, src) in arr.iter_mut().zip(values.iter()) {
281 *dst = *src;
282 }
283 arr
284 };
285
286 let b = crate::Vector::<$d>::new(b_arr);
287 let x = lu.solve_vec(b).unwrap().into_array();
288
289 for (x_i, b_i) in x.iter().zip(b_arr.iter()) {
290 assert_abs_diff_eq!(*x_i, *b_i, epsilon = 1e-12);
291 }
292 }
293 }
294 };
295 }
296
297 // Mirror delaunay-style multi-dimension tests.
298 gen_public_api_matrix_tests!(2);
299 gen_public_api_matrix_tests!(3);
300 gen_public_api_matrix_tests!(4);
301 gen_public_api_matrix_tests!(5);
302}