1use crate::{linalg::LinalgError, traits::Real};
4
5fn coalesce_triplets<T: Real>(
6 n: usize,
7 m: usize,
8 triplets: Vec<(usize, usize, T)>,
9) -> Vec<(usize, usize, T)> {
10 let mut coords: Vec<(usize, usize, T)> = Vec::new();
11 for (row, col, val) in triplets {
12 assert!(row < n && col < m, "Sparse triplet index out of bounds");
13 if let Some(entry) = coords.iter_mut().find(|(r, c, _)| *r == row && *c == col) {
14 entry.2 += val;
15 } else {
16 coords.push((row, col, val));
17 }
18 }
19 coords.retain(|(_, _, v)| *v != T::zero());
20 coords
21}
22
23#[derive(PartialEq, Clone, Debug)]
25pub enum MatrixStorage<T: Real> {
26 Identity,
28 Full,
30 Banded { ml: usize, mu: usize, zero: T },
34 Sparse {
36 coords: Vec<(usize, usize, T)>,
37 zero: T,
38 },
39}
40
41#[derive(PartialEq, Clone, Debug)]
43pub struct Matrix<T: Real> {
44 pub n: usize,
45 pub m: usize,
46 pub data: Vec<T>,
47 pub storage: MatrixStorage<T>,
48}
49
50impl<T: Real> Matrix<T> {
51 pub fn nrows(&self) -> usize {
53 self.n
54 }
55
56 pub fn ncols(&self) -> usize {
58 self.m
59 }
60
61 pub fn identity(n: usize) -> Self {
63 Matrix {
64 n,
65 m: n,
66 data: vec![T::one(), T::zero()],
68 storage: MatrixStorage::Identity,
69 }
70 }
71
72 pub fn from_vec(n: usize, m: usize, data: Vec<T>) -> Result<Self, LinalgError> {
77 if data.len() != n * m {
78 return Err(LinalgError::BadInput {
79 message: format!(
80 "Incompatible data length: expected {}, got {}",
81 n * m,
82 data.len()
83 ),
84 });
85 }
86 Ok(Matrix {
87 n,
88 m,
89 data,
90 storage: MatrixStorage::Full,
91 })
92 }
93
94 pub fn sparse(n: usize, m: usize) -> Self {
96 Matrix {
97 n,
98 m,
99 data: Vec::new(),
100 storage: MatrixStorage::Sparse {
101 coords: Vec::new(),
102 zero: T::zero(),
103 },
104 }
105 }
106
107 pub fn sparse_from_triplets(n: usize, m: usize, triplets: Vec<(usize, usize, T)>) -> Self {
114 let coords = coalesce_triplets(n, m, triplets);
115 Matrix {
116 n,
117 m,
118 data: Vec::new(),
119 storage: MatrixStorage::Sparse {
120 coords,
121 zero: T::zero(),
122 },
123 }
124 }
125
126 pub fn full(n: usize, m: usize) -> Self {
128 let data = vec![T::zero(); n * m];
129 Matrix {
130 n,
131 m,
132 data,
133 storage: MatrixStorage::Full,
134 }
135 }
136
137 pub fn square(n: usize) -> Self {
139 Matrix {
140 n,
141 m: n,
142 data: Vec::with_capacity(n * n),
143 storage: MatrixStorage::Full,
144 }
145 }
146
147 pub fn zeros(n: usize, m: usize) -> Self {
149 Matrix {
150 n,
151 m,
152 data: vec![T::zero(); n * m],
153 storage: MatrixStorage::Full,
154 }
155 }
156
157 pub fn banded(n: usize, ml: usize, mu: usize) -> Self {
160 let rows = ml + mu + 1;
161 let data = vec![T::zero(); rows * n];
162 Matrix {
163 n,
164 m: n,
165 data,
166 storage: MatrixStorage::Banded {
167 ml,
168 mu,
169 zero: T::zero(),
170 },
171 }
172 }
173
174 pub fn diagonal(diag: Vec<T>) -> Self {
176 let n = diag.len();
177 Matrix {
179 n,
180 m: n,
181 data: diag,
182 storage: MatrixStorage::Banded {
183 ml: 0,
184 mu: 0,
185 zero: T::zero(),
186 },
187 }
188 }
189
190 pub fn lower_triangular(n: usize) -> Self {
192 Matrix::banded(n, n.saturating_sub(1), 0)
193 }
194
195 pub fn upper_triangular(n: usize) -> Self {
197 Matrix::banded(n, 0, n.saturating_sub(1))
198 }
199
200 pub fn dims(&self) -> (usize, usize) {
202 (self.n, self.m)
203 }
204
205 pub fn to_dense_vec(&self) -> Vec<T> {
207 match &self.storage {
208 MatrixStorage::Full => self.data.clone(),
209 MatrixStorage::Identity => {
210 let mut dense = vec![T::zero(); self.n * self.m];
211 for i in 0..self.n.min(self.m) {
212 dense[i * self.m + i] = T::one();
213 }
214 dense
215 }
216 MatrixStorage::Banded { ml, mu, .. } => {
217 let mut dense = vec![T::zero(); self.n * self.m];
218 for col in 0..self.m {
219 for band_row in 0..(*ml + *mu + 1) {
220 let offset = band_row as isize - *mu as isize;
221 let row_signed = col as isize + offset;
222 if row_signed >= 0 && (row_signed as usize) < self.n {
223 let row = row_signed as usize;
224 dense[row * self.m + col] += self.data[band_row * self.m + col];
225 }
226 }
227 }
228 dense
229 }
230 MatrixStorage::Sparse { coords, .. } => {
231 let mut dense = vec![T::zero(); self.n * self.m];
232 for &(row, col, value) in coords {
233 dense[row * self.m + col] += value;
234 }
235 dense
236 }
237 }
238 }
239
240 pub fn make_full(&mut self) {
242 self.data = self.to_dense_vec();
243 self.storage = MatrixStorage::Full;
244 }
245
246 pub fn is_identity(&self) -> bool {
248 if let MatrixStorage::Identity = self.storage {
249 return true;
250 } else if let MatrixStorage::Full = self.storage {
251 for i in 0..self.n {
252 for j in 0..self.m {
253 let expected = if i == j { T::one() } else { T::zero() };
254 if self.data[i * self.m + j] != expected {
255 return false;
256 }
257 }
258 }
259 } else if let MatrixStorage::Banded {
260 ml: _ml,
261 mu: _mu,
262 zero,
263 } = self.storage
264 {
265 for i in 0..self.n {
266 for j in 0..self.m {
267 let expected = if i == j { T::one() } else { zero };
268 if self.data[i * self.m + j] != expected {
269 return false;
270 }
271 }
272 }
273 } else if let MatrixStorage::Sparse { ref coords, .. } = self.storage {
274 let diag_count = self.n.min(self.m);
275 if coords.len() != diag_count {
276 return false;
277 }
278 for &(r, c, v) in coords {
279 if r != c || v != T::one() {
280 return false;
281 }
282 }
283 }
284 true
285 }
286
287 pub fn swap_rows(&mut self, r1: usize, r2: usize) {
290 assert!(r1 < self.n && r2 < self.n, "row index out of bounds");
291 if r1 == r2 {
292 return;
293 }
294 match &mut self.storage {
295 MatrixStorage::Full => {
296 for j in 0..self.m {
297 self.data.swap(r1 * self.m + j, r2 * self.m + j);
298 }
299 }
300 MatrixStorage::Identity => {
301 }
304 MatrixStorage::Banded { ml, mu, .. } => {
305 let mlv = *ml as isize;
308 let muv = *mu as isize;
309 for j in 0..self.m {
310 let k1 = r1 as isize - j as isize;
311 let k2 = r2 as isize - j as isize;
312 let in1 = k1 >= -muv && k1 <= mlv;
313 let in2 = k2 >= -muv && k2 <= mlv;
314 if in1 && in2 {
315 let row1 = (k1 + *mu as isize) as usize;
316 let row2 = (k2 + *mu as isize) as usize;
317 self.data.swap(row1 * self.m + j, row2 * self.m + j);
318 } else if in1 || in2 {
319 if in1 {
322 let row1 = (k1 + *mu as isize) as usize;
323 let idx1 = row1 * self.m + j;
324 self.data[idx1] = T::zero();
325 } else {
326 let row2 = (k2 + *mu as isize) as usize;
327 let idx2 = row2 * self.m + j;
328 self.data[idx2] = T::zero();
329 }
330 }
331 }
332 }
333 MatrixStorage::Sparse { coords, .. } => {
334 for item in coords.iter_mut() {
335 if item.0 == r1 {
336 item.0 = r2;
337 } else if item.0 == r2 {
338 item.0 = r1;
339 }
340 }
341 }
342 }
343 }
344
345 pub fn fill(&mut self, value: T) {
347 match &mut self.storage {
348 MatrixStorage::Identity
349 | MatrixStorage::Banded { .. }
350 | MatrixStorage::Sparse { .. }
351 if value != T::zero() =>
352 {
353 self.data = vec![value; self.n * self.m];
354 self.storage = MatrixStorage::Full;
355 }
356 MatrixStorage::Sparse { coords, zero } => {
357 coords.clear();
358 *zero = T::zero();
359 }
360 _ => self.data.fill(value),
361 }
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::{LinalgError, Matrix, MatrixStorage};
368
369 #[test]
370 fn diagonal_constructor_sets_diagonal() {
371 let m = Matrix::diagonal(vec![1.0f64, 2.0, 3.0]);
372 assert_eq!(m[(0, 0)], 1.0);
373 assert_eq!(m[(1, 1)], 2.0);
374 assert_eq!(m[(2, 2)], 3.0);
375 assert_eq!(m[(0, 1)], 0.0);
376 assert_eq!(m[(2, 0)], 0.0);
377 }
378
379 #[test]
380 fn triangular_constructors_shape() {
381 let l: Matrix<f64> = Matrix::lower_triangular(4);
382 assert_eq!(l[(0, 3)], 0.0);
384 let u: Matrix<f64> = Matrix::upper_triangular(4);
385 assert_eq!(u[(3, 0)], 0.0);
387 }
388
389 #[test]
390 fn from_vec_rejects_incompatible_data_length() {
391 let result = Matrix::<f64>::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0]);
392
393 assert_eq!(
394 result,
395 Err(LinalgError::BadInput {
396 message: "Incompatible data length: expected 6, got 4".to_string(),
397 })
398 );
399 }
400
401 #[test]
402 fn sparse_triplets_coalesce_duplicates() {
403 let m = Matrix::sparse_from_triplets(2, 3, vec![(0, 1, 2.0), (0, 1, 3.0), (1, 2, 4.0)]);
404 assert_eq!(m[(0, 0)], 0.0);
405 assert_eq!(m[(0, 1)], 5.0);
406 assert_eq!(m[(1, 2)], 4.0);
407 assert_eq!(m.to_dense_vec(), vec![0.0, 5.0, 0.0, 0.0, 0.0, 4.0]);
408 }
409
410 #[test]
411 fn sparse_index_mut_replaces_coalesced_entry() {
412 let mut m = Matrix::sparse_from_triplets(2, 2, vec![(0, 1, 2.0), (0, 1, 3.0), (1, 1, 4.0)]);
413 m[(0, 1)] = 7.0;
414 assert_eq!(m[(0, 1)], 7.0);
415 m[(1, 0)] = 0.0;
416 assert_eq!(m[(1, 0)], 0.0);
417 }
418
419 #[test]
420 fn sparse_fill_zero_preserves_sparse_storage() {
421 let mut m = Matrix::sparse_from_triplets(2, 2, vec![(0, 1, 2.0)]);
422 m.fill(0.0);
423 assert_eq!(m[(0, 1)], 0.0);
424 assert!(matches!(m.storage, MatrixStorage::Sparse { .. }));
425 }
426
427 #[test]
428 fn sparse_storage_carries_zero_reference() {
429 let m = Matrix::<f64>::sparse(2, 2);
430 match &m.storage {
431 MatrixStorage::Sparse { zero, .. } => {
432 assert_eq!(m[(1, 1)], *zero);
433 }
434 _ => panic!("expected sparse storage"),
435 }
436 }
437}