1use std::fmt;
26use std::ops::{Add, Index, IndexMut, Mul, Sub};
27
28use rayon::prelude::*;
29
30const PARALLEL_THRESHOLD: usize = 16_384;
31
32fn should_parallelize(rows: usize, cols: usize) -> bool {
33 rows.saturating_mul(cols) >= PARALLEL_THRESHOLD
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum MatrixError {
38 DimensionMismatch {
39 operation: &'static str,
40 left: (usize, usize),
41 right: (usize, usize),
42 },
43}
44
45#[derive(Debug, Clone, Copy)]
46pub struct LeastSquaresQrInfo {
47 pub rank: usize,
48 pub cond_est: f64,
49}
50
51impl fmt::Display for MatrixError {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 MatrixError::DimensionMismatch {
55 operation,
56 left,
57 right,
58 } => write!(
59 f,
60 "Matrix dimension mismatch for {}: left is {}x{}, right is {}x{}",
61 operation, left.0, left.1, right.0, right.1
62 ),
63 }
64 }
65}
66
67impl std::error::Error for MatrixError {}
68
69#[derive(Debug, Clone, PartialEq)]
70pub struct Matrix {
71 data: Vec<f64>,
72 rows: usize,
73 cols: usize,
74}
75
76impl Matrix {
77 pub fn new(rows: usize, cols: usize) -> Self {
78 Matrix {
79 data: vec![0.0; rows * cols],
80 rows,
81 cols,
82 }
83 }
84
85 pub fn from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Result<Self, String> {
86 if data.len() != rows * cols {
87 return Err("Data length doesn't match dimensions".to_string());
88 }
89 Ok(Matrix { data, rows, cols })
90 }
91
92 pub fn identity(size: usize) -> Self {
93 let mut m = Matrix::new(size, size);
94 for i in 0..size {
95 m[(i, i)] = 1.0;
96 }
97 m
98 }
99
100 pub fn rows(&self) -> usize {
101 self.rows
102 }
103
104 pub fn cols(&self) -> usize {
105 self.cols
106 }
107
108 pub fn transpose(&self) -> Matrix {
109 self.transpose_with_parallel(false)
110 }
111
112 pub fn transpose_with_parallel(&self, parallel: bool) -> Matrix {
113 let mut result = Matrix::new(self.cols, self.rows);
114 if parallel && should_parallelize(self.rows, self.cols) {
115 result
116 .data
117 .par_chunks_mut(self.rows)
118 .enumerate()
119 .for_each(|(j, row)| {
120 for i in 0..self.rows {
121 row[i] = self[(i, j)];
122 }
123 });
124 return result;
125 }
126
127 for i in 0..self.rows {
128 for j in 0..self.cols {
129 result[(j, i)] = self[(i, j)];
130 }
131 }
132 result
133 }
134
135 pub fn lu_decomposition(&self) -> Result<(Matrix, Matrix, Vec<usize>), String> {
136 self.lu_decomposition_with_tolerance(1e-12)
137 }
138
139 pub fn lu_decomposition_with_tolerance(
140 &self,
141 singular_relative_epsilon: f64,
142 ) -> Result<(Matrix, Matrix, Vec<usize>), String> {
143 if self.rows != self.cols {
144 return Err("Matrix must be square for LU decomposition".to_string());
145 }
146 if !singular_relative_epsilon.is_finite() || singular_relative_epsilon < 0.0 {
147 return Err("Invalid singular tolerance".to_string());
148 }
149
150 let n = self.rows;
151 let mut l = Matrix::identity(n);
152 let mut u = self.clone();
153 let mut pivot = (0..n).collect::<Vec<_>>();
154
155 for k in 0..n {
156 let mut max_val = 0.0;
157 let mut max_row = k;
158 for i in k..n {
159 let val = u[(i, k)].abs();
160 if val > max_val {
161 max_val = val;
162 max_row = i;
163 }
164 }
165
166 let mut row_norm: f64 = 0.0;
167 for j in k..n {
168 row_norm = row_norm.max(u[(max_row, j)].abs());
169 }
170
171 if max_val <= singular_relative_epsilon * row_norm {
172 return Err("Matrix is singular".to_string());
173 }
174
175 if max_row != k {
176 pivot.swap(k, max_row);
177 for j in 0..n {
178 let temp = u[(k, j)];
179 u[(k, j)] = u[(max_row, j)];
180 u[(max_row, j)] = temp;
181 }
182 for j in 0..k {
183 let temp = l[(k, j)];
184 l[(k, j)] = l[(max_row, j)];
185 l[(max_row, j)] = temp;
186 }
187 }
188
189 for i in (k + 1)..n {
190 l[(i, k)] = u[(i, k)] / u[(k, k)];
191 for j in k..n {
192 u[(i, j)] -= l[(i, k)] * u[(k, j)];
193 }
194 }
195 }
196
197 Ok((l, u, pivot))
198 }
199
200 pub fn solve_lu(&self, b: &Matrix) -> Result<Matrix, String> {
201 self.solve_lu_with_tolerance(b, 1e-12)
202 }
203
204 pub fn solve_lu_with_tolerance(
205 &self,
206 b: &Matrix,
207 singular_relative_epsilon: f64,
208 ) -> Result<Matrix, String> {
209 if self.rows != self.cols {
210 return Err("Matrix must be square".to_string());
211 }
212 if b.rows != self.rows || b.cols != 1 {
213 return Err("Invalid dimensions for b".to_string());
214 }
215
216 let (l, u, pivot) = self.lu_decomposition_with_tolerance(singular_relative_epsilon)?;
217
218 let mut pb = Matrix::new(b.rows, 1);
219 for i in 0..b.rows {
220 pb[(i, 0)] = b[(pivot[i], 0)];
221 }
222
223 let mut y = Matrix::new(self.rows, 1);
224 for i in 0..self.rows {
225 y[(i, 0)] = pb[(i, 0)];
226 for j in 0..i {
227 y[(i, 0)] -= l[(i, j)] * y[(j, 0)];
228 }
229 }
230
231 let mut x = Matrix::new(self.rows, 1);
232 for i in (0..self.rows).rev() {
233 x[(i, 0)] = y[(i, 0)];
234 for j in (i + 1)..self.rows {
235 x[(i, 0)] -= u[(i, j)] * x[(j, 0)];
236 }
237 x[(i, 0)] /= u[(i, i)];
238 }
239
240 Ok(x)
241 }
242
243 pub fn solve_least_squares_qr(&self, b: &Matrix) -> Result<Matrix, String> {
249 self.solve_least_squares_qr_with_info(b)
250 .map(|(solution, _)| solution)
251 }
252
253 pub fn solve_least_squares_qr_with_parallel(
254 &self,
255 b: &Matrix,
256 parallel: bool,
257 ) -> Result<Matrix, String> {
258 self.solve_least_squares_qr_with_info_with_parallel(b, parallel)
259 .map(|(solution, _)| solution)
260 }
261
262 pub fn solve_least_squares_qr_with_info(
265 &self,
266 b: &Matrix,
267 ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
268 self.solve_least_squares_qr_with_info_with_parallel(b, false)
269 }
270
271 pub fn solve_least_squares_qr_with_info_with_parallel(
272 &self,
273 b: &Matrix,
274 parallel: bool,
275 ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
276 if b.cols != 1 {
277 return Err("Invalid dimensions for b (expected column vector)".to_string());
278 }
279 if b.rows != self.rows {
280 return Err("Invalid dimensions for b".to_string());
281 }
282
283 let m = self.rows;
284 let n = self.cols;
285
286 if n == 0 {
288 return Ok((
289 Matrix::new(0, 1),
290 LeastSquaresQrInfo {
291 rank: 0,
292 cond_est: f64::INFINITY,
293 },
294 ));
295 }
296
297 if m >= n {
298 Self::solve_least_squares_qr_tall_with_info(self, b, parallel)
299 } else {
300 Self::solve_least_squares_qr_wide_with_info(self, b, parallel)
301 }
302 }
303
304 fn solve_least_squares_qr_tall_with_info(
305 a: &Matrix,
306 b: &Matrix,
307 parallel: bool,
308 ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
309 let m = a.rows;
310 let n = a.cols;
311
312 let mut r = a.clone();
313 let mut qt_b = b.clone();
314 let mut taus = Vec::with_capacity(n);
315
316 for k in 0..n {
318 let mut col_norm: f64 = 0.0;
319 for i in k..m {
320 col_norm = col_norm.hypot(r[(i, k)]);
321 }
322
323 if col_norm == 0.0 {
324 taus.push(0.0);
325 continue;
326 }
327
328 let x0 = r[(k, k)];
329 let sign = if x0 >= 0.0 { 1.0 } else { -1.0 };
330 let alpha = -sign * col_norm;
331 let v0 = x0 - alpha;
332
333 for i in (k + 1)..m {
335 r[(i, k)] /= v0;
336 }
337
338 let mut v_sq = 1.0;
339 for i in (k + 1)..m {
340 let vi = r[(i, k)];
341 v_sq += vi * vi;
342 }
343 let tau = 2.0 / v_sq;
344 taus.push(tau);
345
346 let cols_left = n.saturating_sub(k + 1);
348 if cols_left > 0 {
349 let use_parallel = parallel && should_parallelize(m - (k + 1), cols_left);
350 let dots: Vec<f64> = if use_parallel {
351 (0..cols_left)
352 .into_par_iter()
353 .map(|offset| {
354 let j = k + 1 + offset;
355 let mut dot = r[(k, j)];
356 for i in (k + 1)..m {
357 dot += r[(i, k)] * r[(i, j)];
358 }
359 dot * tau
360 })
361 .collect()
362 } else {
363 let mut dots = Vec::with_capacity(cols_left);
364 for j in (k + 1)..n {
365 let mut dot = r[(k, j)];
366 for i in (k + 1)..m {
367 dot += r[(i, k)] * r[(i, j)];
368 }
369 dots.push(dot * tau);
370 }
371 dots
372 };
373
374 for (offset, dot) in dots.iter().enumerate() {
375 let j = k + 1 + offset;
376 r[(k, j)] -= dot;
377 }
378
379 if use_parallel {
380 let cols = r.cols;
381 let j_start = k + 1;
382 let k_col = k;
383 r.data[(k + 1) * cols..]
384 .par_chunks_mut(cols)
385 .for_each(|row| {
386 let vik = row[k_col];
387 if vik != 0.0 {
388 for (offset, dot) in dots.iter().enumerate() {
389 row[j_start + offset] -= vik * dot;
390 }
391 }
392 });
393 } else {
394 for i in (k + 1)..m {
395 let vik = r[(i, k)];
396 if vik != 0.0 {
397 for (offset, dot) in dots.iter().enumerate() {
398 let j = k + 1 + offset;
399 r[(i, j)] -= vik * dot;
400 }
401 }
402 }
403 }
404 }
405
406 let mut dot = qt_b[(k, 0)];
408 for i in (k + 1)..m {
409 dot += r[(i, k)] * qt_b[(i, 0)];
410 }
411 dot *= tau;
412 qt_b[(k, 0)] -= dot;
413
414 let use_parallel = parallel && should_parallelize(m - (k + 1), 1);
415 if use_parallel {
416 let cols = r.cols;
417 let k_col = k;
418 let r_data = &r.data;
419 qt_b.data[(k + 1)..]
420 .par_iter_mut()
421 .enumerate()
422 .for_each(|(idx, val)| {
423 let i = k + 1 + idx;
424 let vik = r_data[i * cols + k_col];
425 *val -= vik * dot;
426 });
427 } else {
428 for i in (k + 1)..m {
429 qt_b[(i, 0)] -= r[(i, k)] * dot;
430 }
431 }
432
433 r[(k, k)] = alpha;
434 }
435
436 let mut max_diag: f64 = 0.0;
437 for i in 0..n {
438 max_diag = max_diag.max(r[(i, i)].abs());
439 }
440 let tol = 1e-12 * max_diag.max(1.0);
441 let mut rank = 0;
442 let mut min_diag = f64::INFINITY;
443 for i in 0..n {
444 let diag = r[(i, i)].abs();
445 if diag > tol {
446 rank += 1;
447 if diag < min_diag {
448 min_diag = diag;
449 }
450 }
451 }
452 let cond_est = if rank == 0 || !min_diag.is_finite() {
453 f64::INFINITY
454 } else {
455 max_diag / min_diag
456 };
457
458 let mut x = Matrix::new(n, 1);
460 for i in (0..n).rev() {
461 let mut sum = qt_b[(i, 0)];
462 for j in (i + 1)..n {
463 sum -= r[(i, j)] * x[(j, 0)];
464 }
465
466 let diag = r[(i, i)];
467 if !diag.is_finite() {
468 return Err("Least squares solve failed: non-finite diagonal in R".to_string());
469 }
470
471 let mut row_norm: f64 = 0.0;
472 for j in i..n {
473 row_norm = row_norm.max(r[(i, j)].abs());
474 }
475
476 if diag.abs() <= 1e-12 * row_norm {
477 return Err("Least squares solve failed: matrix is rank deficient".to_string());
478 }
479
480 x[(i, 0)] = sum / diag;
481 }
482
483 Ok((
484 x,
485 LeastSquaresQrInfo {
486 rank,
487 cond_est,
488 },
489 ))
490 }
491
492 fn solve_least_squares_qr_wide_with_info(
493 a: &Matrix,
494 b: &Matrix,
495 parallel: bool,
496 ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
497 let m = a.rows;
503 let n = a.cols;
504
505 let mut r = a.transpose(); let mut taus = Vec::with_capacity(m);
507
508 for k in 0..m {
510 let mut col_norm: f64 = 0.0;
511 for i in k..n {
512 col_norm = col_norm.hypot(r[(i, k)]);
513 }
514
515 if col_norm == 0.0 {
516 taus.push(0.0);
517 continue;
518 }
519
520 let x0 = r[(k, k)];
521 let sign = if x0 >= 0.0 { 1.0 } else { -1.0 };
522 let alpha = -sign * col_norm;
523 let v0 = x0 - alpha;
524
525 for i in (k + 1)..n {
526 r[(i, k)] /= v0;
527 }
528
529 let mut v_sq = 1.0;
530 for i in (k + 1)..n {
531 let vi = r[(i, k)];
532 v_sq += vi * vi;
533 }
534 let tau = 2.0 / v_sq;
535 taus.push(tau);
536
537 let cols_left = m.saturating_sub(k + 1);
538 if cols_left > 0 {
539 let use_parallel = parallel && should_parallelize(n - (k + 1), cols_left);
540 let dots: Vec<f64> = if use_parallel {
541 (0..cols_left)
542 .into_par_iter()
543 .map(|offset| {
544 let j = k + 1 + offset;
545 let mut dot = r[(k, j)];
546 for i in (k + 1)..n {
547 dot += r[(i, k)] * r[(i, j)];
548 }
549 dot * tau
550 })
551 .collect()
552 } else {
553 let mut dots = Vec::with_capacity(cols_left);
554 for j in (k + 1)..m {
555 let mut dot = r[(k, j)];
556 for i in (k + 1)..n {
557 dot += r[(i, k)] * r[(i, j)];
558 }
559 dots.push(dot * tau);
560 }
561 dots
562 };
563
564 for (offset, dot) in dots.iter().enumerate() {
565 let j = k + 1 + offset;
566 r[(k, j)] -= dot;
567 }
568
569 if use_parallel {
570 let cols = r.cols;
571 let j_start = k + 1;
572 let k_col = k;
573 r.data[(k + 1) * cols..]
574 .par_chunks_mut(cols)
575 .for_each(|row| {
576 let vik = row[k_col];
577 if vik != 0.0 {
578 for (offset, dot) in dots.iter().enumerate() {
579 row[j_start + offset] -= vik * dot;
580 }
581 }
582 });
583 } else {
584 for i in (k + 1)..n {
585 let vik = r[(i, k)];
586 if vik != 0.0 {
587 for (offset, dot) in dots.iter().enumerate() {
588 let j = k + 1 + offset;
589 r[(i, j)] -= vik * dot;
590 }
591 }
592 }
593 }
594 }
595
596 r[(k, k)] = alpha;
597 }
598
599 let mut max_diag: f64 = 0.0;
600 for i in 0..m {
601 max_diag = max_diag.max(r[(i, i)].abs());
602 }
603 let tol = 1e-12 * max_diag.max(1.0);
604 let mut rank = 0;
605 let mut min_diag = f64::INFINITY;
606 for i in 0..m {
607 let diag = r[(i, i)].abs();
608 if diag > tol {
609 rank += 1;
610 if diag < min_diag {
611 min_diag = diag;
612 }
613 }
614 }
615 let cond_est = if rank == 0 || !min_diag.is_finite() {
616 f64::INFINITY
617 } else {
618 max_diag / min_diag
619 };
620
621 let mut y = vec![0.0; m];
623 for i in 0..m {
624 let mut sum = b[(i, 0)];
625 for j in 0..i {
626 sum -= r[(j, i)] * y[j];
627 }
628
629 let diag = r[(i, i)];
630 if !diag.is_finite() {
631 return Err("Least squares solve failed: non-finite diagonal in R".to_string());
632 }
633
634 let mut col_norm: f64 = 0.0;
635 for j in 0..=i {
636 col_norm = col_norm.max(r[(j, i)].abs());
637 }
638
639 if diag.abs() <= 1e-12 * col_norm {
640 return Err("Least squares solve failed: matrix is rank deficient".to_string());
641 }
642
643 y[i] = sum / diag;
644 }
645
646 let mut w = vec![0.0; n];
648 w[..m].copy_from_slice(&y[..m]);
649
650 for k in (0..m).rev() {
653 let tau = taus[k];
654 if tau == 0.0 {
655 continue;
656 }
657
658 let mut dot = w[k];
659 for i in (k + 1)..n {
660 dot += r[(i, k)] * w[i];
661 }
662 dot *= tau;
663
664 w[k] -= dot;
665
666 let use_parallel = parallel && should_parallelize(n - (k + 1), 1);
667 if use_parallel {
668 let cols = r.cols;
669 let k_col = k;
670 let r_data = &r.data;
671 w[(k + 1)..]
672 .par_iter_mut()
673 .enumerate()
674 .for_each(|(idx, val)| {
675 let i = k + 1 + idx;
676 let vik = r_data[i * cols + k_col];
677 *val -= vik * dot;
678 });
679 } else {
680 for i in (k + 1)..n {
681 w[i] -= r[(i, k)] * dot;
682 }
683 }
684 }
685
686 Matrix::from_vec(w, n, 1).map(|solution| {
687 (
688 solution,
689 LeastSquaresQrInfo {
690 rank,
691 cond_est,
692 },
693 )
694 })
695 }
696
697 pub fn norm(&self) -> f64 {
698 self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
699 }
700
701 pub fn try_add(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
702 if self.rows != rhs.rows || self.cols != rhs.cols {
703 return Err(MatrixError::DimensionMismatch {
704 operation: "add",
705 left: (self.rows, self.cols),
706 right: (rhs.rows, rhs.cols),
707 });
708 }
709
710 let mut result = Matrix::new(self.rows, self.cols);
711 for i in 0..self.data.len() {
712 result.data[i] = self.data[i] + rhs.data[i];
713 }
714 Ok(result)
715 }
716
717 pub fn try_sub(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
718 if self.rows != rhs.rows || self.cols != rhs.cols {
719 return Err(MatrixError::DimensionMismatch {
720 operation: "sub",
721 left: (self.rows, self.cols),
722 right: (rhs.rows, rhs.cols),
723 });
724 }
725
726 let mut result = Matrix::new(self.rows, self.cols);
727 for i in 0..self.data.len() {
728 result.data[i] = self.data[i] - rhs.data[i];
729 }
730 Ok(result)
731 }
732
733 pub fn try_mul(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
734 self.try_mul_with_parallel(rhs, false)
735 }
736
737 pub fn try_mul_with_parallel(&self, rhs: &Matrix, parallel: bool) -> Result<Matrix, MatrixError> {
738 if self.cols != rhs.rows {
739 return Err(MatrixError::DimensionMismatch {
740 operation: "mul",
741 left: (self.rows, self.cols),
742 right: (rhs.rows, rhs.cols),
743 });
744 }
745
746 let mut result = Matrix::new(self.rows, rhs.cols);
747 if parallel && should_parallelize(self.rows, rhs.cols) {
748 let rhs_cols = rhs.cols;
749 result
750 .data
751 .par_chunks_mut(rhs_cols)
752 .enumerate()
753 .for_each(|(i, row)| {
754 for k in 0..self.cols {
755 let a = self[(i, k)];
756 let rhs_row = &rhs.data[k * rhs_cols..(k + 1) * rhs_cols];
757 for j in 0..rhs_cols {
758 row[j] += a * rhs_row[j];
759 }
760 }
761 });
762 return Ok(result);
763 }
764
765 for i in 0..self.rows {
766 for k in 0..self.cols {
767 let a = self[(i, k)];
768 for j in 0..rhs.cols {
769 result[(i, j)] += a * rhs[(k, j)];
770 }
771 }
772 }
773 Ok(result)
774 }
775}
776
777impl Index<(usize, usize)> for Matrix {
778 type Output = f64;
779
780 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
781 &self.data[row * self.cols + col]
782 }
783}
784
785impl IndexMut<(usize, usize)> for Matrix {
786 fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output {
787 &mut self.data[row * self.cols + col]
788 }
789}
790
791impl Add for &Matrix {
792 type Output = Matrix;
793
794 fn add(self, rhs: Self) -> Self::Output {
795 self.try_add(rhs).unwrap_or_else(|err| panic!("{}", err))
796 }
797}
798
799impl Sub for &Matrix {
800 type Output = Matrix;
801
802 fn sub(self, rhs: Self) -> Self::Output {
803 self.try_sub(rhs).unwrap_or_else(|err| panic!("{}", err))
804 }
805}
806
807impl Mul for &Matrix {
808 type Output = Matrix;
809
810 fn mul(self, rhs: Self) -> Self::Output {
811 self.try_mul(rhs).unwrap_or_else(|err| panic!("{}", err))
812 }
813}
814
815impl Mul<f64> for &Matrix {
816 type Output = Matrix;
817
818 fn mul(self, scalar: f64) -> Self::Output {
819 let mut result = self.clone();
820 for val in &mut result.data {
821 *val *= scalar;
822 }
823 result
824 }
825}
826
827impl fmt::Display for Matrix {
828 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
829 for i in 0..self.rows {
830 write!(f, "[")?;
831 for j in 0..self.cols {
832 if j > 0 {
833 write!(f, ", ")?;
834 }
835 write!(f, "{:8.4}", self[(i, j)])?;
836 }
837 writeln!(f, "]")?;
838 }
839 Ok(())
840 }
841}
842
843#[cfg(test)]
844mod tests {
845 use super::*;
846
847 struct TestRng {
848 state: u64,
849 }
850
851 impl TestRng {
852 fn new(seed: u64) -> Self {
853 Self { state: seed }
854 }
855
856 fn next_u32(&mut self) -> u32 {
857 self.state = self
858 .state
859 .wrapping_mul(6364136223846793005)
860 .wrapping_add(1);
861 (self.state >> 32) as u32
862 }
863
864 fn next_f64(&mut self) -> f64 {
865 let v = self.next_u32() as f64 / u32::MAX as f64;
866 2.0 * v - 1.0
867 }
868 }
869
870 fn random_matrix(rows: usize, cols: usize, rng: &mut TestRng) -> Matrix {
871 let mut m = Matrix::new(rows, cols);
872 for i in 0..rows {
873 for j in 0..cols {
874 m[(i, j)] = rng.next_f64();
875 }
876 }
877 m
878 }
879
880 fn assert_matrix_close(a: &Matrix, b: &Matrix, tol: f64) {
881 assert_eq!(a.rows(), b.rows());
882 assert_eq!(a.cols(), b.cols());
883 for i in 0..a.rows() {
884 for j in 0..a.cols() {
885 let diff = (a[(i, j)] - b[(i, j)]).abs();
886 assert!(diff <= tol, "mismatch at ({i},{j}): {diff}");
887 }
888 }
889 }
890
891 #[test]
892 fn test_matrix_creation() {
893 let m = Matrix::new(2, 3);
894 assert_eq!(m.rows(), 2);
895 assert_eq!(m.cols(), 3);
896 assert_eq!(m[(0, 0)], 0.0);
897 }
898
899 #[test]
900 fn test_identity() {
901 let m = Matrix::identity(3);
902 assert_eq!(m[(0, 0)], 1.0);
903 assert_eq!(m[(1, 1)], 1.0);
904 assert_eq!(m[(2, 2)], 1.0);
905 assert_eq!(m[(0, 1)], 0.0);
906 }
907
908 #[test]
909 fn test_transpose() {
910 let mut m = Matrix::new(2, 3);
911 m[(0, 0)] = 1.0;
912 m[(0, 1)] = 2.0;
913 m[(0, 2)] = 3.0;
914 m[(1, 0)] = 4.0;
915 m[(1, 1)] = 5.0;
916 m[(1, 2)] = 6.0;
917
918 let mt_serial = m.transpose_with_parallel(false);
919 let mt_parallel = m.transpose_with_parallel(true);
920 assert_matrix_close(&mt_serial, &mt_parallel, 0.0);
921 assert_eq!(mt_serial.rows(), 3);
922 assert_eq!(mt_serial.cols(), 2);
923 assert_eq!(mt_serial[(0, 0)], 1.0);
924 assert_eq!(mt_serial[(1, 0)], 2.0);
925 assert_eq!(mt_serial[(2, 0)], 3.0);
926 assert_eq!(mt_serial[(0, 1)], 4.0);
927 }
928
929 #[test]
930 fn test_matrix_multiplication() {
931 let mut a = Matrix::new(2, 3);
932 a[(0, 0)] = 1.0;
933 a[(0, 1)] = 2.0;
934 a[(0, 2)] = 3.0;
935 a[(1, 0)] = 4.0;
936 a[(1, 1)] = 5.0;
937 a[(1, 2)] = 6.0;
938
939 let mut b = Matrix::new(3, 2);
940 b[(0, 0)] = 7.0;
941 b[(0, 1)] = 8.0;
942 b[(1, 0)] = 9.0;
943 b[(1, 1)] = 10.0;
944 b[(2, 0)] = 11.0;
945 b[(2, 1)] = 12.0;
946
947 let c_serial = a.try_mul_with_parallel(&b, false).unwrap();
948 let c_parallel = a.try_mul_with_parallel(&b, true).unwrap();
949 assert_matrix_close(&c_serial, &c_parallel, 0.0);
950 assert_eq!(c_serial.rows(), 2);
951 assert_eq!(c_serial.cols(), 2);
952 assert_eq!(c_serial[(0, 0)], 58.0);
953 assert_eq!(c_serial[(0, 1)], 64.0);
954 assert_eq!(c_serial[(1, 0)], 139.0);
955 assert_eq!(c_serial[(1, 1)], 154.0);
956 }
957
958 #[test]
959 fn test_parallel_equivalence_threshold_size() {
960 let size = 128;
961 let mut rng = TestRng::new(0x3e5a_9f21_d00d_cafe);
962
963 let a = random_matrix(size, size, &mut rng);
964 let b = random_matrix(size, size, &mut rng);
965
966 let c_serial = a.try_mul_with_parallel(&b, false).unwrap();
967 let c_parallel = a.try_mul_with_parallel(&b, true).unwrap();
968 assert_matrix_close(&c_serial, &c_parallel, 1e-10);
969
970 let t_serial = a.transpose_with_parallel(false);
971 let t_parallel = a.transpose_with_parallel(true);
972 assert_matrix_close(&t_serial, &t_parallel, 0.0);
973
974 let m = size * 4;
975 let n = size;
976 let tall = random_matrix(m, n, &mut rng);
977 let tall_b = random_matrix(m, 1, &mut rng);
978 let (x_serial, info_serial) =
979 tall.solve_least_squares_qr_with_info_with_parallel(&tall_b, false)
980 .unwrap();
981 let (x_parallel, info_parallel) =
982 tall.solve_least_squares_qr_with_info_with_parallel(&tall_b, true)
983 .unwrap();
984 assert_eq!(info_serial.rank, info_parallel.rank);
985 assert!(info_serial.cond_est.is_finite());
986 assert!(info_parallel.cond_est.is_finite());
987 assert_matrix_close(&x_serial, &x_parallel, 1e-7);
988
989 let wide = random_matrix(n, m, &mut rng);
990 let wide_b = random_matrix(n, 1, &mut rng);
991 let (x_serial, info_serial) =
992 wide.solve_least_squares_qr_with_info_with_parallel(&wide_b, false)
993 .unwrap();
994 let (x_parallel, info_parallel) =
995 wide.solve_least_squares_qr_with_info_with_parallel(&wide_b, true)
996 .unwrap();
997 assert_eq!(info_serial.rank, info_parallel.rank);
998 assert!(info_serial.cond_est.is_finite());
999 assert!(info_parallel.cond_est.is_finite());
1000 assert_matrix_close(&x_serial, &x_parallel, 1e-7);
1001 }
1002
1003 #[test]
1004 fn test_lu_solve() {
1005 let a = Matrix::from_vec(vec![2.0, 1.0, 3.0, 4.0], 2, 2).unwrap();
1006 let b = Matrix::from_vec(vec![5.0, 11.0], 2, 1).unwrap();
1007
1008 let x = a.solve_lu(&b).unwrap();
1009
1010 let verify_serial = a.try_mul_with_parallel(&x, false).unwrap();
1011 let verify_parallel = a.try_mul_with_parallel(&x, true).unwrap();
1012 assert_matrix_close(&verify_serial, &verify_parallel, 0.0);
1013
1014 assert!((verify_serial[(0, 0)] - b[(0, 0)]).abs() < 1e-10);
1015 assert!((verify_serial[(1, 0)] - b[(1, 0)]).abs() < 1e-10);
1016 }
1017
1018 #[test]
1019 fn test_lu_solve_is_scale_invariant() {
1020 let a = Matrix::from_vec(vec![2.0, 1.0, 3.0, 4.0], 2, 2).unwrap();
1021 let b = Matrix::from_vec(vec![5.0, 11.0], 2, 1).unwrap();
1022
1023 let x = a.solve_lu(&b).unwrap();
1024
1025 let scale = 1e-12;
1026 let a_scaled = &a * scale;
1027 let b_scaled = &b * scale;
1028
1029 let x_scaled = a_scaled.solve_lu(&b_scaled).unwrap();
1030 assert!((x_scaled[(0, 0)] - x[(0, 0)]).abs() < 1e-8);
1031 assert!((x_scaled[(1, 0)] - x[(1, 0)]).abs() < 1e-8);
1032 }
1033
1034 #[test]
1035 fn test_try_add_dimension_mismatch() {
1036 let a = Matrix::new(2, 2);
1037 let b = Matrix::new(2, 3);
1038
1039 let err = a.try_add(&b).expect_err("expected dimension mismatch");
1040 assert_eq!(
1041 err,
1042 MatrixError::DimensionMismatch {
1043 operation: "add",
1044 left: (2, 2),
1045 right: (2, 3),
1046 }
1047 );
1048 }
1049
1050 #[test]
1051 fn test_try_sub_dimension_mismatch() {
1052 let a = Matrix::new(3, 2);
1053 let b = Matrix::new(2, 2);
1054
1055 let err = a.try_sub(&b).expect_err("expected dimension mismatch");
1056 assert_eq!(
1057 err,
1058 MatrixError::DimensionMismatch {
1059 operation: "sub",
1060 left: (3, 2),
1061 right: (2, 2),
1062 }
1063 );
1064 }
1065
1066 #[test]
1067 fn test_try_mul_dimension_mismatch() {
1068 let a = Matrix::new(2, 3);
1069 let b = Matrix::new(2, 2);
1070
1071 let err_serial = a
1072 .try_mul_with_parallel(&b, false)
1073 .expect_err("expected dimension mismatch");
1074 let err_parallel = a
1075 .try_mul_with_parallel(&b, true)
1076 .expect_err("expected dimension mismatch");
1077 assert_eq!(
1078 err_serial,
1079 MatrixError::DimensionMismatch {
1080 operation: "mul",
1081 left: (2, 3),
1082 right: (2, 2),
1083 }
1084 );
1085 assert_eq!(err_serial, err_parallel);
1086 }
1087
1088 #[test]
1089 fn test_solve_least_squares_qr_overdetermined_exact() {
1090 let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 3, 2).unwrap();
1092 let b = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
1093
1094 let x_serial = a.solve_least_squares_qr_with_parallel(&b, false).unwrap();
1095 let x_parallel = a.solve_least_squares_qr_with_parallel(&b, true).unwrap();
1096 assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1097 assert!((x_serial[(0, 0)] - 1.0).abs() < 1e-12);
1098 assert!((x_serial[(1, 0)] - 2.0).abs() < 1e-12);
1099 }
1100
1101 #[test]
1102 fn test_solve_least_squares_qr_underdetermined_min_norm() {
1103 let a = Matrix::from_vec(vec![1.0, 1.0], 1, 2).unwrap();
1105 let b = Matrix::from_vec(vec![1.0], 1, 1).unwrap();
1106
1107 let x_serial = a.solve_least_squares_qr_with_parallel(&b, false).unwrap();
1108 let x_parallel = a.solve_least_squares_qr_with_parallel(&b, true).unwrap();
1109 assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1110 assert!((x_serial[(0, 0)] - 0.5).abs() < 1e-12);
1111 assert!((x_serial[(1, 0)] - 0.5).abs() < 1e-12);
1112 }
1113
1114 #[test]
1115 fn test_solve_least_squares_qr_with_info_random_tall_full_rank() {
1116 let mut rng = TestRng::new(0x5eed_1234_5678_9abc);
1117 for _ in 0..5 {
1118 let mut a = random_matrix(6, 3, &mut rng);
1119 for i in 0..3 {
1120 a[(i, i)] += 2.0;
1121 }
1122 let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1123 let x_mat = Matrix::from_vec(x_true.clone(), 3, 1).unwrap();
1124 let b = &a * &x_mat;
1125
1126 let (x_serial, info_serial) =
1127 a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1128 .unwrap();
1129 let (x_parallel, info_parallel) =
1130 a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1131 .unwrap();
1132 assert_eq!(info_serial.rank, 3);
1133 assert_eq!(info_parallel.rank, 3);
1134 assert!(info_serial.cond_est.is_finite());
1135 assert!(info_parallel.cond_est.is_finite());
1136 assert_matrix_close(&x_serial, &x_parallel, 1e-8);
1137 for i in 0..3 {
1138 assert!((x_serial[(i, 0)] - x_true[i]).abs() < 1e-8);
1139 assert!((x_parallel[(i, 0)] - x_true[i]).abs() < 1e-8);
1140 }
1141 }
1142 }
1143
1144 #[test]
1145 fn test_solve_least_squares_qr_with_info_random_wide_min_norm() {
1146 let mut rng = TestRng::new(0x1234_5678_9abc_def0);
1147 let mut a = Matrix::new(2, 4);
1148 a[(0, 0)] = 1.0;
1149 a[(1, 1)] = 1.0;
1150
1151 for _ in 0..5 {
1152 let b0 = rng.next_f64();
1153 let b1 = rng.next_f64();
1154 let b = Matrix::from_vec(vec![b0, b1], 2, 1).unwrap();
1155
1156 let (x_serial, info_serial) =
1157 a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1158 .unwrap();
1159 let (x_parallel, info_parallel) =
1160 a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1161 .unwrap();
1162 assert_eq!(info_serial.rank, 2);
1163 assert_eq!(info_parallel.rank, 2);
1164 assert!(info_serial.cond_est.is_finite());
1165 assert!(info_parallel.cond_est.is_finite());
1166 assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1167 assert!((x_serial[(0, 0)] - b0).abs() < 1e-12);
1168 assert!((x_serial[(1, 0)] - b1).abs() < 1e-12);
1169 assert!((x_serial[(2, 0)]).abs() < 1e-12);
1170 assert!((x_serial[(3, 0)]).abs() < 1e-12);
1171 assert!((x_parallel[(0, 0)] - b0).abs() < 1e-12);
1172 assert!((x_parallel[(1, 0)] - b1).abs() < 1e-12);
1173 assert!((x_parallel[(2, 0)]).abs() < 1e-12);
1174 assert!((x_parallel[(3, 0)]).abs() < 1e-12);
1175 }
1176 }
1177
1178 #[test]
1179 fn test_solve_least_squares_qr_with_info_detects_ill_conditioning() {
1180 let mut a = Matrix::identity(3);
1181 for i in 0..3 {
1182 a[(i, 2)] *= 1e-8;
1183 }
1184 let b = Matrix::from_vec(vec![0.0, 0.0, 0.0], 3, 1).unwrap();
1185
1186 let (_x_serial, info_serial) =
1187 a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1188 .unwrap();
1189 let (_x_parallel, info_parallel) =
1190 a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1191 .unwrap();
1192 assert_eq!(info_serial.rank, 3);
1193 assert_eq!(info_parallel.rank, 3);
1194 assert!(info_serial.cond_est > 1e6, "cond_est was {}", info_serial.cond_est);
1195 assert!(info_parallel.cond_est > 1e6, "cond_est was {}", info_parallel.cond_est);
1196 }
1197
1198 #[test]
1199 fn test_solve_least_squares_qr_with_info_tall_full_rank() {
1200 let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 3, 2).unwrap();
1201 let b = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
1202
1203 let (x_serial, info_serial) =
1204 a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1205 .unwrap();
1206 let (x_parallel, info_parallel) =
1207 a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1208 .unwrap();
1209 assert_eq!(info_serial.rank, 2);
1210 assert_eq!(info_parallel.rank, 2);
1211 assert!(info_serial.cond_est.is_finite());
1212 assert!(info_parallel.cond_est.is_finite());
1213 assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1214 assert!((x_serial[(0, 0)] - 1.0).abs() < 1e-12);
1215 assert!((x_serial[(1, 0)] - 2.0).abs() < 1e-12);
1216 }
1217
1218 #[test]
1219 fn test_solve_least_squares_qr_with_info_wide_full_rank() {
1220 let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
1221 let b = Matrix::from_vec(vec![1.0, 2.0], 2, 1).unwrap();
1222
1223 let (x_serial, info_serial) =
1224 a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1225 .unwrap();
1226 let (x_parallel, info_parallel) =
1227 a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1228 .unwrap();
1229 assert_eq!(info_serial.rank, 2);
1230 assert_eq!(info_parallel.rank, 2);
1231 assert!(info_serial.cond_est.is_finite());
1232 assert!(info_parallel.cond_est.is_finite());
1233 assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1234 assert!((x_serial[(0, 0)] - 1.0).abs() < 1e-12);
1235 assert!((x_serial[(1, 0)] - 2.0).abs() < 1e-12);
1236 assert!((x_serial[(2, 0)] - 0.0).abs() < 1e-12);
1237 }
1238
1239 #[test]
1240 fn test_solve_least_squares_qr_with_info_rank_deficient() {
1241 let a = Matrix::from_vec(vec![1.0, 1.0, 2.0, 2.0], 2, 2).unwrap();
1242 let b = Matrix::from_vec(vec![1.0, 2.0], 2, 1).unwrap();
1243
1244 let err_serial = a
1245 .solve_least_squares_qr_with_info_with_parallel(&b, false)
1246 .expect_err("expected rank-deficient QR solve to fail");
1247 let err_parallel = a
1248 .solve_least_squares_qr_with_info_with_parallel(&b, true)
1249 .expect_err("expected rank-deficient QR solve to fail");
1250 assert!(err_serial.contains("rank deficient"), "{err_serial}");
1251 assert!(err_parallel.contains("rank deficient"), "{err_parallel}");
1252 }
1253}