differential_equations/linalg/matrix/
base.rs1use crate::traits::Real;
4
5#[derive(PartialEq, Clone, Debug)]
7pub enum MatrixStorage<T: Real> {
8 Identity,
10 Full,
12 Banded { ml: usize, mu: usize, zero: T },
16}
17
18#[derive(PartialEq, Clone, Debug)]
20pub struct Matrix<T: Real> {
21 pub n: usize,
22 pub m: usize,
23 pub data: Vec<T>,
24 pub storage: MatrixStorage<T>,
25}
26
27impl<T: Real> Matrix<T> {
28 pub fn nrows(&self) -> usize {
30 self.n
31 }
32
33 pub fn ncols(&self) -> usize {
35 self.m
36 }
37
38 pub fn identity(n: usize) -> Self {
40 Matrix {
41 n,
42 m: n,
43 data: vec![T::one(), T::zero()],
45 storage: MatrixStorage::Identity,
46 }
47 }
48
49 pub fn from_vec(n: usize, m: usize, data: Vec<T>) -> Self {
51 assert_eq!(data.len(), n * m, "Incompatible data length");
52 Matrix {
53 n,
54 m,
55 data,
56 storage: MatrixStorage::Full,
57 }
58 }
59
60 pub fn full(n: usize, m: usize) -> Self {
62 let data = vec![T::zero(); n * m];
63 Matrix {
64 n,
65 m,
66 data,
67 storage: MatrixStorage::Full,
68 }
69 }
70
71 pub fn square(n: usize) -> Self {
73 Matrix {
74 n,
75 m: n,
76 data: Vec::with_capacity(n * n),
77 storage: MatrixStorage::Full,
78 }
79 }
80
81 pub fn zeros(n: usize, m: usize) -> Self {
83 Matrix {
84 n,
85 m,
86 data: vec![T::zero(); n * m],
87 storage: MatrixStorage::Full,
88 }
89 }
90
91 pub fn banded(n: usize, ml: usize, mu: usize) -> Self {
94 let rows = ml + mu + 1;
95 let data = vec![T::zero(); rows * n];
96 Matrix {
97 n,
98 m: n,
99 data,
100 storage: MatrixStorage::Banded {
101 ml,
102 mu,
103 zero: T::zero(),
104 },
105 }
106 }
107
108 pub fn diagonal(diag: Vec<T>) -> Self {
110 let n = diag.len();
111 Matrix {
113 n,
114 m: n,
115 data: diag,
116 storage: MatrixStorage::Banded {
117 ml: 0,
118 mu: 0,
119 zero: T::zero(),
120 },
121 }
122 }
123
124 pub fn lower_triangular(n: usize) -> Self {
126 Matrix::banded(n, n.saturating_sub(1), 0)
127 }
128
129 pub fn upper_triangular(n: usize) -> Self {
131 Matrix::banded(n, 0, n.saturating_sub(1))
132 }
133
134 pub fn dims(&self) -> (usize, usize) {
136 (self.n, self.m)
137 }
138
139 pub fn is_identity(&self) -> bool {
141 if let MatrixStorage::Identity = self.storage {
142 return true;
143 } else if let MatrixStorage::Full = self.storage {
144 for i in 0..self.n {
145 for j in 0..self.m {
146 if i == j && self.data[i * self.m + j] != T::one() {
147 return false;
148 } else if i != j && self.data[i * self.m + j] != T::zero() {
149 return false;
150 }
151 }
152 }
153 } else if let MatrixStorage::Banded {
154 ml: _ml,
155 mu: _mu,
156 zero,
157 } = self.storage
158 {
159 for i in 0..self.n {
160 for j in 0..self.m {
161 if i == j && self.data[i * self.m + j] != T::one() {
162 return false;
163 } else if i != j && self.data[i * self.m + j] != zero {
164 return false;
165 }
166 }
167 }
168 }
169 true
170 }
171
172 pub fn swap_rows(&mut self, r1: usize, r2: usize) {
175 assert!(r1 < self.n && r2 < self.n, "row index out of bounds");
176 if r1 == r2 {
177 return;
178 }
179 match &mut self.storage {
180 MatrixStorage::Full => {
181 for j in 0..self.m {
182 self.data.swap(r1 * self.m + j, r2 * self.m + j);
183 }
184 }
185 MatrixStorage::Identity => {
186 }
189 MatrixStorage::Banded { ml, mu, .. } => {
190 let mlv = *ml as isize;
193 let muv = *mu as isize;
194 for j in 0..self.m {
195 let k1 = r1 as isize - j as isize;
196 let k2 = r2 as isize - j as isize;
197 let in1 = k1 >= -muv && k1 <= mlv;
198 let in2 = k2 >= -muv && k2 <= mlv;
199 if in1 && in2 {
200 let row1 = (k1 + *mu as isize) as usize;
201 let row2 = (k2 + *mu as isize) as usize;
202 self.data.swap(row1 * self.m + j, row2 * self.m + j);
203 } else if in1 || in2 {
204 if in1 {
207 let row1 = (k1 + *mu as isize) as usize;
208 let idx1 = row1 * self.m + j;
209 self.data[idx1] = T::zero();
210 } else {
211 let row2 = (k2 + *mu as isize) as usize;
212 let idx2 = row2 * self.m + j;
213 self.data[idx2] = T::zero();
214 }
215 }
216 }
217 }
218 }
219 }
220
221 pub fn fill(&mut self, value: T) {
223 self.data.fill(value);
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::Matrix;
230
231 #[test]
232 fn diagonal_constructor_sets_diagonal() {
233 let m = Matrix::diagonal(vec![1.0f64, 2.0, 3.0]);
234 assert_eq!(m[(0, 0)], 1.0);
235 assert_eq!(m[(1, 1)], 2.0);
236 assert_eq!(m[(2, 2)], 3.0);
237 assert_eq!(m[(0, 1)], 0.0);
238 assert_eq!(m[(2, 0)], 0.0);
239 }
240
241 #[test]
242 fn triangular_constructors_shape() {
243 let l: Matrix<f64> = Matrix::lower_triangular(4);
244 assert_eq!(l[(0, 3)], 0.0);
246 let u: Matrix<f64> = Matrix::upper_triangular(4);
247 assert_eq!(u[(3, 0)], 0.0);
249 }
250}