differential_equations/linalg/matrix/
index.rs

1//! Indexing and display for `Matrix`.
2
3use core::fmt::{self, Display, Write as _};
4use core::ops::{Index, IndexMut};
5
6use crate::traits::Real;
7
8use super::base::{Matrix, MatrixStorage};
9
10/// 2D indexing by (i, j), read-only.
11impl<T: Real> Index<(usize, usize)> for Matrix<T> {
12    type Output = T;
13
14    fn index(&self, index: (usize, usize)) -> &Self::Output {
15        let (i, j) = index;
16        assert!(i < self.n && j < self.m, "Index out of bounds");
17        match &self.storage {
18            MatrixStorage::Identity => {
19                if i == j {
20                    &self.data[0]
21                } else {
22                    &self.data[1]
23                }
24            }
25            MatrixStorage::Full => &self.data[i * self.m + j],
26            MatrixStorage::Banded { ml, mu, zero } => {
27                let k = i as isize - j as isize;
28                if k < -(*mu as isize) || k > *ml as isize {
29                    zero
30                } else {
31                    let row = (k + *mu as isize) as usize;
32                    &self.data[row * self.m + j]
33                }
34            }
35        }
36    }
37}
38
39/// 2D indexing by (i, j), mutable (where supported).
40impl<T: Real> IndexMut<(usize, usize)> for Matrix<T> {
41    fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
42        assert!(i < self.n && j < self.m, "Index out of bounds");
43        match &mut self.storage {
44            MatrixStorage::Full => &mut self.data[i * self.m + j],
45            MatrixStorage::Identity => {
46                panic!(
47                    "cannot mutate Identity matrix via indexing; convert explicitly to Full first"
48                )
49            }
50            MatrixStorage::Banded { ml, mu, .. } => {
51                let k = i as isize - j as isize;
52                if k >= -(*mu as isize) && k <= *ml as isize {
53                    let row = (k + *mu as isize) as usize;
54                    &mut self.data[row * self.m + j]
55                } else {
56                    panic!(
57                        "attempted to write outside band of Banded matrix: i-j={} not in [-mu, ml] = [-{}, {}]",
58                        k, mu, ml
59                    )
60                }
61            }
62        }
63    }
64}
65
66impl<T> Display for Matrix<T>
67where
68    T: Real + Display,
69{
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        let (nr, nc) = (self.n, self.m);
72        for i in 0..nr {
73            f.write_str("[")?;
74            for j in 0..nc {
75                if j > 0 {
76                    f.write_str(" ")?;
77                }
78                write!(f, "{}", self[(i, j)])?;
79            }
80            f.write_str("]")?;
81            if i + 1 < nr {
82                f.write_char('\n')?;
83            }
84        }
85        Ok(())
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use crate::linalg::matrix::Matrix;
92
93    #[test]
94    #[should_panic(expected = "cannot mutate Identity matrix via indexing")]
95    fn identity_panics_on_write() {
96        let mut m: Matrix<f64> = Matrix::identity(3);
97        m[(2, 0)] = 5.0; // should panic
98    }
99
100    #[test]
101    fn banded_inband_write_keeps_banded() {
102        let mut b: Matrix<f64> = Matrix::banded(5, 1, 2);
103        b[(2, 2)] = 10.0; // main diag
104        b[(1, 2)] = 5.0; // upper-1
105        b[(3, 2)] = 7.0; // lower-1
106        // stays valid and values accessible
107        assert_eq!(b[(2, 2)], 10.0);
108        assert_eq!(b[(1, 2)], 5.0);
109        assert_eq!(b[(3, 2)], 7.0);
110        // Outside band reads as zero
111        assert_eq!(b[(0, 4)], 0.0);
112    }
113
114    #[test]
115    #[should_panic(expected = "attempted to write outside band of Banded matrix")]
116    fn banded_out_of_band_write_panics() {
117        let mut b: Matrix<f64> = Matrix::banded(3, 0, 0); // only main diagonal
118        b[(0, 2)] = 7.0; // out-of-band -> panic
119    }
120
121    #[test]
122    fn full_index_read_write() {
123        let mut m: Matrix<f64> = Matrix::zeros(2, 2);
124        m[(0, 0)] = 1.0;
125        m[(0, 1)] = 2.0;
126        m[(1, 0)] = 3.0;
127        m[(1, 1)] = 4.0;
128        assert_eq!(m[(0, 1)], 2.0);
129        assert_eq!(m[(1, 0)], 3.0);
130    }
131
132    #[test]
133    fn display_prints_dense_matrix() {
134        let m: Matrix<f64> = Matrix::identity(3);
135        let s = format!("{}", m);
136        assert!(s.contains("[1 0 0]"));
137        assert!(s.contains("[0 1 0]"));
138        assert!(s.contains("[0 0 1]"));
139    }
140}