differential_equations/linalg/matrix/
base.rs

1//! Core matrix type, storage enum, and constructors.
2
3use crate::traits::Real;
4
5/// Matrix storage layout.
6#[derive(PartialEq, Clone, Debug)]
7pub enum MatrixStorage<T: Real> {
8    /// Identity matrix (implicit). `data` stores [one, zero] to satisfy indexing by reference.
9    Identity,
10    /// Dense row-major matrix (nrows*ncols entries).
11    Full,
12    /// Banded matrix with lower (ml) and upper (mu) bandwidth.
13    /// Compact diagonal storage with shape (ml+mu+1, ncols), row-major per diagonal.
14    /// Off-band reads return `zero`.
15    Banded { ml: usize, mu: usize, zero: T },
16}
17
18/// Generic matrix for linear algebra (typically square in current use).
19#[derive(PartialEq, Clone, Debug)]
20pub struct Matrix<T: Real> {
21    pub n: usize,
22    pub m: usize,
23    pub data: Vec<T>,
24    pub storage: MatrixStorage<T>,
25}
26
27impl<T: Real> Matrix<T> {
28    /// Number of rows.
29    pub fn nrows(&self) -> usize {
30        self.n
31    }
32
33    /// Number of columns.
34    pub fn ncols(&self) -> usize {
35        self.m
36    }
37
38    /// Identity matrix of size n x n.
39    pub fn identity(n: usize) -> Self {
40        Matrix {
41            n,
42            m: n,
43            // Keep [one, zero] so indexing can return references.
44            data: vec![T::one(), T::zero()],
45            storage: MatrixStorage::Identity,
46        }
47    }
48
49    /// Creates a matrix from a vector.
50    pub fn from_vec(n: usize, m: usize, data: Vec<T>) -> Self {
51        assert_eq!(data.len(), n * m, "Incompatible data length");
52        Matrix {
53            n,
54            m,
55            data,
56            storage: MatrixStorage::Full,
57        }
58    }
59
60    /// Full matrix from a row-major vector of length n*m.
61    pub fn full(n: usize, m: usize) -> Self {
62        let data = vec![T::zero(); n * m];
63        Matrix {
64            n,
65            m,
66            data,
67            storage: MatrixStorage::Full,
68        }
69    }
70
71    /// Square matrix of size n x n.
72    pub fn square(n: usize) -> Self {
73        Matrix {
74            n,
75            m: n,
76            data: Vec::with_capacity(n * n),
77            storage: MatrixStorage::Full,
78        }
79    }
80
81    /// Zero matrix of size n x m.
82    pub fn zeros(n: usize, m: usize) -> Self {
83        Matrix {
84            n,
85            m,
86            data: vec![T::zero(); n * m],
87            storage: MatrixStorage::Full,
88        }
89    }
90
91    /// Zero banded matrix with the given bandwidths.
92    /// For entry (i,j) within the band, index maps to data[i - j + mu, j].
93    pub fn banded(n: usize, ml: usize, mu: usize) -> Self {
94        let rows = ml + mu + 1;
95        let data = vec![T::zero(); rows * n];
96        Matrix {
97            n,
98            m: n,
99            data,
100            storage: MatrixStorage::Banded {
101                ml,
102                mu,
103                zero: T::zero(),
104            },
105        }
106    }
107
108    /// Diagonal matrix from the provided diagonal entries (ml=mu=0).
109    pub fn diagonal(diag: Vec<T>) -> Self {
110        let n = diag.len();
111        // With ml=mu=0, storage is (1,n), so `diag` maps directly to row 0.
112        Matrix {
113            n,
114            m: n,
115            data: diag,
116            storage: MatrixStorage::Banded {
117                ml: 0,
118                mu: 0,
119                zero: T::zero(),
120            },
121        }
122    }
123
124    /// Zero lower-triangular matrix (ml = n-1, mu = 0).
125    pub fn lower_triangular(n: usize) -> Self {
126        Matrix::banded(n, n.saturating_sub(1), 0)
127    }
128
129    /// Zero upper-triangular matrix (ml = 0, mu = n-1).
130    pub fn upper_triangular(n: usize) -> Self {
131        Matrix::banded(n, 0, n.saturating_sub(1))
132    }
133
134    /// Dimensions (nrows, ncols).
135    pub fn dims(&self) -> (usize, usize) {
136        (self.n, self.m)
137    }
138
139    /// Checks if the matrix is an identity matrix.
140    pub fn is_identity(&self) -> bool {
141        if let MatrixStorage::Identity = self.storage {
142            return true;
143        } else if let MatrixStorage::Full = self.storage {
144            for i in 0..self.n {
145                for j in 0..self.m {
146                    if i == j && self.data[i * self.m + j] != T::one() {
147                        return false;
148                    } else if i != j && self.data[i * self.m + j] != T::zero() {
149                        return false;
150                    }
151                }
152            }
153        } else if let MatrixStorage::Banded {
154            ml: _ml,
155            mu: _mu,
156            zero,
157        } = self.storage
158        {
159            for i in 0..self.n {
160                for j in 0..self.m {
161                    if i == j && self.data[i * self.m + j] != T::one() {
162                        return false;
163                    } else if i != j && self.data[i * self.m + j] != zero {
164                        return false;
165                    }
166                }
167            }
168        }
169        true
170    }
171
172    /// Swap two rows in-place for Full storage. For Banded storage, performs a logical swap
173    /// of accessible entries within the band; for Identity, no-op unless swapping equal indices.
174    pub fn swap_rows(&mut self, r1: usize, r2: usize) {
175        assert!(r1 < self.n && r2 < self.n, "row index out of bounds");
176        if r1 == r2 {
177            return;
178        }
179        match &mut self.storage {
180            MatrixStorage::Full => {
181                for j in 0..self.m {
182                    self.data.swap(r1 * self.m + j, r2 * self.m + j);
183                }
184            }
185            MatrixStorage::Identity => {
186                // Identity is stored as [one, zero]; swapping has no effect on implicit structure.
187                // Clients should not attempt to permute Identity rows; we ignore to keep API simple.
188            }
189            MatrixStorage::Banded { ml, mu, .. } => {
190                // Only swap entries that are actually stored (within band).
191                // For each column j, if (r1,j) and/or (r2,j) are in band, swap.
192                let mlv = *ml as isize;
193                let muv = *mu as isize;
194                for j in 0..self.m {
195                    let k1 = r1 as isize - j as isize;
196                    let k2 = r2 as isize - j as isize;
197                    let in1 = k1 >= -muv && k1 <= mlv;
198                    let in2 = k2 >= -muv && k2 <= mlv;
199                    if in1 && in2 {
200                        let row1 = (k1 + *mu as isize) as usize;
201                        let row2 = (k2 + *mu as isize) as usize;
202                        self.data.swap(row1 * self.m + j, row2 * self.m + j);
203                    } else if in1 || in2 {
204                        // One entry is implicit zero; swapping sets stored one to zero and vice versa
205                        // This best-effort maintains logical swap within band footprint.
206                        if in1 {
207                            let row1 = (k1 + *mu as isize) as usize;
208                            let idx1 = row1 * self.m + j;
209                            self.data[idx1] = T::zero();
210                        } else {
211                            let row2 = (k2 + *mu as isize) as usize;
212                            let idx2 = row2 * self.m + j;
213                            self.data[idx2] = T::zero();
214                        }
215                    }
216                }
217            }
218        }
219    }
220
221    /// Fill the matrix with a constant value.
222    pub fn fill(&mut self, value: T) {
223        self.data.fill(value);
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::Matrix;
230
231    #[test]
232    fn diagonal_constructor_sets_diagonal() {
233        let m = Matrix::diagonal(vec![1.0f64, 2.0, 3.0]);
234        assert_eq!(m[(0, 0)], 1.0);
235        assert_eq!(m[(1, 1)], 2.0);
236        assert_eq!(m[(2, 2)], 3.0);
237        assert_eq!(m[(0, 1)], 0.0);
238        assert_eq!(m[(2, 0)], 0.0);
239    }
240
241    #[test]
242    fn triangular_constructors_shape() {
243        let l: Matrix<f64> = Matrix::lower_triangular(4);
244        // Above main diagonal reads zero
245        assert_eq!(l[(0, 3)], 0.0);
246        let u: Matrix<f64> = Matrix::upper_triangular(4);
247        // Below main diagonal reads zero
248        assert_eq!(u[(3, 0)], 0.0);
249    }
250}