Skip to main content

numra_linalg/
matrix.rs

1//! Matrix trait and implementations.
2//!
3//! Author: Moussa Leblouba
4//! Date: 8 February 2026
5//! Modified: 2 May 2026
6
7use crate::Scalar;
8use faer::prelude::*;
9use faer::{ComplexField, Conjugate, Entity, Mat, MatMut, MatRef, SimpleEntity};
10use numra_core::LinalgError;
11
12/// Trait for matrix types.
13///
14/// Provides a backend-agnostic interface for matrix operations needed by ODE
15/// solvers.
16///
17/// **Sparse storage**: [`crate::SparseMatrix`] is not currently in this trait.
18/// Sparse-aware solvers (`crate::iterative`, `crate::preconditioner`) dispatch
19/// on `&SparseMatrix<S>` concretely rather than through `dyn Matrix<S>`.
20/// Whether sparse should join this trait is an open foundation question
21/// coupled to sparse-Jacobian-return, deferred per Foundation Spec §7 #5
22/// under the spec's named trigger ("when a sparse-aware solver path needs to
23/// dispatch across both dense and sparse via a single trait"). The current
24/// dense-only scope is the expedient state, not a deliberate design — see
25/// F-MATRIX-SHAPE in `docs/internal-followups.md` for the operational
26/// tracker, and Foundation Spec §3.2 + §7 #5 for the design context.
27pub trait Matrix<S: Scalar>: Clone + Sized {
28    /// Create a zero matrix with given dimensions.
29    fn zeros(rows: usize, cols: usize) -> Self;
30
31    /// Create an identity matrix.
32    fn identity(n: usize) -> Self;
33
34    /// Number of rows.
35    fn nrows(&self) -> usize;
36
37    /// Number of columns.
38    fn ncols(&self) -> usize;
39
40    /// Get element at (i, j).
41    fn get(&self, i: usize, j: usize) -> S;
42
43    /// Set element at (i, j).
44    fn set(&mut self, i: usize, j: usize, value: S);
45
46    /// Fill matrix with zeros.
47    fn fill_zero(&mut self);
48
49    /// Scale all elements by a constant.
50    fn scale(&mut self, alpha: S);
51
52    /// Compute y = self * x (matrix-vector product).
53    fn mul_vec(&self, x: &[S], y: &mut [S]);
54
55    /// Add another matrix: self += alpha * other.
56    fn add_scaled(&mut self, alpha: S, other: &Self);
57
58    /// Solve the linear system Ax = b, returning x.
59    fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError>;
60
61    /// Check if matrix is square.
62    fn is_square(&self) -> bool {
63        self.nrows() == self.ncols()
64    }
65}
66
67/// Dense matrix backed by faer.
68#[derive(Clone, Debug)]
69pub struct DenseMatrix<S: Scalar + Entity> {
70    data: Mat<S>,
71}
72
73impl<S: Scalar + Entity> DenseMatrix<S> {
74    /// Create from a faer Mat.
75    pub fn from_faer(mat: Mat<S>) -> Self {
76        Self { data: mat }
77    }
78
79    /// Get reference to underlying faer matrix.
80    pub fn as_faer(&self) -> MatRef<'_, S> {
81        self.data.as_ref()
82    }
83
84    /// Get mutable reference to underlying faer matrix.
85    pub fn as_faer_mut(&mut self) -> MatMut<'_, S> {
86        self.data.as_mut()
87    }
88
89    /// Create from row-major data.
90    pub fn from_row_major(rows: usize, cols: usize, data: &[S]) -> Self {
91        assert_eq!(data.len(), rows * cols);
92        let mut mat = Mat::zeros(rows, cols);
93        for i in 0..rows {
94            for j in 0..cols {
95                mat.write(i, j, data[i * cols + j]);
96            }
97        }
98        Self { data: mat }
99    }
100
101    /// Create from column-major data.
102    pub fn from_col_major(rows: usize, cols: usize, data: &[S]) -> Self {
103        assert_eq!(data.len(), rows * cols);
104        let mut mat = Mat::zeros(rows, cols);
105        for j in 0..cols {
106            for i in 0..rows {
107                mat.write(i, j, data[j * rows + i]);
108            }
109        }
110        Self { data: mat }
111    }
112
113    /// Convert to row-major vector.
114    pub fn to_row_major(&self) -> Vec<S> {
115        let (rows, cols) = (self.data.nrows(), self.data.ncols());
116        let mut data = Vec::with_capacity(rows * cols);
117        for i in 0..rows {
118            for j in 0..cols {
119                data.push(self.data.read(i, j));
120            }
121        }
122        data
123    }
124
125    /// Frobenius norm.
126    pub fn norm_frobenius(&self) -> S {
127        let mut sum = S::ZERO;
128        for i in 0..self.data.nrows() {
129            for j in 0..self.data.ncols() {
130                let v = self.data.read(i, j);
131                sum += v * v;
132            }
133        }
134        sum.sqrt()
135    }
136
137    /// Infinity norm (max row sum).
138    pub fn norm_inf(&self) -> S {
139        let mut max_sum = S::ZERO;
140        for i in 0..self.data.nrows() {
141            let mut row_sum = S::ZERO;
142            for j in 0..self.data.ncols() {
143                row_sum += self.data.read(i, j).abs();
144            }
145            max_sum = max_sum.max(row_sum);
146        }
147        max_sum
148    }
149
150    /// Number of rows (direct access without trait).
151    pub fn rows(&self) -> usize {
152        self.data.nrows()
153    }
154
155    /// Number of columns (direct access without trait).
156    pub fn cols(&self) -> usize {
157        self.data.ncols()
158    }
159
160    /// Check if square (direct access without trait).
161    pub fn is_square(&self) -> bool {
162        self.data.nrows() == self.data.ncols()
163    }
164}
165
166impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Matrix<S>
167    for DenseMatrix<S>
168{
169    fn zeros(rows: usize, cols: usize) -> Self {
170        Self {
171            data: Mat::zeros(rows, cols),
172        }
173    }
174
175    fn identity(n: usize) -> Self {
176        let mut mat = Mat::zeros(n, n);
177        for i in 0..n {
178            mat.write(i, i, S::ONE);
179        }
180        Self { data: mat }
181    }
182
183    fn nrows(&self) -> usize {
184        self.data.nrows()
185    }
186
187    fn ncols(&self) -> usize {
188        self.data.ncols()
189    }
190
191    fn get(&self, i: usize, j: usize) -> S {
192        self.data.read(i, j)
193    }
194
195    fn set(&mut self, i: usize, j: usize, value: S) {
196        self.data.write(i, j, value);
197    }
198
199    fn fill_zero(&mut self) {
200        for i in 0..self.nrows() {
201            for j in 0..self.ncols() {
202                self.data.write(i, j, S::ZERO);
203            }
204        }
205    }
206
207    fn scale(&mut self, alpha: S) {
208        for i in 0..self.nrows() {
209            for j in 0..self.ncols() {
210                let v = self.data.read(i, j);
211                self.data.write(i, j, alpha * v);
212            }
213        }
214    }
215
216    fn mul_vec(&self, x: &[S], y: &mut [S]) {
217        assert_eq!(x.len(), self.ncols());
218        assert_eq!(y.len(), self.nrows());
219
220        for (i, y_i) in y.iter_mut().enumerate().take(self.nrows()) {
221            let mut sum = S::ZERO;
222            for (j, &x_j) in x.iter().enumerate().take(self.ncols()) {
223                sum += self.data.read(i, j) * x_j;
224            }
225            *y_i = sum;
226        }
227    }
228
229    fn add_scaled(&mut self, alpha: S, other: &Self) {
230        assert_eq!(self.nrows(), other.nrows());
231        assert_eq!(self.ncols(), other.ncols());
232
233        for i in 0..self.nrows() {
234            for j in 0..self.ncols() {
235                let v = self.data.read(i, j) + alpha * other.data.read(i, j);
236                self.data.write(i, j, v);
237            }
238        }
239    }
240
241    fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
242        if !self.is_square() {
243            return Err(LinalgError::NotSquare {
244                nrows: self.nrows(),
245                ncols: self.ncols(),
246            });
247        }
248        if b.len() != self.nrows() {
249            return Err(LinalgError::DimensionMismatch {
250                expected: (self.nrows(), 1),
251                actual: (b.len(), 1),
252            });
253        }
254
255        // Use faer's LU solver
256        let lu = self.data.as_ref().partial_piv_lu();
257
258        // Create column vector from b
259        let mut b_mat = Mat::zeros(b.len(), 1);
260        for (i, &val) in b.iter().enumerate() {
261            b_mat.write(i, 0, val);
262        }
263
264        // Solve
265        let x_mat = lu.solve(&b_mat);
266
267        // Extract result
268        let mut x = Vec::with_capacity(b.len());
269        for i in 0..b.len() {
270            x.push(x_mat.read(i, 0));
271        }
272
273        Ok(x)
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_zeros() {
283        let m: DenseMatrix<f64> = DenseMatrix::zeros(3, 4);
284        assert_eq!(m.nrows(), 3);
285        assert_eq!(m.ncols(), 4);
286        for i in 0..3 {
287            for j in 0..4 {
288                assert!((m.get(i, j) - 0.0).abs() < 1e-15);
289            }
290        }
291    }
292
293    #[test]
294    fn test_identity() {
295        let m: DenseMatrix<f64> = DenseMatrix::identity(3);
296        assert_eq!(m.nrows(), 3);
297        assert_eq!(m.ncols(), 3);
298        for i in 0..3 {
299            for j in 0..3 {
300                let expected = if i == j { 1.0 } else { 0.0 };
301                assert!((m.get(i, j) - expected).abs() < 1e-15);
302            }
303        }
304    }
305
306    #[test]
307    fn test_set_get() {
308        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
309        m.set(0, 0, 1.0);
310        m.set(0, 1, 2.0);
311        m.set(1, 0, 3.0);
312        m.set(1, 1, 4.0);
313
314        assert!((m.get(0, 0) - 1.0).abs() < 1e-15);
315        assert!((m.get(0, 1) - 2.0).abs() < 1e-15);
316        assert!((m.get(1, 0) - 3.0).abs() < 1e-15);
317        assert!((m.get(1, 1) - 4.0).abs() < 1e-15);
318    }
319
320    #[test]
321    fn test_mul_vec() {
322        // [1 2] * [1]   [5]
323        // [3 4]   [2] = [11]
324        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
325        m.set(0, 0, 1.0);
326        m.set(0, 1, 2.0);
327        m.set(1, 0, 3.0);
328        m.set(1, 1, 4.0);
329
330        let x = [1.0, 2.0];
331        let mut y = [0.0, 0.0];
332        m.mul_vec(&x, &mut y);
333
334        assert!((y[0] - 5.0).abs() < 1e-10);
335        assert!((y[1] - 11.0).abs() < 1e-10);
336    }
337
338    #[test]
339    fn test_scale() {
340        let mut m: DenseMatrix<f64> = DenseMatrix::identity(2);
341        m.scale(3.0);
342        assert!((m.get(0, 0) - 3.0).abs() < 1e-15);
343        assert!((m.get(1, 1) - 3.0).abs() < 1e-15);
344    }
345
346    #[test]
347    fn test_add_scaled() {
348        let mut a: DenseMatrix<f64> = DenseMatrix::identity(2);
349        let b: DenseMatrix<f64> = DenseMatrix::identity(2);
350        a.add_scaled(2.0, &b);
351
352        // a should now be 3*I
353        assert!((a.get(0, 0) - 3.0).abs() < 1e-15);
354        assert!((a.get(1, 1) - 3.0).abs() < 1e-15);
355    }
356
357    #[test]
358    fn test_solve_diagonal() {
359        // Solve diag(2, 3, 4) * x = [1, 2, 3]
360        // x = [0.5, 2/3, 0.75]
361        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
362        m.set(0, 0, 2.0);
363        m.set(1, 1, 3.0);
364        m.set(2, 2, 4.0);
365
366        let b = vec![1.0, 2.0, 3.0];
367        let x = m.solve(&b).unwrap();
368
369        assert!((x[0] - 0.5).abs() < 1e-10);
370        assert!((x[1] - 2.0 / 3.0).abs() < 1e-10);
371        assert!((x[2] - 0.75).abs() < 1e-10);
372    }
373
374    #[test]
375    fn test_solve_general() {
376        // Solve [1 2; 3 4] * x = [5; 11]
377        // x = [1; 2]
378        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
379        m.set(0, 0, 1.0);
380        m.set(0, 1, 2.0);
381        m.set(1, 0, 3.0);
382        m.set(1, 1, 4.0);
383
384        let b = vec![5.0, 11.0];
385        let x = m.solve(&b).unwrap();
386
387        assert!((x[0] - 1.0).abs() < 1e-10);
388        assert!((x[1] - 2.0).abs() < 1e-10);
389    }
390
391    #[test]
392    fn test_from_row_major() {
393        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
394        let m: DenseMatrix<f64> = DenseMatrix::from_row_major(2, 3, &data);
395
396        assert_eq!(m.nrows(), 2);
397        assert_eq!(m.ncols(), 3);
398        assert!((m.get(0, 0) - 1.0).abs() < 1e-15);
399        assert!((m.get(0, 2) - 3.0).abs() < 1e-15);
400        assert!((m.get(1, 0) - 4.0).abs() < 1e-15);
401        assert!((m.get(1, 2) - 6.0).abs() < 1e-15);
402    }
403
404    #[test]
405    fn test_norm_frobenius() {
406        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
407        m.set(0, 0, 1.0);
408        m.set(0, 1, 2.0);
409        m.set(1, 0, 3.0);
410        m.set(1, 1, 4.0);
411
412        // sqrt(1 + 4 + 9 + 16) = sqrt(30)
413        let norm = m.norm_frobenius();
414        assert!((norm - 30.0_f64.sqrt()).abs() < 1e-10);
415    }
416
417    // ============================================================================
418    // Edge Case Tests
419    // ============================================================================
420
421    #[test]
422    fn test_1x1_matrix() {
423        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(1, 1);
424        m.set(0, 0, 5.0);
425        assert!(m.is_square());
426        assert!((m.get(0, 0) - 5.0).abs() < 1e-15);
427
428        let b = vec![10.0];
429        let x = m.solve(&b).unwrap();
430        assert!((x[0] - 2.0).abs() < 1e-10);
431    }
432
433    #[test]
434    fn test_identity_1x1() {
435        let m: DenseMatrix<f64> = DenseMatrix::identity(1);
436        assert!((m.get(0, 0) - 1.0).abs() < 1e-15);
437    }
438
439    #[test]
440    fn test_rectangular_not_square() {
441        let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
442        assert!(!m.is_square());
443    }
444
445    #[test]
446    fn test_solve_non_square_error() {
447        let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
448        let b = vec![1.0, 2.0];
449        let result = m.solve(&b);
450        assert!(result.is_err());
451    }
452
453    #[test]
454    fn test_solve_dimension_mismatch() {
455        let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
456        let b = vec![1.0, 2.0, 3.0]; // Wrong size
457        let result = m.solve(&b);
458        assert!(result.is_err());
459    }
460
461    #[test]
462    fn test_fill_zero() {
463        let mut m: DenseMatrix<f64> = DenseMatrix::identity(3);
464        m.fill_zero();
465        for i in 0..3 {
466            for j in 0..3 {
467                assert!(m.get(i, j).abs() < 1e-15);
468            }
469        }
470    }
471
472    #[test]
473    fn test_scale_by_zero() {
474        let mut m: DenseMatrix<f64> = DenseMatrix::identity(2);
475        m.scale(0.0);
476        for i in 0..2 {
477            for j in 0..2 {
478                assert!(m.get(i, j).abs() < 1e-15);
479            }
480        }
481    }
482
483    #[test]
484    fn test_scale_by_negative() {
485        let mut m: DenseMatrix<f64> = DenseMatrix::identity(2);
486        m.scale(-1.0);
487        assert!((m.get(0, 0) + 1.0).abs() < 1e-15);
488        assert!((m.get(1, 1) + 1.0).abs() < 1e-15);
489    }
490
491    #[test]
492    fn test_mul_vec_with_zeros() {
493        let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
494        let x = [100.0, 200.0];
495        let mut y = [999.0, 999.0];
496        m.mul_vec(&x, &mut y);
497        assert!(y[0].abs() < 1e-15);
498        assert!(y[1].abs() < 1e-15);
499    }
500
501    #[test]
502    fn test_norm_inf() {
503        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
504        m.set(0, 0, -1.0);
505        m.set(0, 1, 2.0);
506        m.set(1, 0, 3.0);
507        m.set(1, 1, -4.0);
508
509        // Row 0: |−1| + |2| = 3
510        // Row 1: |3| + |−4| = 7
511        // Max = 7
512        assert!((m.norm_inf() - 7.0).abs() < 1e-10);
513    }
514
515    #[test]
516    fn test_zeros_large() {
517        let m: DenseMatrix<f64> = DenseMatrix::zeros(100, 100);
518        assert_eq!(m.nrows(), 100);
519        assert_eq!(m.ncols(), 100);
520    }
521
522    #[test]
523    fn test_from_col_major() {
524        // Column-major: col0 = [1, 3], col1 = [2, 4]
525        let data = [1.0, 3.0, 2.0, 4.0];
526        let m: DenseMatrix<f64> = DenseMatrix::from_col_major(2, 2, &data);
527
528        assert!((m.get(0, 0) - 1.0).abs() < 1e-15);
529        assert!((m.get(1, 0) - 3.0).abs() < 1e-15);
530        assert!((m.get(0, 1) - 2.0).abs() < 1e-15);
531        assert!((m.get(1, 1) - 4.0).abs() < 1e-15);
532    }
533
534    #[test]
535    fn test_to_row_major() {
536        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
537        m.set(0, 0, 1.0);
538        m.set(0, 1, 2.0);
539        m.set(0, 2, 3.0);
540        m.set(1, 0, 4.0);
541        m.set(1, 1, 5.0);
542        m.set(1, 2, 6.0);
543
544        let data = m.to_row_major();
545        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
546    }
547
548    #[test]
549    fn test_solve_ill_conditioned() {
550        // Near-singular matrix (Hilbert-like)
551        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
552        m.set(0, 0, 1.0);
553        m.set(0, 1, 0.5);
554        m.set(1, 0, 0.5);
555        m.set(1, 1, 0.333333333333);
556
557        let b = vec![1.5, 0.833333333333];
558        let result = m.solve(&b);
559        // Should still produce a result (may be less accurate)
560        assert!(result.is_ok());
561    }
562
563    // ============================================================================
564    // f32 Scalar Tests
565    // ============================================================================
566
567    #[test]
568    fn test_f32_solve() {
569        let mut m: DenseMatrix<f32> = DenseMatrix::zeros(2, 2);
570        m.set(0, 0, 2.0);
571        m.set(0, 1, 0.0);
572        m.set(1, 0, 0.0);
573        m.set(1, 1, 3.0);
574
575        let b = vec![4.0f32, 9.0f32];
576        let x = m.solve(&b).unwrap();
577
578        assert!((x[0] - 2.0).abs() < 1e-5);
579        assert!((x[1] - 3.0).abs() < 1e-5);
580    }
581
582    #[test]
583    fn test_f32_identity() {
584        let m: DenseMatrix<f32> = DenseMatrix::identity(3);
585        for i in 0..3 {
586            for j in 0..3 {
587                let expected = if i == j { 1.0f32 } else { 0.0f32 };
588                assert!((m.get(i, j) - expected).abs() < 1e-7);
589            }
590        }
591    }
592
593    #[test]
594    fn test_f32_mul_vec() {
595        let mut m: DenseMatrix<f32> = DenseMatrix::zeros(2, 2);
596        m.set(0, 0, 1.0);
597        m.set(0, 1, 2.0);
598        m.set(1, 0, 3.0);
599        m.set(1, 1, 4.0);
600
601        let x = [1.0f32, 2.0f32];
602        let mut y = [0.0f32, 0.0f32];
603        m.mul_vec(&x, &mut y);
604
605        assert!((y[0] - 5.0).abs() < 1e-5);
606        assert!((y[1] - 11.0).abs() < 1e-5);
607    }
608
609    // ============================================================================
610    // Error Type Tests
611    // ============================================================================
612
613    #[test]
614    fn test_solve_non_square_returns_not_square_error() {
615        let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
616        let b = vec![1.0, 2.0];
617        match m.solve(&b) {
618            Err(LinalgError::NotSquare { nrows: 2, ncols: 3 }) => {}
619            other => panic!("Expected NotSquare error, got {:?}", other),
620        }
621    }
622
623    #[test]
624    fn test_solve_dimension_mismatch_returns_typed_error() {
625        let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
626        let b = vec![1.0, 2.0, 3.0];
627        match m.solve(&b) {
628            Err(LinalgError::DimensionMismatch { .. }) => {}
629            other => panic!("Expected DimensionMismatch error, got {:?}", other),
630        }
631    }
632}