Skip to main content

differential_equations/linalg/matrix/
base.rs

1//! Core matrix type, storage enum, and constructors.
2
3use crate::{linalg::LinalgError, traits::Real};
4
5fn coalesce_triplets<T: Real>(
6    n: usize,
7    m: usize,
8    triplets: Vec<(usize, usize, T)>,
9) -> Vec<(usize, usize, T)> {
10    let mut coords: Vec<(usize, usize, T)> = Vec::new();
11    for (row, col, val) in triplets {
12        assert!(row < n && col < m, "Sparse triplet index out of bounds");
13        if let Some(entry) = coords.iter_mut().find(|(r, c, _)| *r == row && *c == col) {
14            entry.2 += val;
15        } else {
16            coords.push((row, col, val));
17        }
18    }
19    coords.retain(|(_, _, v)| *v != T::zero());
20    coords
21}
22
23/// Matrix storage layout.
24#[derive(PartialEq, Clone, Debug)]
25pub enum MatrixStorage<T: Real> {
26    /// Identity matrix (implicit). `data` stores [one, zero] to satisfy indexing by reference.
27    Identity,
28    /// Dense row-major matrix (nrows*ncols entries).
29    Full,
30    /// Banded matrix with lower (ml) and upper (mu) bandwidth.
31    /// Compact diagonal storage with shape (ml+mu+1, ncols), row-major per diagonal.
32    /// Off-band reads return `zero`.
33    Banded { ml: usize, mu: usize, zero: T },
34    /// Sparse matrix representation using coordinate format.
35    Sparse {
36        coords: Vec<(usize, usize, T)>,
37        zero: T,
38    },
39}
40
41/// Generic matrix for linear algebra (typically square in current use).
42#[derive(PartialEq, Clone, Debug)]
43pub struct Matrix<T: Real> {
44    pub n: usize,
45    pub m: usize,
46    pub data: Vec<T>,
47    pub storage: MatrixStorage<T>,
48}
49
50impl<T: Real> Matrix<T> {
51    /// Number of rows.
52    pub fn nrows(&self) -> usize {
53        self.n
54    }
55
56    /// Number of columns.
57    pub fn ncols(&self) -> usize {
58        self.m
59    }
60
61    /// Identity matrix of size n x n.
62    pub fn identity(n: usize) -> Self {
63        Matrix {
64            n,
65            m: n,
66            // Keep [one, zero] so indexing can return references.
67            data: vec![T::one(), T::zero()],
68            storage: MatrixStorage::Identity,
69        }
70    }
71
72    /// Creates a dense row-major matrix from a vector.
73    ///
74    /// # Errors
75    /// Returns [`LinalgError::BadInput`] when `data.len() != n * m`.
76    pub fn from_vec(n: usize, m: usize, data: Vec<T>) -> Result<Self, LinalgError> {
77        if data.len() != n * m {
78            return Err(LinalgError::BadInput {
79                message: format!(
80                    "Incompatible data length: expected {}, got {}",
81                    n * m,
82                    data.len()
83                ),
84            });
85        }
86        Ok(Matrix {
87            n,
88            m,
89            data,
90            storage: MatrixStorage::Full,
91        })
92    }
93
94    /// Empty sparse matrix of size n x m.
95    pub fn sparse(n: usize, m: usize) -> Self {
96        Matrix {
97            n,
98            m,
99            data: Vec::new(),
100            storage: MatrixStorage::Sparse {
101                coords: Vec::new(),
102                zero: T::zero(),
103            },
104        }
105    }
106
107    /// Sparse matrix from coordinate triplets.
108    ///
109    /// Duplicate entries for the same (row, col) are coalesced (summed) at
110    /// construction time, and any entries that sum to zero are dropped.
111    /// This guarantees that no two stored coordinates share the same position,
112    /// so `Index` / `IndexMut` work consistently without a separate `get`/`set`.
113    pub fn sparse_from_triplets(n: usize, m: usize, triplets: Vec<(usize, usize, T)>) -> Self {
114        let coords = coalesce_triplets(n, m, triplets);
115        Matrix {
116            n,
117            m,
118            data: Vec::new(),
119            storage: MatrixStorage::Sparse {
120                coords,
121                zero: T::zero(),
122            },
123        }
124    }
125
126    /// Full matrix from a row-major vector of length n*m.
127    pub fn full(n: usize, m: usize) -> Self {
128        let data = vec![T::zero(); n * m];
129        Matrix {
130            n,
131            m,
132            data,
133            storage: MatrixStorage::Full,
134        }
135    }
136
137    /// Square matrix of size n x n.
138    pub fn square(n: usize) -> Self {
139        Matrix {
140            n,
141            m: n,
142            data: Vec::with_capacity(n * n),
143            storage: MatrixStorage::Full,
144        }
145    }
146
147    /// Zero matrix of size n x m.
148    pub fn zeros(n: usize, m: usize) -> Self {
149        Matrix {
150            n,
151            m,
152            data: vec![T::zero(); n * m],
153            storage: MatrixStorage::Full,
154        }
155    }
156
157    /// Zero banded matrix with the given bandwidths.
158    /// For entry (i,j) within the band, index maps to data[i - j + mu, j].
159    pub fn banded(n: usize, ml: usize, mu: usize) -> Self {
160        let rows = ml + mu + 1;
161        let data = vec![T::zero(); rows * n];
162        Matrix {
163            n,
164            m: n,
165            data,
166            storage: MatrixStorage::Banded {
167                ml,
168                mu,
169                zero: T::zero(),
170            },
171        }
172    }
173
174    /// Diagonal matrix from the provided diagonal entries (ml=mu=0).
175    pub fn diagonal(diag: Vec<T>) -> Self {
176        let n = diag.len();
177        // With ml=mu=0, storage is (1,n), so `diag` maps directly to row 0.
178        Matrix {
179            n,
180            m: n,
181            data: diag,
182            storage: MatrixStorage::Banded {
183                ml: 0,
184                mu: 0,
185                zero: T::zero(),
186            },
187        }
188    }
189
190    /// Zero lower-triangular matrix (ml = n-1, mu = 0).
191    pub fn lower_triangular(n: usize) -> Self {
192        Matrix::banded(n, n.saturating_sub(1), 0)
193    }
194
195    /// Zero upper-triangular matrix (ml = 0, mu = n-1).
196    pub fn upper_triangular(n: usize) -> Self {
197        Matrix::banded(n, 0, n.saturating_sub(1))
198    }
199
200    /// Dimensions (nrows, ncols).
201    pub fn dims(&self) -> (usize, usize) {
202        (self.n, self.m)
203    }
204
205    /// Convert the matrix to dense row-major storage.
206    pub fn to_dense_vec(&self) -> Vec<T> {
207        match &self.storage {
208            MatrixStorage::Full => self.data.clone(),
209            MatrixStorage::Identity => {
210                let mut dense = vec![T::zero(); self.n * self.m];
211                for i in 0..self.n.min(self.m) {
212                    dense[i * self.m + i] = T::one();
213                }
214                dense
215            }
216            MatrixStorage::Banded { ml, mu, .. } => {
217                let mut dense = vec![T::zero(); self.n * self.m];
218                for col in 0..self.m {
219                    for band_row in 0..(*ml + *mu + 1) {
220                        let offset = band_row as isize - *mu as isize;
221                        let row_signed = col as isize + offset;
222                        if row_signed >= 0 && (row_signed as usize) < self.n {
223                            let row = row_signed as usize;
224                            dense[row * self.m + col] += self.data[band_row * self.m + col];
225                        }
226                    }
227                }
228                dense
229            }
230            MatrixStorage::Sparse { coords, .. } => {
231                let mut dense = vec![T::zero(); self.n * self.m];
232                for &(row, col, value) in coords {
233                    dense[row * self.m + col] += value;
234                }
235                dense
236            }
237        }
238    }
239
240    /// Convert the matrix to full storage in place.
241    pub fn make_full(&mut self) {
242        self.data = self.to_dense_vec();
243        self.storage = MatrixStorage::Full;
244    }
245
246    /// Checks if the matrix is an identity matrix.
247    pub fn is_identity(&self) -> bool {
248        if let MatrixStorage::Identity = self.storage {
249            return true;
250        } else if let MatrixStorage::Full = self.storage {
251            for i in 0..self.n {
252                for j in 0..self.m {
253                    let expected = if i == j { T::one() } else { T::zero() };
254                    if self.data[i * self.m + j] != expected {
255                        return false;
256                    }
257                }
258            }
259        } else if let MatrixStorage::Banded {
260            ml: _ml,
261            mu: _mu,
262            zero,
263        } = self.storage
264        {
265            for i in 0..self.n {
266                for j in 0..self.m {
267                    let expected = if i == j { T::one() } else { zero };
268                    if self.data[i * self.m + j] != expected {
269                        return false;
270                    }
271                }
272            }
273        } else if let MatrixStorage::Sparse { ref coords, .. } = self.storage {
274            let diag_count = self.n.min(self.m);
275            if coords.len() != diag_count {
276                return false;
277            }
278            for &(r, c, v) in coords {
279                if r != c || v != T::one() {
280                    return false;
281                }
282            }
283        }
284        true
285    }
286
287    /// Swap two rows in-place for Full storage. For Banded storage, performs a logical swap
288    /// of accessible entries within the band; for Identity, no-op unless swapping equal indices.
289    pub fn swap_rows(&mut self, r1: usize, r2: usize) {
290        assert!(r1 < self.n && r2 < self.n, "row index out of bounds");
291        if r1 == r2 {
292            return;
293        }
294        match &mut self.storage {
295            MatrixStorage::Full => {
296                for j in 0..self.m {
297                    self.data.swap(r1 * self.m + j, r2 * self.m + j);
298                }
299            }
300            MatrixStorage::Identity => {
301                // Identity is stored as [one, zero]; swapping has no effect on implicit structure.
302                // Clients should not attempt to permute Identity rows; we ignore to keep API simple.
303            }
304            MatrixStorage::Banded { ml, mu, .. } => {
305                // Only swap entries that are actually stored (within band).
306                // For each column j, if (r1,j) and/or (r2,j) are in band, swap.
307                let mlv = *ml as isize;
308                let muv = *mu as isize;
309                for j in 0..self.m {
310                    let k1 = r1 as isize - j as isize;
311                    let k2 = r2 as isize - j as isize;
312                    let in1 = k1 >= -muv && k1 <= mlv;
313                    let in2 = k2 >= -muv && k2 <= mlv;
314                    if in1 && in2 {
315                        let row1 = (k1 + *mu as isize) as usize;
316                        let row2 = (k2 + *mu as isize) as usize;
317                        self.data.swap(row1 * self.m + j, row2 * self.m + j);
318                    } else if in1 || in2 {
319                        // One entry is implicit zero; swapping sets stored one to zero and vice versa
320                        // This best-effort maintains logical swap within band footprint.
321                        if in1 {
322                            let row1 = (k1 + *mu as isize) as usize;
323                            let idx1 = row1 * self.m + j;
324                            self.data[idx1] = T::zero();
325                        } else {
326                            let row2 = (k2 + *mu as isize) as usize;
327                            let idx2 = row2 * self.m + j;
328                            self.data[idx2] = T::zero();
329                        }
330                    }
331                }
332            }
333            MatrixStorage::Sparse { coords, .. } => {
334                for item in coords.iter_mut() {
335                    if item.0 == r1 {
336                        item.0 = r2;
337                    } else if item.0 == r2 {
338                        item.0 = r1;
339                    }
340                }
341            }
342        }
343    }
344
345    /// Fill the matrix with a constant value.
346    pub fn fill(&mut self, value: T) {
347        match &mut self.storage {
348            MatrixStorage::Identity
349            | MatrixStorage::Banded { .. }
350            | MatrixStorage::Sparse { .. }
351                if value != T::zero() =>
352            {
353                self.data = vec![value; self.n * self.m];
354                self.storage = MatrixStorage::Full;
355            }
356            MatrixStorage::Sparse { coords, zero } => {
357                coords.clear();
358                *zero = T::zero();
359            }
360            _ => self.data.fill(value),
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::{LinalgError, Matrix, MatrixStorage};
368
369    #[test]
370    fn diagonal_constructor_sets_diagonal() {
371        let m = Matrix::diagonal(vec![1.0f64, 2.0, 3.0]);
372        assert_eq!(m[(0, 0)], 1.0);
373        assert_eq!(m[(1, 1)], 2.0);
374        assert_eq!(m[(2, 2)], 3.0);
375        assert_eq!(m[(0, 1)], 0.0);
376        assert_eq!(m[(2, 0)], 0.0);
377    }
378
379    #[test]
380    fn triangular_constructors_shape() {
381        let l: Matrix<f64> = Matrix::lower_triangular(4);
382        // Above main diagonal reads zero
383        assert_eq!(l[(0, 3)], 0.0);
384        let u: Matrix<f64> = Matrix::upper_triangular(4);
385        // Below main diagonal reads zero
386        assert_eq!(u[(3, 0)], 0.0);
387    }
388
389    #[test]
390    fn from_vec_rejects_incompatible_data_length() {
391        let result = Matrix::<f64>::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0]);
392
393        assert_eq!(
394            result,
395            Err(LinalgError::BadInput {
396                message: "Incompatible data length: expected 6, got 4".to_string(),
397            })
398        );
399    }
400
401    #[test]
402    fn sparse_triplets_coalesce_duplicates() {
403        let m = Matrix::sparse_from_triplets(2, 3, vec![(0, 1, 2.0), (0, 1, 3.0), (1, 2, 4.0)]);
404        assert_eq!(m[(0, 0)], 0.0);
405        assert_eq!(m[(0, 1)], 5.0);
406        assert_eq!(m[(1, 2)], 4.0);
407        assert_eq!(m.to_dense_vec(), vec![0.0, 5.0, 0.0, 0.0, 0.0, 4.0]);
408    }
409
410    #[test]
411    fn sparse_index_mut_replaces_coalesced_entry() {
412        let mut m = Matrix::sparse_from_triplets(2, 2, vec![(0, 1, 2.0), (0, 1, 3.0), (1, 1, 4.0)]);
413        m[(0, 1)] = 7.0;
414        assert_eq!(m[(0, 1)], 7.0);
415        m[(1, 0)] = 0.0;
416        assert_eq!(m[(1, 0)], 0.0);
417    }
418
419    #[test]
420    fn sparse_fill_zero_preserves_sparse_storage() {
421        let mut m = Matrix::sparse_from_triplets(2, 2, vec![(0, 1, 2.0)]);
422        m.fill(0.0);
423        assert_eq!(m[(0, 1)], 0.0);
424        assert!(matches!(m.storage, MatrixStorage::Sparse { .. }));
425    }
426
427    #[test]
428    fn sparse_storage_carries_zero_reference() {
429        let m = Matrix::<f64>::sparse(2, 2);
430        match &m.storage {
431            MatrixStorage::Sparse { zero, .. } => {
432                assert_eq!(m[(1, 1)], *zero);
433            }
434            _ => panic!("expected sparse storage"),
435        }
436    }
437}