differential_equations/linalg/matrix/
base.rs

1//! Core matrix type, storage enum, and constructors.
2
3use crate::traits::Real;
4
5/// Matrix storage layout.
6#[derive(Clone, Debug)]
7pub enum MatrixStorage<T: Real> {
8    /// Identity matrix (implicit). `data` stores [one, zero] to satisfy indexing by reference.
9    Identity,
10    /// Dense row-major matrix (nrows*ncols entries).
11    Full,
12    /// Banded matrix with lower (ml) and upper (mu) bandwidth.
13    /// Compact diagonal storage with shape (ml+mu+1, ncols), row-major per diagonal.
14    /// Off-band reads return `zero`.
15    Banded { ml: usize, mu: usize, zero: T },
16}
17
18/// Generic matrix for linear algebra (typically square in current use).
19#[derive(Clone, Debug)]
20pub struct Matrix<T: Real> {
21    pub nrows: usize,
22    pub ncols: usize,
23    pub data: Vec<T>,
24    pub storage: MatrixStorage<T>,
25}
26
27impl<T: Real> Matrix<T> {
28    /// Identity matrix of size n x n.
29    pub fn identity(n: usize) -> Self {
30        Matrix {
31            nrows: n,
32            ncols: n,
33            // Keep [one, zero] so indexing can return references.
34            data: vec![T::one(), T::zero()],
35            storage: MatrixStorage::Identity,
36        }
37    }
38
39    /// Full matrix from a row-major vector of length n*n.
40    pub fn full(n: usize, data: Vec<T>) -> Self {
41        assert_eq!(data.len(), n * n, "Matrix::full expects data of length n*n");
42        Matrix {
43            nrows: n,
44            ncols: n,
45            data,
46            storage: MatrixStorage::Full,
47        }
48    }
49
50    /// Zero matrix of size n x n.
51    pub fn zeros(n: usize) -> Self {
52        Matrix {
53            nrows: n,
54            ncols: n,
55            data: vec![T::zero(); n * n],
56            storage: MatrixStorage::Full,
57        }
58    }
59
60    /// Zero banded matrix with the given bandwidths.
61    /// For entry (i,j) within the band, index maps to data[i - j + mu, j].
62    pub fn banded(n: usize, ml: usize, mu: usize) -> Self {
63        let rows = ml + mu + 1;
64        let data = vec![T::zero(); rows * n];
65        Matrix {
66            nrows: n,
67            ncols: n,
68            data,
69            storage: MatrixStorage::Banded {
70                ml,
71                mu,
72                zero: T::zero(),
73            },
74        }
75    }
76
77    /// Diagonal matrix from the provided diagonal entries (ml=mu=0).
78    pub fn diagonal(diag: Vec<T>) -> Self {
79        let n = diag.len();
80        // With ml=mu=0, storage is (1,n), so `diag` maps directly to row 0.
81        Matrix {
82            nrows: n,
83            ncols: n,
84            data: diag,
85            storage: MatrixStorage::Banded {
86                ml: 0,
87                mu: 0,
88                zero: T::zero(),
89            },
90        }
91    }
92
93    /// Zero lower-triangular matrix (ml = n-1, mu = 0).
94    pub fn lower_triangular(n: usize) -> Self {
95        Matrix::banded(n, n.saturating_sub(1), 0)
96    }
97
98    /// Zero upper-triangular matrix (ml = 0, mu = n-1).
99    pub fn upper_triangular(n: usize) -> Self {
100        Matrix::banded(n, 0, n.saturating_sub(1))
101    }
102
103    /// Dimensions (nrows, ncols).
104    pub fn dims(&self) -> (usize, usize) {
105        (self.nrows, self.ncols)
106    }
107
108    /// Convenience: n for an n x n matrix.
109    pub fn n(&self) -> usize {
110        self.dims().0
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::Matrix;
117
118    #[test]
119    fn diagonal_constructor_sets_diagonal() {
120        let m = Matrix::diagonal(vec![1.0f64, 2.0, 3.0]);
121        assert_eq!(m[(0, 0)], 1.0);
122        assert_eq!(m[(1, 1)], 2.0);
123        assert_eq!(m[(2, 2)], 3.0);
124        assert_eq!(m[(0, 1)], 0.0);
125        assert_eq!(m[(2, 0)], 0.0);
126    }
127
128    #[test]
129    fn triangular_constructors_shape() {
130        let l: Matrix<f64> = Matrix::lower_triangular(4);
131        // Above main diagonal reads zero
132        assert_eq!(l[(0, 3)], 0.0);
133        let u: Matrix<f64> = Matrix::upper_triangular(4);
134        // Below main diagonal reads zero
135        assert_eq!(u[(3, 0)], 0.0);
136    }
137}