Skip to main content

la_stack/
matrix.rs

1//! Fixed-size, stack-allocated square matrices.
2
3use core::hint::cold_path;
4
5use crate::LaError;
6use crate::ldlt::Ldlt;
7use crate::lu::Lu;
8use crate::{ERR_COEFF_2, ERR_COEFF_3, ERR_COEFF_4};
9
10/// Fixed-size square matrix `D×D`, stored inline.
11#[must_use]
12#[derive(Clone, Copy, Debug, PartialEq)]
13pub struct Matrix<const D: usize> {
14    pub(crate) rows: [[f64; D]; D],
15}
16
17impl<const D: usize> Matrix<D> {
18    /// Construct from row-major storage.
19    ///
20    /// # Examples
21    /// ```
22    /// use la_stack::prelude::*;
23    ///
24    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
25    /// assert_eq!(m.get(0, 1), Some(2.0));
26    /// ```
27    #[inline]
28    pub const fn from_rows(rows: [[f64; D]; D]) -> Self {
29        Self { rows }
30    }
31
32    /// All-zeros matrix.
33    ///
34    /// # Examples
35    /// ```
36    /// use la_stack::prelude::*;
37    ///
38    /// let z = Matrix::<2>::zero();
39    /// assert_eq!(z.get(1, 1), Some(0.0));
40    /// ```
41    #[inline]
42    pub const fn zero() -> Self {
43        Self {
44            rows: [[0.0; D]; D],
45        }
46    }
47
48    /// Identity matrix.
49    ///
50    /// # Examples
51    /// ```
52    /// use la_stack::prelude::*;
53    ///
54    /// let i = Matrix::<3>::identity();
55    /// assert_eq!(i.get(0, 0), Some(1.0));
56    /// assert_eq!(i.get(0, 1), Some(0.0));
57    /// assert_eq!(i.get(2, 2), Some(1.0));
58    /// ```
59    #[inline]
60    pub const fn identity() -> Self {
61        let mut m = Self::zero();
62
63        let mut i = 0;
64        while i < D {
65            m.rows[i][i] = 1.0;
66            i += 1;
67        }
68
69        m
70    }
71
72    /// Get an element with bounds checking.
73    ///
74    /// # Examples
75    /// ```
76    /// use la_stack::prelude::*;
77    ///
78    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
79    /// assert_eq!(m.get(1, 0), Some(3.0));
80    /// assert_eq!(m.get(2, 0), None);
81    /// ```
82    #[inline]
83    #[must_use]
84    pub const fn get(&self, r: usize, c: usize) -> Option<f64> {
85        if r < D && c < D {
86            Some(self.rows[r][c])
87        } else {
88            None
89        }
90    }
91
92    /// Set an element with bounds checking.
93    ///
94    /// Returns `true` if the index was in-bounds.
95    ///
96    /// # Examples
97    /// ```
98    /// use la_stack::prelude::*;
99    ///
100    /// let mut m = Matrix::<2>::zero();
101    /// assert!(m.set(0, 1, 2.5));
102    /// assert_eq!(m.get(0, 1), Some(2.5));
103    /// assert!(!m.set(10, 0, 1.0));
104    /// ```
105    #[inline]
106    pub const fn set(&mut self, r: usize, c: usize, value: f64) -> bool {
107        if r < D && c < D {
108            self.rows[r][c] = value;
109            true
110        } else {
111            false
112        }
113    }
114
115    /// Infinity norm (maximum absolute row sum).
116    ///
117    /// # Non-finite handling
118    /// If any entry is NaN, the result is NaN.  NaN is detected explicitly
119    /// because a naive `row_sum > max_row_sum` comparison silently skips NaN
120    /// rows (every ordered comparison against NaN is `false`).  If any entry
121    /// is infinite (and no entry is NaN), the result is `+∞`.
122    ///
123    /// # Examples
124    /// ```
125    /// use la_stack::prelude::*;
126    ///
127    /// let m = Matrix::<2>::from_rows([[1.0, -2.0], [3.0, 4.0]]);
128    /// assert!((m.inf_norm() - 7.0).abs() <= 1e-12);
129    ///
130    /// // NaN entries propagate to the norm.
131    /// let nan = Matrix::<2>::from_rows([[f64::NAN, 1.0], [2.0, 3.0]]);
132    /// assert!(nan.inf_norm().is_nan());
133    /// ```
134    #[inline]
135    #[must_use]
136    pub const fn inf_norm(&self) -> f64 {
137        let mut max_row_sum: f64 = 0.0;
138
139        let mut r = 0;
140        while r < D {
141            // Iterator chains like `row.iter().map(|x| x.abs()).sum()` are
142            // not yet const-stable, so accumulate the absolute row sum with
143            // a manual `while` loop.
144            let row = &self.rows[r];
145            let mut row_sum: f64 = 0.0;
146            let mut c = 0;
147            while c < D {
148                row_sum += row[c].abs();
149                c += 1;
150            }
151            // Propagate NaN explicitly: `f64::max` drops NaN (IEEE 754 `maxNum`)
152            // and `f64::maximum` (IEEE 754-2019 `maximum`) is still unstable,
153            // so we short-circuit on NaN instead.
154            if row_sum.is_nan() {
155                cold_path();
156                return f64::NAN;
157            }
158            if row_sum > max_row_sum {
159                max_row_sum = row_sum;
160            }
161            r += 1;
162        }
163
164        max_row_sum
165    }
166
167    /// Returns `true` if the matrix is symmetric within a relative tolerance.
168    ///
169    /// Two entries `self[r][c]` and `self[c][r]` are considered equal (for the
170    /// purposes of symmetry) when
171    /// `|self[r][c] - self[c][r]| <= rel_tol * max(1.0, self.inf_norm())`.
172    /// This mirrors the predicate used internally by the debug-build symmetry
173    /// check inside [`ldlt`](Self::ldlt), so callers can pre-validate matrices
174    /// that may come from untrusted sources without relying on a debug-only
175    /// panic.
176    ///
177    /// Use [`first_asymmetry`](Self::first_asymmetry) to locate the first
178    /// offending pair when this returns `false`.
179    ///
180    /// # NaN / infinity handling
181    /// Any non-finite `|self[r][c] - self[c][r]|` (NaN or ±∞) causes this
182    /// predicate to return `false`.  This catches both NaN off-diagonals and
183    /// asymmetric pairs where one side is infinite and the other is finite
184    /// (which would otherwise slip through when `inf_norm()` blows `eps` up
185    /// to `+∞` and makes `diff > eps` trivially false).  A matrix whose
186    /// [`inf_norm`](Self::inf_norm) is `+∞` can still tolerate *finite*
187    /// asymmetries under an infinite `eps` — callers who need strict equality
188    /// on large-magnitude finite entries should validate finiteness
189    /// separately.
190    ///
191    /// # Panics
192    /// In debug builds, panics if `rel_tol` is negative or NaN; in release
193    /// builds these are silently treated as garbage-in garbage-out, matching
194    /// the convention of [`lu`](Self::lu) and [`ldlt`](Self::ldlt).
195    ///
196    /// # Examples
197    /// ```
198    /// use la_stack::prelude::*;
199    ///
200    /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]);
201    /// assert!(a.is_symmetric(1e-12));
202    ///
203    /// let b = Matrix::<2>::from_rows([[4.0, 2.0], [3.0, 3.0]]);
204    /// assert!(!b.is_symmetric(1e-12));
205    /// ```
206    #[inline]
207    #[must_use]
208    pub fn is_symmetric(&self, rel_tol: f64) -> bool {
209        self.first_asymmetry(rel_tol).is_none()
210    }
211
212    /// Returns the indices `(r, c)` (with `r < c`) of the first off-diagonal
213    /// pair that violates symmetry, or `None` if the matrix is symmetric
214    /// within `rel_tol`.
215    ///
216    /// Iteration order is row-major over the strict upper triangle, so the
217    /// returned indices are the lexicographically smallest such pair.  The
218    /// predicate is the same as [`is_symmetric`](Self::is_symmetric):
219    /// `|self[r][c] - self[c][r]| <= rel_tol * max(1.0, self.inf_norm())`.
220    ///
221    /// # Panics
222    /// In debug builds, panics if `rel_tol` is negative or NaN.
223    ///
224    /// # Examples
225    /// ```
226    /// use la_stack::prelude::*;
227    ///
228    /// let a = Matrix::<3>::from_rows([
229    ///     [1.0, 2.0, 0.0],
230    ///     [2.0, 4.0, 5.0],
231    ///     [0.0, 6.0, 9.0], // 6.0 breaks symmetry with a[1][2] = 5.0
232    /// ]);
233    /// assert_eq!(a.first_asymmetry(1e-12), Some((1, 2)));
234    /// assert_eq!(Matrix::<3>::identity().first_asymmetry(1e-12), None);
235    /// ```
236    #[inline]
237    #[must_use]
238    pub fn first_asymmetry(&self, rel_tol: f64) -> Option<(usize, usize)> {
239        debug_assert!(
240            rel_tol >= 0.0,
241            "rel_tol must be non-negative (got {rel_tol})"
242        );
243        let eps = rel_tol * self.inf_norm().max(1.0);
244        for r in 0..D {
245            for c in (r + 1)..D {
246                let diff = (self.rows[r][c] - self.rows[c][r]).abs();
247                // Any non-finite `diff` is reported as asymmetric:
248                //  * NaN contaminates one side only, and `diff > eps` would
249                //    silently skip it because ordered comparisons against NaN
250                //    are always `false`.
251                //  * ±∞ arises when exactly one of `self[r][c]` / `self[c][r]`
252                //    is infinite; a naive `diff > eps` misses this when the
253                //    matrix's `inf_norm()` pushes `eps` to `+∞` (because
254                //    `∞ > ∞` is `false`).
255                if !diff.is_finite() || diff > eps {
256                    cold_path();
257                    return Some((r, c));
258                }
259            }
260        }
261        None
262    }
263
264    /// Compute an LU decomposition with partial pivoting.
265    ///
266    /// # Examples
267    /// ```
268    /// use la_stack::prelude::*;
269    ///
270    /// # fn main() -> Result<(), LaError> {
271    /// let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
272    /// let lu = a.lu(DEFAULT_PIVOT_TOL)?;
273    ///
274    /// let b = Vector::<2>::new([5.0, 11.0]);
275    /// let x = lu.solve_vec(b)?.into_array();
276    ///
277    /// assert!((x[0] - 1.0).abs() <= 1e-12);
278    /// assert!((x[1] - 2.0).abs() <= 1e-12);
279    /// # Ok(())
280    /// # }
281    /// ```
282    ///
283    /// # Errors
284    /// Returns [`LaError::Singular`] if, for some column `k`, the largest-magnitude candidate pivot
285    /// in that column satisfies `|pivot| <= tol` (so no numerically usable pivot exists).
286    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
287    #[inline]
288    pub fn lu(self, tol: f64) -> Result<Lu<D>, LaError> {
289        Lu::factor(self, tol)
290    }
291
292    /// Compute an LDLT factorization (`A = L D Lᵀ`) without pivoting.
293    ///
294    /// This is intended for symmetric positive definite (SPD) and positive semi-definite (PSD)
295    /// matrices such as Gram matrices.
296    ///
297    /// # Preconditions
298    /// **The input matrix `self` must be symmetric** — that is, `self[i][j] == self[j][i]`
299    /// (within rounding) for all `i`, `j`.  This is a *correctness* precondition, not merely
300    /// a performance hint.
301    ///
302    /// - In **debug builds** a `debug_assert!` verifies symmetry via
303    ///   [`is_symmetric`](Self::is_symmetric) (relative tolerance scaled by the matrix's
304    ///   infinity norm) and panics if it fails.
305    /// - In **release builds** the check is compiled out for performance.  An asymmetric
306    ///   input will be accepted silently and produce a mathematically meaningless
307    ///   factorization — subsequent calls to [`Ldlt::det`] and [`Ldlt::solve_vec`] will
308    ///   return wrong results with no error.
309    ///
310    /// Callers who cannot statically guarantee symmetry should pre-validate with
311    /// [`is_symmetric`](Self::is_symmetric) (or locate the offending pair with
312    /// [`first_asymmetry`](Self::first_asymmetry)) before calling `ldlt`.  If you need a
313    /// general-purpose factorization that tolerates non-symmetric inputs, use
314    /// [`lu`](Self::lu) instead.
315    ///
316    /// # Examples
317    /// ```
318    /// use la_stack::prelude::*;
319    ///
320    /// # fn main() -> Result<(), LaError> {
321    /// // Note the symmetric layout: a[0][1] == a[1][0] == 2.0.
322    /// let a = Matrix::<2>::from_rows([[4.0, 2.0], [2.0, 3.0]]);
323    /// let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL)?;
324    ///
325    /// // det(A) = 8
326    /// assert!((ldlt.det() - 8.0).abs() <= 1e-12);
327    ///
328    /// // Solve A x = b
329    /// let b = Vector::<2>::new([1.0, 2.0]);
330    /// let x = ldlt.solve_vec(b)?.into_array();
331    /// assert!((x[0] - (-0.125)).abs() <= 1e-12);
332    /// assert!((x[1] - 0.75).abs() <= 1e-12);
333    /// # Ok(())
334    /// # }
335    /// ```
336    ///
337    /// # Errors
338    /// Returns [`LaError::Singular`] if, for some step `k`, the required diagonal entry `d = D[k,k]`
339    /// is `<= tol` (non-positive or too small). This treats PSD degeneracy (and indefinite inputs)
340    /// as singular/degenerate.
341    /// Returns [`LaError::NonFinite`] if NaN/∞ is detected during factorization.
342    ///
343    /// Note that an *asymmetric* input is **not** reported as an error in release builds —
344    /// see the [Preconditions](#preconditions) section above.
345    #[inline]
346    pub fn ldlt(self, tol: f64) -> Result<Ldlt<D>, LaError> {
347        Ldlt::factor(self, tol)
348    }
349
350    /// Closed-form determinant for dimensions 0–4, bypassing LU factorization.
351    ///
352    /// Returns `Some(det)` for `D` ∈ {0, 1, 2, 3, 4}, `None` for D ≥ 5.
353    /// `D = 0` returns `Some(1.0)` (empty product).
354    /// This is a `const fn` (Rust 1.94+) and uses fused multiply-add (`mul_add`)
355    /// for improved accuracy and performance.
356    ///
357    /// For a determinant that works for any dimension (falling back to LU for D ≥ 5),
358    /// use [`det`](Self::det).
359    ///
360    /// # Examples
361    /// ```
362    /// use la_stack::prelude::*;
363    ///
364    /// let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
365    /// assert!((m.det_direct().unwrap() - (-2.0)).abs() <= 1e-12);
366    ///
367    /// // D = 0 is the empty product.
368    /// assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0));
369    ///
370    /// // D ≥ 5 returns None.
371    /// assert!(Matrix::<5>::identity().det_direct().is_none());
372    /// ```
373    #[inline]
374    #[must_use]
375    pub const fn det_direct(&self) -> Option<f64> {
376        match D {
377            0 => Some(1.0),
378            1 => Some(self.rows[0][0]),
379            2 => {
380                // ad - bc
381                Some(self.rows[0][0].mul_add(self.rows[1][1], -(self.rows[0][1] * self.rows[1][0])))
382            }
383            3 => {
384                // Cofactor expansion on first row.
385                let m00 =
386                    self.rows[1][1].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][1]));
387                let m01 =
388                    self.rows[1][0].mul_add(self.rows[2][2], -(self.rows[1][2] * self.rows[2][0]));
389                let m02 =
390                    self.rows[1][0].mul_add(self.rows[2][1], -(self.rows[1][1] * self.rows[2][0]));
391                Some(
392                    self.rows[0][0]
393                        .mul_add(m00, (-self.rows[0][1]).mul_add(m01, self.rows[0][2] * m02)),
394                )
395            }
396            4 => {
397                // Cofactor expansion on first row → four 3×3 sub-determinants.
398                // Hoist the 6 unique 2×2 minors from rows 2–3 (each used twice).
399                let r = &self.rows;
400
401                // 2×2 minors: s_ij = r[2][i]*r[3][j] - r[2][j]*r[3][i]
402                let s23 = r[2][2].mul_add(r[3][3], -(r[2][3] * r[3][2])); // cols 2,3
403                let s13 = r[2][1].mul_add(r[3][3], -(r[2][3] * r[3][1])); // cols 1,3
404                let s12 = r[2][1].mul_add(r[3][2], -(r[2][2] * r[3][1])); // cols 1,2
405                let s03 = r[2][0].mul_add(r[3][3], -(r[2][3] * r[3][0])); // cols 0,3
406                let s02 = r[2][0].mul_add(r[3][2], -(r[2][2] * r[3][0])); // cols 0,2
407                let s01 = r[2][0].mul_add(r[3][1], -(r[2][1] * r[3][0])); // cols 0,1
408
409                // 3×3 cofactors via row 1 expansion using hoisted minors.
410                let c00 = r[1][1].mul_add(s23, (-r[1][2]).mul_add(s13, r[1][3] * s12));
411                let c01 = r[1][0].mul_add(s23, (-r[1][2]).mul_add(s03, r[1][3] * s02));
412                let c02 = r[1][0].mul_add(s13, (-r[1][1]).mul_add(s03, r[1][3] * s01));
413                let c03 = r[1][0].mul_add(s12, (-r[1][1]).mul_add(s02, r[1][2] * s01));
414
415                Some(r[0][0].mul_add(
416                    c00,
417                    (-r[0][1]).mul_add(c01, r[0][2].mul_add(c02, -(r[0][3] * c03))),
418                ))
419            }
420            _ => {
421                // Cold in the common D ≤ 4 case; callers fall back to LU for D ≥ 5.
422                cold_path();
423                None
424            }
425        }
426    }
427
428    /// Determinant, using closed-form formulas for D ≤ 4 and LU decomposition for D ≥ 5.
429    ///
430    /// For D ∈ {1, 2, 3, 4}, this bypasses LU factorization entirely for a significant
431    /// speedup (see [`det_direct`](Self::det_direct)). The `tol` parameter is only used
432    /// by the LU fallback path for D ≥ 5.
433    ///
434    /// # Examples
435    /// ```
436    /// use la_stack::prelude::*;
437    ///
438    /// # fn main() -> Result<(), LaError> {
439    /// let det = Matrix::<3>::identity().det(DEFAULT_PIVOT_TOL)?;
440    /// assert!((det - 1.0).abs() <= 1e-12);
441    /// # Ok(())
442    /// # }
443    /// ```
444    ///
445    /// # Errors
446    /// Returns [`LaError::NonFinite`] if the result contains NaN or infinity.
447    /// For D ≥ 5, propagates LU factorization errors (e.g. [`LaError::Singular`]).
448    #[inline]
449    pub fn det(self, tol: f64) -> Result<f64, LaError> {
450        if let Some(d) = self.det_direct() {
451            return if d.is_finite() {
452                Ok(d)
453            } else {
454                cold_path();
455                // Scan for the first non-finite entry to preserve coordinates.
456                for r in 0..D {
457                    for c in 0..D {
458                        if !self.rows[r][c].is_finite() {
459                            return Err(LaError::non_finite_cell(r, c));
460                        }
461                    }
462                }
463                // All entries are finite but the determinant overflowed.
464                Err(LaError::non_finite_at(0))
465            };
466        }
467        self.lu(tol).map(|lu| lu.det())
468    }
469
470    /// Conservative absolute error bound for `det_direct()`.
471    ///
472    /// Returns `Some(bound)` such that `|det_direct() - det_exact| ≤ bound`,
473    /// or `None` for D ≥ 5 where no fast bound is available.
474    ///
475    /// For D ≤ 4, the bound is derived from the absolute Leibniz sum using
476    /// Shewchuk-style error analysis (see `REFERENCES.md` \[8\] and the
477    /// per-constant docs on [`ERR_COEFF_2`](crate::ERR_COEFF_2),
478    /// [`ERR_COEFF_3`](crate::ERR_COEFF_3), and
479    /// [`ERR_COEFF_4`](crate::ERR_COEFF_4)). For D = 0 or 1, returns
480    /// `Some(0.0)` since the determinant computation is exact (no
481    /// arithmetic).
482    ///
483    /// This method does NOT require the `exact` feature — the bounds use
484    /// pure f64 arithmetic and are useful for custom adaptive-precision logic.
485    ///
486    /// # When to use
487    ///
488    /// Use this to build adaptive-precision logic: if `|det_direct()| > bound`,
489    /// the f64 sign is provably correct. Otherwise fall back to exact arithmetic.
490    ///
491    /// # Examples
492    /// ```
493    /// use la_stack::prelude::*;
494    ///
495    /// let m = Matrix::<3>::from_rows([
496    ///     [1.0, 2.0, 3.0],
497    ///     [4.0, 5.0, 6.0],
498    ///     [7.0, 8.0, 9.0],
499    /// ]);
500    /// let bound = m.det_errbound().unwrap();
501    /// let det_approx = m.det_direct().unwrap();
502    /// // If |det_approx| > bound, the sign is guaranteed correct.
503    /// ```
504    ///
505    /// # Adaptive precision pattern (requires `exact` feature)
506    /// ```ignore
507    /// use la_stack::prelude::*;
508    ///
509    /// let m = Matrix::<3>::identity();
510    /// if let Some(bound) = m.det_errbound() {
511    ///     let det = m.det_direct().unwrap();
512    ///     if det.abs() > bound {
513    ///         // f64 sign is guaranteed correct
514    ///         let sign = det.signum() as i8;
515    ///     } else {
516    ///         // Fall back to exact arithmetic (requires `exact` feature)
517    ///         let sign = m.det_sign_exact().unwrap();
518    ///     }
519    /// } else {
520    ///     // D ≥ 5: no fast filter, use exact directly
521    ///     let sign = m.det_sign_exact().unwrap();
522    /// }
523    /// ```
524    #[must_use]
525    #[inline]
526    pub const fn det_errbound(&self) -> Option<f64> {
527        match D {
528            0 | 1 => Some(0.0), // No arithmetic — result is exact.
529            2 => {
530                let r = &self.rows;
531                let permanent = (r[0][0] * r[1][1]).abs() + (r[0][1] * r[1][0]).abs();
532                Some(ERR_COEFF_2 * permanent)
533            }
534            3 => {
535                let r = &self.rows;
536                let pm00 = (r[1][1] * r[2][2]).abs() + (r[1][2] * r[2][1]).abs();
537                let pm01 = (r[1][0] * r[2][2]).abs() + (r[1][2] * r[2][0]).abs();
538                let pm02 = (r[1][0] * r[2][1]).abs() + (r[1][1] * r[2][0]).abs();
539                let permanent = r[0][2]
540                    .abs()
541                    .mul_add(pm02, r[0][1].abs().mul_add(pm01, r[0][0].abs() * pm00));
542                Some(ERR_COEFF_3 * permanent)
543            }
544            4 => {
545                let r = &self.rows;
546                // 2×2 minor permanents from rows 2–3.
547                let sp23 = (r[2][2] * r[3][3]).abs() + (r[2][3] * r[3][2]).abs();
548                let sp13 = (r[2][1] * r[3][3]).abs() + (r[2][3] * r[3][1]).abs();
549                let sp12 = (r[2][1] * r[3][2]).abs() + (r[2][2] * r[3][1]).abs();
550                let sp03 = (r[2][0] * r[3][3]).abs() + (r[2][3] * r[3][0]).abs();
551                let sp02 = (r[2][0] * r[3][2]).abs() + (r[2][2] * r[3][0]).abs();
552                let sp01 = (r[2][0] * r[3][1]).abs() + (r[2][1] * r[3][0]).abs();
553                // 3×3 cofactor permanents from row 1.
554                let pc0 = r[1][3]
555                    .abs()
556                    .mul_add(sp12, r[1][2].abs().mul_add(sp13, r[1][1].abs() * sp23));
557                let pc1 = r[1][3]
558                    .abs()
559                    .mul_add(sp02, r[1][2].abs().mul_add(sp03, r[1][0].abs() * sp23));
560                let pc2 = r[1][3]
561                    .abs()
562                    .mul_add(sp01, r[1][1].abs().mul_add(sp03, r[1][0].abs() * sp13));
563                let pc3 = r[1][2]
564                    .abs()
565                    .mul_add(sp01, r[1][1].abs().mul_add(sp02, r[1][0].abs() * sp12));
566                let permanent = r[0][3].abs().mul_add(
567                    pc3,
568                    r[0][2]
569                        .abs()
570                        .mul_add(pc2, r[0][1].abs().mul_add(pc1, r[0][0].abs() * pc0)),
571                );
572                Some(ERR_COEFF_4 * permanent)
573            }
574            _ => None,
575        }
576    }
577}
578
579impl<const D: usize> Default for Matrix<D> {
580    #[inline]
581    fn default() -> Self {
582        Self::zero()
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589    use crate::DEFAULT_PIVOT_TOL;
590
591    use approx::assert_abs_diff_eq;
592    use pastey::paste;
593    use std::hint::black_box;
594
595    macro_rules! gen_public_api_matrix_tests {
596        ($d:literal) => {
597            paste! {
598                #[test]
599                fn [<public_api_matrix_from_rows_get_set_bounds_checked_ $d d>]() {
600                    let mut rows = [[0.0f64; $d]; $d];
601                    rows[0][0] = 1.0;
602                    rows[$d - 1][$d - 1] = -2.0;
603
604                    let mut m = Matrix::<$d>::from_rows(rows);
605
606                    assert_eq!(m.get(0, 0), Some(1.0));
607                    assert_eq!(m.get($d - 1, $d - 1), Some(-2.0));
608
609                    // Out-of-bounds is None.
610                    assert_eq!(m.get($d, 0), None);
611
612                    // Out-of-bounds set fails.
613                    assert!(!m.set($d, 0, 3.0));
614
615                    // In-bounds set works.
616                    assert!(m.set(0, $d - 1, 3.0));
617                    assert_eq!(m.get(0, $d - 1), Some(3.0));
618                }
619
620                #[test]
621                fn [<public_api_matrix_zero_and_default_are_zero_ $d d>]() {
622                    let z = Matrix::<$d>::zero();
623                    assert_abs_diff_eq!(z.inf_norm(), 0.0, epsilon = 0.0);
624
625                    let d = Matrix::<$d>::default();
626                    assert_abs_diff_eq!(d.inf_norm(), 0.0, epsilon = 0.0);
627                }
628
629                #[test]
630                fn [<public_api_matrix_inf_norm_max_row_sum_ $d d>]() {
631                    let mut rows = [[0.0f64; $d]; $d];
632
633                    // Row 0 has absolute row sum = D.
634                    for c in 0..$d {
635                        rows[0][c] = -1.0;
636                    }
637
638                    // Row 1 has smaller absolute row sum.
639                    for c in 0..$d {
640                        rows[1][c] = 0.5;
641                    }
642
643                    let m = Matrix::<$d>::from_rows(rows);
644                    assert_abs_diff_eq!(m.inf_norm(), f64::from($d), epsilon = 0.0);
645                }
646
647                #[test]
648                fn [<public_api_matrix_identity_lu_det_solve_vec_ $d d>]() {
649                    let m = Matrix::<$d>::identity();
650
651                    // Identity has ones on diag and zeros off diag.
652                    for r in 0..$d {
653                        for c in 0..$d {
654                            let expected = if r == c { 1.0 } else { 0.0 };
655                            assert_abs_diff_eq!(m.get(r, c).unwrap(), expected, epsilon = 0.0);
656                        }
657                    }
658
659                    // Determinant is 1.
660                    let det = m.det(DEFAULT_PIVOT_TOL).unwrap();
661                    assert_abs_diff_eq!(det, 1.0, epsilon = 1e-12);
662
663                    // LU solve on identity returns the RHS.
664                    let lu = m.lu(DEFAULT_PIVOT_TOL).unwrap();
665
666                    let b_arr = {
667                        let mut arr = [0.0f64; $d];
668                        let values = [1.0f64, 2.0, 3.0, 4.0, 5.0];
669                        for (dst, src) in arr.iter_mut().zip(values.iter()) {
670                            *dst = *src;
671                        }
672                        arr
673                    };
674
675                    let b = crate::Vector::<$d>::new(b_arr);
676                    let x = lu.solve_vec(b).unwrap().into_array();
677
678                    for (x_i, b_i) in x.iter().zip(b_arr.iter()) {
679                        assert_abs_diff_eq!(*x_i, *b_i, epsilon = 1e-12);
680                    }
681                }
682            }
683        };
684    }
685
686    // Mirror delaunay-style multi-dimension tests.
687    gen_public_api_matrix_tests!(2);
688    gen_public_api_matrix_tests!(3);
689    gen_public_api_matrix_tests!(4);
690    gen_public_api_matrix_tests!(5);
691
692    // === det_direct tests ===
693
694    #[test]
695    fn det_direct_d0_is_one() {
696        assert_eq!(Matrix::<0>::zero().det_direct(), Some(1.0));
697    }
698
699    #[test]
700    fn det_direct_d1_returns_element() {
701        let m = Matrix::<1>::from_rows([[42.0]]);
702        assert_eq!(m.det_direct(), Some(42.0));
703    }
704
705    #[test]
706    fn det_direct_d2_known_value() {
707        // [[1,2],[3,4]] → det = 1*4 - 2*3 = -2
708        // black_box prevents compile-time constant folding of the const fn.
709        let m = black_box(Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]));
710        assert_abs_diff_eq!(m.det_direct().unwrap(), -2.0, epsilon = 1e-15);
711    }
712
713    #[test]
714    fn det_direct_d3_known_value() {
715        // Classic 3×3: det = 0
716        let m = black_box(Matrix::<3>::from_rows([
717            [1.0, 2.0, 3.0],
718            [4.0, 5.0, 6.0],
719            [7.0, 8.0, 9.0],
720        ]));
721        assert_abs_diff_eq!(m.det_direct().unwrap(), 0.0, epsilon = 1e-12);
722    }
723
724    #[test]
725    fn det_direct_d3_nonsingular() {
726        // [[2,1,0],[0,3,1],[1,0,2]] → det = 2*(6-0) - 1*(0-1) + 0 = 13
727        let m = black_box(Matrix::<3>::from_rows([
728            [2.0, 1.0, 0.0],
729            [0.0, 3.0, 1.0],
730            [1.0, 0.0, 2.0],
731        ]));
732        assert_abs_diff_eq!(m.det_direct().unwrap(), 13.0, epsilon = 1e-12);
733    }
734
735    #[test]
736    fn det_direct_d4_identity() {
737        let m = black_box(Matrix::<4>::identity());
738        assert_abs_diff_eq!(m.det_direct().unwrap(), 1.0, epsilon = 1e-15);
739    }
740
741    #[test]
742    fn det_direct_d4_known_value() {
743        // Diagonal matrix: det = product of diagonal entries.
744        let mut rows = [[0.0f64; 4]; 4];
745        rows[0][0] = 2.0;
746        rows[1][1] = 3.0;
747        rows[2][2] = 5.0;
748        rows[3][3] = 7.0;
749        let m = black_box(Matrix::<4>::from_rows(rows));
750        assert_abs_diff_eq!(m.det_direct().unwrap(), 210.0, epsilon = 1e-12);
751    }
752
753    #[test]
754    fn det_direct_d5_returns_none() {
755        assert_eq!(Matrix::<5>::identity().det_direct(), None);
756    }
757
758    #[test]
759    fn det_direct_d8_returns_none() {
760        assert_eq!(Matrix::<8>::zero().det_direct(), None);
761    }
762
763    macro_rules! gen_det_direct_agrees_with_lu {
764        ($d:literal) => {
765            paste! {
766                #[test]
767                #[allow(clippy::cast_precision_loss)] // r, c, D are tiny integers
768                fn [<det_direct_agrees_with_lu_ $d d>]() {
769                    // Well-conditioned matrix: diagonally dominant.
770                    let mut rows = [[0.0f64; $d]; $d];
771                    for r in 0..$d {
772                        for c in 0..$d {
773                            rows[r][c] = if r == c {
774                                (r as f64) + f64::from($d) + 1.0
775                            } else {
776                                0.1 / ((r + c + 1) as f64)
777                            };
778                        }
779                    }
780                    let m = Matrix::<$d>::from_rows(rows);
781                    let direct = m.det_direct().unwrap();
782                    let lu_det = m.lu(DEFAULT_PIVOT_TOL).unwrap().det();
783                    let eps = lu_det.abs().mul_add(1e-12, 1e-12);
784                    assert_abs_diff_eq!(direct, lu_det, epsilon = eps);
785                }
786            }
787        };
788    }
789
790    gen_det_direct_agrees_with_lu!(1);
791    gen_det_direct_agrees_with_lu!(2);
792    gen_det_direct_agrees_with_lu!(3);
793    gen_det_direct_agrees_with_lu!(4);
794
795    #[test]
796    fn det_direct_identity_all_dims() {
797        assert_abs_diff_eq!(
798            Matrix::<1>::identity().det_direct().unwrap(),
799            1.0,
800            epsilon = 0.0
801        );
802        assert_abs_diff_eq!(
803            Matrix::<2>::identity().det_direct().unwrap(),
804            1.0,
805            epsilon = 0.0
806        );
807        assert_abs_diff_eq!(
808            Matrix::<3>::identity().det_direct().unwrap(),
809            1.0,
810            epsilon = 0.0
811        );
812        assert_abs_diff_eq!(
813            Matrix::<4>::identity().det_direct().unwrap(),
814            1.0,
815            epsilon = 0.0
816        );
817    }
818
819    #[test]
820    fn det_direct_zero_matrix() {
821        assert_abs_diff_eq!(
822            Matrix::<2>::zero().det_direct().unwrap(),
823            0.0,
824            epsilon = 0.0
825        );
826        assert_abs_diff_eq!(
827            Matrix::<3>::zero().det_direct().unwrap(),
828            0.0,
829            epsilon = 0.0
830        );
831        assert_abs_diff_eq!(
832            Matrix::<4>::zero().det_direct().unwrap(),
833            0.0,
834            epsilon = 0.0
835        );
836    }
837
838    #[test]
839    fn det_returns_nonfinite_error_for_nan_d2() {
840        let m = Matrix::<2>::from_rows([[f64::NAN, 1.0], [1.0, 1.0]]);
841        assert_eq!(
842            m.det(DEFAULT_PIVOT_TOL),
843            Err(LaError::NonFinite {
844                row: Some(0),
845                col: 0
846            })
847        );
848    }
849
850    #[test]
851    fn det_returns_nonfinite_error_for_inf_d3() {
852        let m =
853            Matrix::<3>::from_rows([[f64::INFINITY, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
854        assert_eq!(
855            m.det(DEFAULT_PIVOT_TOL),
856            Err(LaError::NonFinite {
857                row: Some(0),
858                col: 0
859            })
860        );
861    }
862
863    #[test]
864    fn det_returns_nonfinite_error_for_overflow_with_finite_entries() {
865        // det_direct produces an overflowing f64 (1e300 * 1e300 = ∞) even
866        // though every matrix entry is finite.  The entry scan in `det`
867        // falls through and returns NonFinite { row: None, col: 0 } to signal
868        // a computed overflow rather than a NaN/∞ input.
869        let m = Matrix::<2>::from_rows([[1e300, 0.0], [0.0, 1e300]]);
870        assert_eq!(
871            m.det(DEFAULT_PIVOT_TOL),
872            Err(LaError::NonFinite { row: None, col: 0 })
873        );
874    }
875
876    // === det_direct const-evaluability tests (D = 2..=5) ===
877    //
878    // Every dimension hits a distinct arm of the `match D { … }` body inside
879    // `det_direct`, so exercising each at compile time is the tightest
880    // const-fn proof available.
881
882    macro_rules! gen_det_direct_const_eval_tests {
883        ($d:literal) => {
884            paste! {
885                /// `Matrix::<D>::det_direct()` on the identity must const-evaluate
886                /// to `Some(1.0)` for every closed-form dimension `D ∈ {1, 2, 3, 4}`.
887                #[test]
888                fn [<det_direct_const_eval_ $d d>]() {
889                    const DET: Option<f64> = Matrix::<$d>::identity().det_direct();
890                    assert_eq!(DET, Some(1.0));
891                }
892            }
893        };
894    }
895
896    gen_det_direct_const_eval_tests!(2);
897    gen_det_direct_const_eval_tests!(3);
898    gen_det_direct_const_eval_tests!(4);
899
900    #[test]
901    fn det_direct_const_eval_d5_is_none() {
902        // D ≥ 5 has no closed-form arm; `det_direct` returns `None`.  Verify
903        // that the wildcard arm is reachable in a `const { … }` context.
904        const DET: Option<f64> = Matrix::<5>::identity().det_direct();
905        assert_eq!(DET, None);
906    }
907
908    // === det_errbound tests (no `exact` feature required) ===
909
910    #[test]
911    fn det_errbound_available_without_exact_feature() {
912        // Verify det_errbound is accessible without exact feature
913        let m = Matrix::<3>::identity();
914        let bound = m.det_errbound();
915        assert!(bound.is_some());
916        assert!(bound.unwrap() > 0.0);
917    }
918
919    #[test]
920    fn det_errbound_d5_returns_none() {
921        // D=5 has no fast filter
922        assert_eq!(Matrix::<5>::identity().det_errbound(), None);
923    }
924
925    // === det_errbound const-evaluability tests (D = 2..=5) ===
926
927    macro_rules! gen_det_errbound_const_eval_tests {
928        ($d:literal) => {
929            paste! {
930                /// `Matrix::<D>::det_errbound()` on the identity must const-evaluate
931                /// to `Some(bound)` with `bound > 0` for every closed-form dimension
932                /// `D ∈ {2, 3, 4}`.  Each dimension hits a distinct arm of
933                /// `det_errbound` with a dimension-specific permanent computation.
934                #[test]
935                fn [<det_errbound_const_eval_ $d d>]() {
936                    const BOUND: Option<f64> = Matrix::<$d>::identity().det_errbound();
937                    assert!(BOUND.is_some());
938                    assert!(BOUND.unwrap() > 0.0);
939                }
940            }
941        };
942    }
943
944    gen_det_errbound_const_eval_tests!(2);
945    gen_det_errbound_const_eval_tests!(3);
946    gen_det_errbound_const_eval_tests!(4);
947
948    #[test]
949    fn det_errbound_const_eval_d5_is_none() {
950        // D ≥ 5 has no fast-filter bound; `det_errbound` returns `None`.
951        const BOUND: Option<f64> = Matrix::<5>::identity().det_errbound();
952        assert_eq!(BOUND, None);
953    }
954
955    // === inf_norm const-evaluability tests (D = 2..=5) ===
956
957    macro_rules! gen_inf_norm_const_eval_tests {
958        ($d:literal) => {
959            paste! {
960                /// `Matrix::<D>::inf_norm()` on the identity must const-evaluate
961                /// to `1.0` for every `D ≥ 1` — each row has a single `1.0`
962                /// entry, so the max absolute row sum is exactly `1.0`.
963                #[test]
964                fn [<inf_norm_const_eval_ $d d>]() {
965                    const NORM: f64 = Matrix::<$d>::identity().inf_norm();
966                    assert!((NORM - 1.0).abs() <= 1e-12);
967                }
968            }
969        };
970    }
971
972    gen_inf_norm_const_eval_tests!(2);
973    gen_inf_norm_const_eval_tests!(3);
974    gen_inf_norm_const_eval_tests!(4);
975    gen_inf_norm_const_eval_tests!(5);
976
977    // === inf_norm NaN / Inf propagation (regression tests for #85) ===
978
979    macro_rules! gen_inf_norm_nonfinite_tests {
980        ($d:literal) => {
981            paste! {
982                #[test]
983                fn [<inf_norm_all_nan_returns_nan_ $d d>]() {
984                    // Before the fix, `NaN > max_row_sum` was always false, so a
985                    // matrix full of NaN silently produced inf_norm == 0.0.
986                    let m = Matrix::<$d>::from_rows([[f64::NAN; $d]; $d]);
987                    assert!(m.inf_norm().is_nan());
988                }
989
990                #[test]
991                fn [<inf_norm_single_nan_entry_returns_nan_ $d d>]() {
992                    // A single NaN entry must contaminate its row sum and
993                    // propagate through `f64::maximum` to the final result.
994                    let mut rows = [[0.0f64; $d]; $d];
995                    rows[0][0] = f64::NAN;
996                    rows[$d - 1][$d - 1] = 1.0;
997                    let m = Matrix::<$d>::from_rows(rows);
998                    assert!(m.inf_norm().is_nan());
999                }
1000
1001                #[test]
1002                fn [<inf_norm_infinity_entry_propagates_ $d d>]() {
1003                    // Infinity entries should propagate to +∞ via the row sum,
1004                    // not be silently dropped.  The norm is a sum of absolute
1005                    // values, so any infinite result is necessarily +∞.
1006                    let mut rows = [[0.0f64; $d]; $d];
1007                    rows[0][0] = f64::INFINITY;
1008                    let m = Matrix::<$d>::from_rows(rows);
1009                    let norm = m.inf_norm();
1010                    assert!(norm.is_infinite() && norm.is_sign_positive());
1011                }
1012            }
1013        };
1014    }
1015
1016    gen_inf_norm_nonfinite_tests!(2);
1017    gen_inf_norm_nonfinite_tests!(3);
1018    gen_inf_norm_nonfinite_tests!(4);
1019    gen_inf_norm_nonfinite_tests!(5);
1020
1021    // === is_symmetric / first_asymmetry (public LDLT preconditions helpers) ===
1022
1023    macro_rules! gen_is_symmetric_tests {
1024        ($d:literal) => {
1025            paste! {
1026                #[test]
1027                fn [<is_symmetric_true_for_identity_ $d d>]() {
1028                    let m = Matrix::<$d>::identity();
1029                    assert!(m.is_symmetric(1e-12));
1030                    assert_eq!(m.first_asymmetry(1e-12), None);
1031                }
1032
1033                #[test]
1034                fn [<is_symmetric_true_for_zero_ $d d>]() {
1035                    let m = Matrix::<$d>::zero();
1036                    assert!(m.is_symmetric(1e-12));
1037                    assert_eq!(m.first_asymmetry(1e-12), None);
1038                }
1039
1040                #[test]
1041                fn [<is_symmetric_true_for_constructed_symmetric_ $d d>]() {
1042                    // Construct A = M + Mᵀ so A is provably symmetric.
1043                    let mut m = [[0.0f64; $d]; $d];
1044                    for r in 0..$d {
1045                        for c in 0..$d {
1046                            #[allow(clippy::cast_precision_loss)]
1047                            {
1048                                m[r][c] = (r * $d + c) as f64;
1049                            }
1050                        }
1051                    }
1052                    let mut sym = [[0.0f64; $d]; $d];
1053                    for r in 0..$d {
1054                        for c in 0..$d {
1055                            sym[r][c] = m[r][c] + m[c][r];
1056                        }
1057                    }
1058                    let a = Matrix::<$d>::from_rows(sym);
1059                    assert!(a.is_symmetric(1e-12));
1060                    assert_eq!(a.first_asymmetry(1e-12), None);
1061                }
1062
1063                #[test]
1064                fn [<is_symmetric_false_for_asymmetric_offdiagonal_ $d d>]() {
1065                    // Perturb a single off-diagonal entry so symmetry fails.
1066                    let mut rows = [[0.0f64; $d]; $d];
1067                    for i in 0..$d {
1068                        rows[i][i] = 1.0;
1069                    }
1070                    rows[0][$d - 1] = 1.0;
1071                    rows[$d - 1][0] = -1.0; // breaks symmetry
1072                    let a = Matrix::<$d>::from_rows(rows);
1073                    assert!(!a.is_symmetric(1e-12));
1074                    assert_eq!(a.first_asymmetry(1e-12), Some((0, $d - 1)));
1075                }
1076
1077                #[test]
1078                fn [<is_symmetric_false_for_nan_offdiagonal_ $d d>]() {
1079                    // A NaN off-diagonal must be detected as asymmetric.
1080                    let mut rows = [[0.0f64; $d]; $d];
1081                    for i in 0..$d {
1082                        rows[i][i] = 1.0;
1083                    }
1084                    rows[0][1] = f64::NAN;
1085                    rows[1][0] = f64::NAN;
1086                    let a = Matrix::<$d>::from_rows(rows);
1087                    assert!(!a.is_symmetric(1e-12));
1088                    // (0, 1) is the first upper-triangular pair involving the NaN.
1089                    assert_eq!(a.first_asymmetry(1e-12), Some((0, 1)));
1090                }
1091            }
1092        };
1093    }
1094
1095    gen_is_symmetric_tests!(2);
1096    gen_is_symmetric_tests!(3);
1097    gen_is_symmetric_tests!(4);
1098    gen_is_symmetric_tests!(5);
1099
1100    #[test]
1101    fn is_symmetric_tolerance_scales_with_inf_norm() {
1102        // Off-diagonal entries differ by 1e-6.  With inf_norm ≈ 2e6, the
1103        // relative tolerance 1e-12 yields eps ≈ 2e-6, which accepts the gap;
1104        // a stricter tol of 1e-15 rejects it.
1105        let a = Matrix::<2>::from_rows([[1.0e6, 1.0e6 + 1.0e-6], [1.0e6, 1.0e6]]);
1106        assert!(a.is_symmetric(1e-12));
1107        assert!(!a.is_symmetric(1e-15));
1108    }
1109
1110    #[test]
1111    fn first_asymmetry_returns_lexicographically_first_pair() {
1112        // Two asymmetric pairs: (0, 2) and (1, 2).  We must get (0, 2) first.
1113        let a = Matrix::<3>::from_rows([[1.0, 0.0, 2.0], [0.0, 1.0, 3.0], [-2.0, -3.0, 1.0]]);
1114        assert_eq!(a.first_asymmetry(1e-12), Some((0, 2)));
1115    }
1116
1117    /// Regression: a single infinite off-diagonal paired with a finite entry
1118    /// used to slip through as "symmetric" because `inf_norm()` blew `eps` up
1119    /// to `+∞` and `∞ > ∞` evaluates to `false`.  After the fix, any
1120    /// non-finite `|a[r][c] - a[c][r]|` is reported as an asymmetry regardless
1121    /// of `eps`.
1122    #[test]
1123    fn first_asymmetry_flags_infinite_offdiagonal_against_finite() {
1124        let a = Matrix::<2>::from_rows([[1.0, f64::INFINITY], [0.0, 1.0]]);
1125        assert_eq!(a.first_asymmetry(1e-12), Some((0, 1)));
1126        assert!(!a.is_symmetric(1e-12));
1127    }
1128
1129    #[cfg(debug_assertions)]
1130    #[test]
1131    #[should_panic(expected = "rel_tol must be non-negative")]
1132    fn first_asymmetry_debug_panics_on_negative_tol() {
1133        // Mirrors the `debug_assert!(tol >= 0.0)` convention used by
1134        // `Matrix::lu` / `Matrix::ldlt`.
1135        let _ = Matrix::<2>::identity().first_asymmetry(-1.0);
1136    }
1137}