1use crate::Scalar;
8use faer::prelude::*;
9use faer::{ComplexField, Conjugate, Entity, Mat, MatMut, MatRef, SimpleEntity};
10use numra_core::LinalgError;
11
12pub trait Matrix<S: Scalar>: Clone + Sized {
28 fn zeros(rows: usize, cols: usize) -> Self;
30
31 fn identity(n: usize) -> Self;
33
34 fn nrows(&self) -> usize;
36
37 fn ncols(&self) -> usize;
39
40 fn get(&self, i: usize, j: usize) -> S;
42
43 fn set(&mut self, i: usize, j: usize, value: S);
45
46 fn fill_zero(&mut self);
48
49 fn scale(&mut self, alpha: S);
51
52 fn mul_vec(&self, x: &[S], y: &mut [S]);
54
55 fn add_scaled(&mut self, alpha: S, other: &Self);
57
58 fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError>;
60
61 fn is_square(&self) -> bool {
63 self.nrows() == self.ncols()
64 }
65}
66
67#[derive(Clone, Debug)]
69pub struct DenseMatrix<S: Scalar + Entity> {
70 data: Mat<S>,
71}
72
73impl<S: Scalar + Entity> DenseMatrix<S> {
74 pub fn from_faer(mat: Mat<S>) -> Self {
76 Self { data: mat }
77 }
78
79 pub fn as_faer(&self) -> MatRef<'_, S> {
81 self.data.as_ref()
82 }
83
84 pub fn as_faer_mut(&mut self) -> MatMut<'_, S> {
86 self.data.as_mut()
87 }
88
89 pub fn from_row_major(rows: usize, cols: usize, data: &[S]) -> Self {
91 assert_eq!(data.len(), rows * cols);
92 let mut mat = Mat::zeros(rows, cols);
93 for i in 0..rows {
94 for j in 0..cols {
95 mat.write(i, j, data[i * cols + j]);
96 }
97 }
98 Self { data: mat }
99 }
100
101 pub fn from_col_major(rows: usize, cols: usize, data: &[S]) -> Self {
103 assert_eq!(data.len(), rows * cols);
104 let mut mat = Mat::zeros(rows, cols);
105 for j in 0..cols {
106 for i in 0..rows {
107 mat.write(i, j, data[j * rows + i]);
108 }
109 }
110 Self { data: mat }
111 }
112
113 pub fn to_row_major(&self) -> Vec<S> {
115 let (rows, cols) = (self.data.nrows(), self.data.ncols());
116 let mut data = Vec::with_capacity(rows * cols);
117 for i in 0..rows {
118 for j in 0..cols {
119 data.push(self.data.read(i, j));
120 }
121 }
122 data
123 }
124
125 pub fn norm_frobenius(&self) -> S {
127 let mut sum = S::ZERO;
128 for i in 0..self.data.nrows() {
129 for j in 0..self.data.ncols() {
130 let v = self.data.read(i, j);
131 sum += v * v;
132 }
133 }
134 sum.sqrt()
135 }
136
137 pub fn norm_inf(&self) -> S {
139 let mut max_sum = S::ZERO;
140 for i in 0..self.data.nrows() {
141 let mut row_sum = S::ZERO;
142 for j in 0..self.data.ncols() {
143 row_sum += self.data.read(i, j).abs();
144 }
145 max_sum = max_sum.max(row_sum);
146 }
147 max_sum
148 }
149
150 pub fn rows(&self) -> usize {
152 self.data.nrows()
153 }
154
155 pub fn cols(&self) -> usize {
157 self.data.ncols()
158 }
159
160 pub fn is_square(&self) -> bool {
162 self.data.nrows() == self.data.ncols()
163 }
164}
165
166impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Matrix<S>
167 for DenseMatrix<S>
168{
169 fn zeros(rows: usize, cols: usize) -> Self {
170 Self {
171 data: Mat::zeros(rows, cols),
172 }
173 }
174
175 fn identity(n: usize) -> Self {
176 let mut mat = Mat::zeros(n, n);
177 for i in 0..n {
178 mat.write(i, i, S::ONE);
179 }
180 Self { data: mat }
181 }
182
183 fn nrows(&self) -> usize {
184 self.data.nrows()
185 }
186
187 fn ncols(&self) -> usize {
188 self.data.ncols()
189 }
190
191 fn get(&self, i: usize, j: usize) -> S {
192 self.data.read(i, j)
193 }
194
195 fn set(&mut self, i: usize, j: usize, value: S) {
196 self.data.write(i, j, value);
197 }
198
199 fn fill_zero(&mut self) {
200 for i in 0..self.nrows() {
201 for j in 0..self.ncols() {
202 self.data.write(i, j, S::ZERO);
203 }
204 }
205 }
206
207 fn scale(&mut self, alpha: S) {
208 for i in 0..self.nrows() {
209 for j in 0..self.ncols() {
210 let v = self.data.read(i, j);
211 self.data.write(i, j, alpha * v);
212 }
213 }
214 }
215
216 fn mul_vec(&self, x: &[S], y: &mut [S]) {
217 assert_eq!(x.len(), self.ncols());
218 assert_eq!(y.len(), self.nrows());
219
220 for (i, y_i) in y.iter_mut().enumerate().take(self.nrows()) {
221 let mut sum = S::ZERO;
222 for (j, &x_j) in x.iter().enumerate().take(self.ncols()) {
223 sum += self.data.read(i, j) * x_j;
224 }
225 *y_i = sum;
226 }
227 }
228
229 fn add_scaled(&mut self, alpha: S, other: &Self) {
230 assert_eq!(self.nrows(), other.nrows());
231 assert_eq!(self.ncols(), other.ncols());
232
233 for i in 0..self.nrows() {
234 for j in 0..self.ncols() {
235 let v = self.data.read(i, j) + alpha * other.data.read(i, j);
236 self.data.write(i, j, v);
237 }
238 }
239 }
240
241 fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
242 if !self.is_square() {
243 return Err(LinalgError::NotSquare {
244 nrows: self.nrows(),
245 ncols: self.ncols(),
246 });
247 }
248 if b.len() != self.nrows() {
249 return Err(LinalgError::DimensionMismatch {
250 expected: (self.nrows(), 1),
251 actual: (b.len(), 1),
252 });
253 }
254
255 let lu = self.data.as_ref().partial_piv_lu();
257
258 let mut b_mat = Mat::zeros(b.len(), 1);
260 for (i, &val) in b.iter().enumerate() {
261 b_mat.write(i, 0, val);
262 }
263
264 let x_mat = lu.solve(&b_mat);
266
267 let mut x = Vec::with_capacity(b.len());
269 for i in 0..b.len() {
270 x.push(x_mat.read(i, 0));
271 }
272
273 Ok(x)
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_zeros() {
283 let m: DenseMatrix<f64> = DenseMatrix::zeros(3, 4);
284 assert_eq!(m.nrows(), 3);
285 assert_eq!(m.ncols(), 4);
286 for i in 0..3 {
287 for j in 0..4 {
288 assert!((m.get(i, j) - 0.0).abs() < 1e-15);
289 }
290 }
291 }
292
293 #[test]
294 fn test_identity() {
295 let m: DenseMatrix<f64> = DenseMatrix::identity(3);
296 assert_eq!(m.nrows(), 3);
297 assert_eq!(m.ncols(), 3);
298 for i in 0..3 {
299 for j in 0..3 {
300 let expected = if i == j { 1.0 } else { 0.0 };
301 assert!((m.get(i, j) - expected).abs() < 1e-15);
302 }
303 }
304 }
305
306 #[test]
307 fn test_set_get() {
308 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
309 m.set(0, 0, 1.0);
310 m.set(0, 1, 2.0);
311 m.set(1, 0, 3.0);
312 m.set(1, 1, 4.0);
313
314 assert!((m.get(0, 0) - 1.0).abs() < 1e-15);
315 assert!((m.get(0, 1) - 2.0).abs() < 1e-15);
316 assert!((m.get(1, 0) - 3.0).abs() < 1e-15);
317 assert!((m.get(1, 1) - 4.0).abs() < 1e-15);
318 }
319
320 #[test]
321 fn test_mul_vec() {
322 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
325 m.set(0, 0, 1.0);
326 m.set(0, 1, 2.0);
327 m.set(1, 0, 3.0);
328 m.set(1, 1, 4.0);
329
330 let x = [1.0, 2.0];
331 let mut y = [0.0, 0.0];
332 m.mul_vec(&x, &mut y);
333
334 assert!((y[0] - 5.0).abs() < 1e-10);
335 assert!((y[1] - 11.0).abs() < 1e-10);
336 }
337
338 #[test]
339 fn test_scale() {
340 let mut m: DenseMatrix<f64> = DenseMatrix::identity(2);
341 m.scale(3.0);
342 assert!((m.get(0, 0) - 3.0).abs() < 1e-15);
343 assert!((m.get(1, 1) - 3.0).abs() < 1e-15);
344 }
345
346 #[test]
347 fn test_add_scaled() {
348 let mut a: DenseMatrix<f64> = DenseMatrix::identity(2);
349 let b: DenseMatrix<f64> = DenseMatrix::identity(2);
350 a.add_scaled(2.0, &b);
351
352 assert!((a.get(0, 0) - 3.0).abs() < 1e-15);
354 assert!((a.get(1, 1) - 3.0).abs() < 1e-15);
355 }
356
357 #[test]
358 fn test_solve_diagonal() {
359 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
362 m.set(0, 0, 2.0);
363 m.set(1, 1, 3.0);
364 m.set(2, 2, 4.0);
365
366 let b = vec![1.0, 2.0, 3.0];
367 let x = m.solve(&b).unwrap();
368
369 assert!((x[0] - 0.5).abs() < 1e-10);
370 assert!((x[1] - 2.0 / 3.0).abs() < 1e-10);
371 assert!((x[2] - 0.75).abs() < 1e-10);
372 }
373
374 #[test]
375 fn test_solve_general() {
376 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
379 m.set(0, 0, 1.0);
380 m.set(0, 1, 2.0);
381 m.set(1, 0, 3.0);
382 m.set(1, 1, 4.0);
383
384 let b = vec![5.0, 11.0];
385 let x = m.solve(&b).unwrap();
386
387 assert!((x[0] - 1.0).abs() < 1e-10);
388 assert!((x[1] - 2.0).abs() < 1e-10);
389 }
390
391 #[test]
392 fn test_from_row_major() {
393 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
394 let m: DenseMatrix<f64> = DenseMatrix::from_row_major(2, 3, &data);
395
396 assert_eq!(m.nrows(), 2);
397 assert_eq!(m.ncols(), 3);
398 assert!((m.get(0, 0) - 1.0).abs() < 1e-15);
399 assert!((m.get(0, 2) - 3.0).abs() < 1e-15);
400 assert!((m.get(1, 0) - 4.0).abs() < 1e-15);
401 assert!((m.get(1, 2) - 6.0).abs() < 1e-15);
402 }
403
404 #[test]
405 fn test_norm_frobenius() {
406 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
407 m.set(0, 0, 1.0);
408 m.set(0, 1, 2.0);
409 m.set(1, 0, 3.0);
410 m.set(1, 1, 4.0);
411
412 let norm = m.norm_frobenius();
414 assert!((norm - 30.0_f64.sqrt()).abs() < 1e-10);
415 }
416
417 #[test]
422 fn test_1x1_matrix() {
423 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(1, 1);
424 m.set(0, 0, 5.0);
425 assert!(m.is_square());
426 assert!((m.get(0, 0) - 5.0).abs() < 1e-15);
427
428 let b = vec![10.0];
429 let x = m.solve(&b).unwrap();
430 assert!((x[0] - 2.0).abs() < 1e-10);
431 }
432
433 #[test]
434 fn test_identity_1x1() {
435 let m: DenseMatrix<f64> = DenseMatrix::identity(1);
436 assert!((m.get(0, 0) - 1.0).abs() < 1e-15);
437 }
438
439 #[test]
440 fn test_rectangular_not_square() {
441 let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
442 assert!(!m.is_square());
443 }
444
445 #[test]
446 fn test_solve_non_square_error() {
447 let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
448 let b = vec![1.0, 2.0];
449 let result = m.solve(&b);
450 assert!(result.is_err());
451 }
452
453 #[test]
454 fn test_solve_dimension_mismatch() {
455 let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
456 let b = vec![1.0, 2.0, 3.0]; let result = m.solve(&b);
458 assert!(result.is_err());
459 }
460
461 #[test]
462 fn test_fill_zero() {
463 let mut m: DenseMatrix<f64> = DenseMatrix::identity(3);
464 m.fill_zero();
465 for i in 0..3 {
466 for j in 0..3 {
467 assert!(m.get(i, j).abs() < 1e-15);
468 }
469 }
470 }
471
472 #[test]
473 fn test_scale_by_zero() {
474 let mut m: DenseMatrix<f64> = DenseMatrix::identity(2);
475 m.scale(0.0);
476 for i in 0..2 {
477 for j in 0..2 {
478 assert!(m.get(i, j).abs() < 1e-15);
479 }
480 }
481 }
482
483 #[test]
484 fn test_scale_by_negative() {
485 let mut m: DenseMatrix<f64> = DenseMatrix::identity(2);
486 m.scale(-1.0);
487 assert!((m.get(0, 0) + 1.0).abs() < 1e-15);
488 assert!((m.get(1, 1) + 1.0).abs() < 1e-15);
489 }
490
491 #[test]
492 fn test_mul_vec_with_zeros() {
493 let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
494 let x = [100.0, 200.0];
495 let mut y = [999.0, 999.0];
496 m.mul_vec(&x, &mut y);
497 assert!(y[0].abs() < 1e-15);
498 assert!(y[1].abs() < 1e-15);
499 }
500
501 #[test]
502 fn test_norm_inf() {
503 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
504 m.set(0, 0, -1.0);
505 m.set(0, 1, 2.0);
506 m.set(1, 0, 3.0);
507 m.set(1, 1, -4.0);
508
509 assert!((m.norm_inf() - 7.0).abs() < 1e-10);
513 }
514
515 #[test]
516 fn test_zeros_large() {
517 let m: DenseMatrix<f64> = DenseMatrix::zeros(100, 100);
518 assert_eq!(m.nrows(), 100);
519 assert_eq!(m.ncols(), 100);
520 }
521
522 #[test]
523 fn test_from_col_major() {
524 let data = [1.0, 3.0, 2.0, 4.0];
526 let m: DenseMatrix<f64> = DenseMatrix::from_col_major(2, 2, &data);
527
528 assert!((m.get(0, 0) - 1.0).abs() < 1e-15);
529 assert!((m.get(1, 0) - 3.0).abs() < 1e-15);
530 assert!((m.get(0, 1) - 2.0).abs() < 1e-15);
531 assert!((m.get(1, 1) - 4.0).abs() < 1e-15);
532 }
533
534 #[test]
535 fn test_to_row_major() {
536 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
537 m.set(0, 0, 1.0);
538 m.set(0, 1, 2.0);
539 m.set(0, 2, 3.0);
540 m.set(1, 0, 4.0);
541 m.set(1, 1, 5.0);
542 m.set(1, 2, 6.0);
543
544 let data = m.to_row_major();
545 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
546 }
547
548 #[test]
549 fn test_solve_ill_conditioned() {
550 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
552 m.set(0, 0, 1.0);
553 m.set(0, 1, 0.5);
554 m.set(1, 0, 0.5);
555 m.set(1, 1, 0.333333333333);
556
557 let b = vec![1.5, 0.833333333333];
558 let result = m.solve(&b);
559 assert!(result.is_ok());
561 }
562
563 #[test]
568 fn test_f32_solve() {
569 let mut m: DenseMatrix<f32> = DenseMatrix::zeros(2, 2);
570 m.set(0, 0, 2.0);
571 m.set(0, 1, 0.0);
572 m.set(1, 0, 0.0);
573 m.set(1, 1, 3.0);
574
575 let b = vec![4.0f32, 9.0f32];
576 let x = m.solve(&b).unwrap();
577
578 assert!((x[0] - 2.0).abs() < 1e-5);
579 assert!((x[1] - 3.0).abs() < 1e-5);
580 }
581
582 #[test]
583 fn test_f32_identity() {
584 let m: DenseMatrix<f32> = DenseMatrix::identity(3);
585 for i in 0..3 {
586 for j in 0..3 {
587 let expected = if i == j { 1.0f32 } else { 0.0f32 };
588 assert!((m.get(i, j) - expected).abs() < 1e-7);
589 }
590 }
591 }
592
593 #[test]
594 fn test_f32_mul_vec() {
595 let mut m: DenseMatrix<f32> = DenseMatrix::zeros(2, 2);
596 m.set(0, 0, 1.0);
597 m.set(0, 1, 2.0);
598 m.set(1, 0, 3.0);
599 m.set(1, 1, 4.0);
600
601 let x = [1.0f32, 2.0f32];
602 let mut y = [0.0f32, 0.0f32];
603 m.mul_vec(&x, &mut y);
604
605 assert!((y[0] - 5.0).abs() < 1e-5);
606 assert!((y[1] - 11.0).abs() < 1e-5);
607 }
608
609 #[test]
614 fn test_solve_non_square_returns_not_square_error() {
615 let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 3);
616 let b = vec![1.0, 2.0];
617 match m.solve(&b) {
618 Err(LinalgError::NotSquare { nrows: 2, ncols: 3 }) => {}
619 other => panic!("Expected NotSquare error, got {:?}", other),
620 }
621 }
622
623 #[test]
624 fn test_solve_dimension_mismatch_returns_typed_error() {
625 let m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
626 let b = vec![1.0, 2.0, 3.0];
627 match m.solve(&b) {
628 Err(LinalgError::DimensionMismatch { .. }) => {}
629 other => panic!("Expected DimensionMismatch error, got {:?}", other),
630 }
631 }
632}