1use std::fmt;
26use std::ops::{Add, Index, IndexMut, Mul, Sub};
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum MatrixError {
30 DimensionMismatch {
31 operation: &'static str,
32 left: (usize, usize),
33 right: (usize, usize),
34 },
35}
36
37#[derive(Debug, Clone, Copy)]
38pub struct LeastSquaresQrInfo {
39 pub rank: usize,
40 pub cond_est: f64,
41}
42
43impl fmt::Display for MatrixError {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 MatrixError::DimensionMismatch {
47 operation,
48 left,
49 right,
50 } => write!(
51 f,
52 "Matrix dimension mismatch for {}: left is {}x{}, right is {}x{}",
53 operation, left.0, left.1, right.0, right.1
54 ),
55 }
56 }
57}
58
59impl std::error::Error for MatrixError {}
60
61#[derive(Debug, Clone, PartialEq)]
62pub struct Matrix {
63 data: Vec<f64>,
64 rows: usize,
65 cols: usize,
66}
67
68impl Matrix {
69 pub fn new(rows: usize, cols: usize) -> Self {
70 Matrix {
71 data: vec![0.0; rows * cols],
72 rows,
73 cols,
74 }
75 }
76
77 pub fn from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Result<Self, String> {
78 if data.len() != rows * cols {
79 return Err("Data length doesn't match dimensions".to_string());
80 }
81 Ok(Matrix { data, rows, cols })
82 }
83
84 pub fn identity(size: usize) -> Self {
85 let mut m = Matrix::new(size, size);
86 for i in 0..size {
87 m[(i, i)] = 1.0;
88 }
89 m
90 }
91
92 pub fn rows(&self) -> usize {
93 self.rows
94 }
95
96 pub fn cols(&self) -> usize {
97 self.cols
98 }
99
100 pub fn transpose(&self) -> Matrix {
101 let mut result = Matrix::new(self.cols, self.rows);
102 for i in 0..self.rows {
103 for j in 0..self.cols {
104 result[(j, i)] = self[(i, j)];
105 }
106 }
107 result
108 }
109
110 pub fn lu_decomposition(&self) -> Result<(Matrix, Matrix, Vec<usize>), String> {
111 self.lu_decomposition_with_tolerance(1e-12)
112 }
113
114 pub fn lu_decomposition_with_tolerance(
115 &self,
116 singular_relative_epsilon: f64,
117 ) -> Result<(Matrix, Matrix, Vec<usize>), String> {
118 if self.rows != self.cols {
119 return Err("Matrix must be square for LU decomposition".to_string());
120 }
121 if !singular_relative_epsilon.is_finite() || singular_relative_epsilon < 0.0 {
122 return Err("Invalid singular tolerance".to_string());
123 }
124
125 let n = self.rows;
126 let mut l = Matrix::identity(n);
127 let mut u = self.clone();
128 let mut pivot = (0..n).collect::<Vec<_>>();
129
130 for k in 0..n {
131 let mut max_val = 0.0;
132 let mut max_row = k;
133 for i in k..n {
134 let val = u[(i, k)].abs();
135 if val > max_val {
136 max_val = val;
137 max_row = i;
138 }
139 }
140
141 let mut row_norm: f64 = 0.0;
142 for j in k..n {
143 row_norm = row_norm.max(u[(max_row, j)].abs());
144 }
145
146 if max_val <= singular_relative_epsilon * row_norm {
147 return Err("Matrix is singular".to_string());
148 }
149
150 if max_row != k {
151 pivot.swap(k, max_row);
152 for j in 0..n {
153 let temp = u[(k, j)];
154 u[(k, j)] = u[(max_row, j)];
155 u[(max_row, j)] = temp;
156 }
157 for j in 0..k {
158 let temp = l[(k, j)];
159 l[(k, j)] = l[(max_row, j)];
160 l[(max_row, j)] = temp;
161 }
162 }
163
164 for i in (k + 1)..n {
165 l[(i, k)] = u[(i, k)] / u[(k, k)];
166 for j in k..n {
167 u[(i, j)] -= l[(i, k)] * u[(k, j)];
168 }
169 }
170 }
171
172 Ok((l, u, pivot))
173 }
174
175 pub fn solve_lu(&self, b: &Matrix) -> Result<Matrix, String> {
176 self.solve_lu_with_tolerance(b, 1e-12)
177 }
178
179 pub fn solve_lu_with_tolerance(
180 &self,
181 b: &Matrix,
182 singular_relative_epsilon: f64,
183 ) -> Result<Matrix, String> {
184 if self.rows != self.cols {
185 return Err("Matrix must be square".to_string());
186 }
187 if b.rows != self.rows || b.cols != 1 {
188 return Err("Invalid dimensions for b".to_string());
189 }
190
191 let (l, u, pivot) = self.lu_decomposition_with_tolerance(singular_relative_epsilon)?;
192
193 let mut pb = Matrix::new(b.rows, 1);
194 for i in 0..b.rows {
195 pb[(i, 0)] = b[(pivot[i], 0)];
196 }
197
198 let mut y = Matrix::new(self.rows, 1);
199 for i in 0..self.rows {
200 y[(i, 0)] = pb[(i, 0)];
201 for j in 0..i {
202 y[(i, 0)] -= l[(i, j)] * y[(j, 0)];
203 }
204 }
205
206 let mut x = Matrix::new(self.rows, 1);
207 for i in (0..self.rows).rev() {
208 x[(i, 0)] = y[(i, 0)];
209 for j in (i + 1)..self.rows {
210 x[(i, 0)] -= u[(i, j)] * x[(j, 0)];
211 }
212 x[(i, 0)] /= u[(i, i)];
213 }
214
215 Ok(x)
216 }
217
218 pub fn solve_least_squares_qr(&self, b: &Matrix) -> Result<Matrix, String> {
224 self.solve_least_squares_qr_with_info(b)
225 .map(|(solution, _)| solution)
226 }
227
228 pub fn solve_least_squares_qr_with_info(
231 &self,
232 b: &Matrix,
233 ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
234 if b.cols != 1 {
235 return Err("Invalid dimensions for b (expected column vector)".to_string());
236 }
237 if b.rows != self.rows {
238 return Err("Invalid dimensions for b".to_string());
239 }
240
241 let m = self.rows;
242 let n = self.cols;
243
244 if n == 0 {
246 return Ok((
247 Matrix::new(0, 1),
248 LeastSquaresQrInfo {
249 rank: 0,
250 cond_est: f64::INFINITY,
251 },
252 ));
253 }
254
255 if m >= n {
256 Self::solve_least_squares_qr_tall_with_info(self, b)
257 } else {
258 Self::solve_least_squares_qr_wide_with_info(self, b)
259 }
260 }
261
262 fn solve_least_squares_qr_tall_with_info(
263 a: &Matrix,
264 b: &Matrix,
265 ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
266 let m = a.rows;
267 let n = a.cols;
268
269 let mut r = a.clone();
270 let mut qt_b = b.clone();
271 let mut taus = Vec::with_capacity(n);
272
273 for k in 0..n {
275 let mut col_norm: f64 = 0.0;
276 for i in k..m {
277 col_norm = col_norm.hypot(r[(i, k)]);
278 }
279
280 if col_norm == 0.0 {
281 taus.push(0.0);
282 continue;
283 }
284
285 let x0 = r[(k, k)];
286 let sign = if x0 >= 0.0 { 1.0 } else { -1.0 };
287 let alpha = -sign * col_norm;
288 let v0 = x0 - alpha;
289
290 for i in (k + 1)..m {
292 r[(i, k)] /= v0;
293 }
294
295 let mut v_sq = 1.0;
296 for i in (k + 1)..m {
297 let vi = r[(i, k)];
298 v_sq += vi * vi;
299 }
300 let tau = 2.0 / v_sq;
301 taus.push(tau);
302
303 for j in (k + 1)..n {
305 let mut dot = r[(k, j)];
306 for i in (k + 1)..m {
307 dot += r[(i, k)] * r[(i, j)];
308 }
309 dot *= tau;
310
311 r[(k, j)] -= dot;
312 for i in (k + 1)..m {
313 r[(i, j)] -= r[(i, k)] * dot;
314 }
315 }
316
317 let mut dot = qt_b[(k, 0)];
319 for i in (k + 1)..m {
320 dot += r[(i, k)] * qt_b[(i, 0)];
321 }
322 dot *= tau;
323 qt_b[(k, 0)] -= dot;
324 for i in (k + 1)..m {
325 qt_b[(i, 0)] -= r[(i, k)] * dot;
326 }
327
328 r[(k, k)] = alpha;
329 }
330
331 let mut max_diag: f64 = 0.0;
332 for i in 0..n {
333 max_diag = max_diag.max(r[(i, i)].abs());
334 }
335 let tol = 1e-12 * max_diag.max(1.0);
336 let mut rank = 0;
337 let mut min_diag = f64::INFINITY;
338 for i in 0..n {
339 let diag = r[(i, i)].abs();
340 if diag > tol {
341 rank += 1;
342 if diag < min_diag {
343 min_diag = diag;
344 }
345 }
346 }
347 let cond_est = if rank == 0 || !min_diag.is_finite() {
348 f64::INFINITY
349 } else {
350 max_diag / min_diag
351 };
352
353 let mut x = Matrix::new(n, 1);
355 for i in (0..n).rev() {
356 let mut sum = qt_b[(i, 0)];
357 for j in (i + 1)..n {
358 sum -= r[(i, j)] * x[(j, 0)];
359 }
360
361 let diag = r[(i, i)];
362 if !diag.is_finite() {
363 return Err("Least squares solve failed: non-finite diagonal in R".to_string());
364 }
365
366 let mut row_norm: f64 = 0.0;
367 for j in i..n {
368 row_norm = row_norm.max(r[(i, j)].abs());
369 }
370
371 if diag.abs() <= 1e-12 * row_norm {
372 return Err("Least squares solve failed: matrix is rank deficient".to_string());
373 }
374
375 x[(i, 0)] = sum / diag;
376 }
377
378 Ok((
379 x,
380 LeastSquaresQrInfo {
381 rank,
382 cond_est,
383 },
384 ))
385 }
386
387 fn solve_least_squares_qr_wide_with_info(
388 a: &Matrix,
389 b: &Matrix,
390 ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
391 let m = a.rows;
397 let n = a.cols;
398
399 let mut r = a.transpose(); let mut taus = Vec::with_capacity(m);
401
402 for k in 0..m {
404 let mut col_norm: f64 = 0.0;
405 for i in k..n {
406 col_norm = col_norm.hypot(r[(i, k)]);
407 }
408
409 if col_norm == 0.0 {
410 taus.push(0.0);
411 continue;
412 }
413
414 let x0 = r[(k, k)];
415 let sign = if x0 >= 0.0 { 1.0 } else { -1.0 };
416 let alpha = -sign * col_norm;
417 let v0 = x0 - alpha;
418
419 for i in (k + 1)..n {
420 r[(i, k)] /= v0;
421 }
422
423 let mut v_sq = 1.0;
424 for i in (k + 1)..n {
425 let vi = r[(i, k)];
426 v_sq += vi * vi;
427 }
428 let tau = 2.0 / v_sq;
429 taus.push(tau);
430
431 for j in (k + 1)..m {
432 let mut dot = r[(k, j)];
433 for i in (k + 1)..n {
434 dot += r[(i, k)] * r[(i, j)];
435 }
436 dot *= tau;
437 r[(k, j)] -= dot;
438 for i in (k + 1)..n {
439 r[(i, j)] -= r[(i, k)] * dot;
440 }
441 }
442
443 r[(k, k)] = alpha;
444 }
445
446 let mut max_diag: f64 = 0.0;
447 for i in 0..m {
448 max_diag = max_diag.max(r[(i, i)].abs());
449 }
450 let tol = 1e-12 * max_diag.max(1.0);
451 let mut rank = 0;
452 let mut min_diag = f64::INFINITY;
453 for i in 0..m {
454 let diag = r[(i, i)].abs();
455 if diag > tol {
456 rank += 1;
457 if diag < min_diag {
458 min_diag = diag;
459 }
460 }
461 }
462 let cond_est = if rank == 0 || !min_diag.is_finite() {
463 f64::INFINITY
464 } else {
465 max_diag / min_diag
466 };
467
468 let mut y = vec![0.0; m];
470 for i in 0..m {
471 let mut sum = b[(i, 0)];
472 for j in 0..i {
473 sum -= r[(j, i)] * y[j];
474 }
475
476 let diag = r[(i, i)];
477 if !diag.is_finite() {
478 return Err("Least squares solve failed: non-finite diagonal in R".to_string());
479 }
480
481 let mut col_norm: f64 = 0.0;
482 for j in 0..=i {
483 col_norm = col_norm.max(r[(j, i)].abs());
484 }
485
486 if diag.abs() <= 1e-12 * col_norm {
487 return Err("Least squares solve failed: matrix is rank deficient".to_string());
488 }
489
490 y[i] = sum / diag;
491 }
492
493 let mut w = vec![0.0; n];
495 w[..m].copy_from_slice(&y[..m]);
496
497 for k in (0..m).rev() {
500 let tau = taus[k];
501 if tau == 0.0 {
502 continue;
503 }
504
505 let mut dot = w[k];
506 for i in (k + 1)..n {
507 dot += r[(i, k)] * w[i];
508 }
509 dot *= tau;
510
511 w[k] -= dot;
512 for i in (k + 1)..n {
513 w[i] -= r[(i, k)] * dot;
514 }
515 }
516
517 Matrix::from_vec(w, n, 1).map(|solution| {
518 (
519 solution,
520 LeastSquaresQrInfo {
521 rank,
522 cond_est,
523 },
524 )
525 })
526 }
527
528 pub fn norm(&self) -> f64 {
529 self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
530 }
531
532 pub fn try_add(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
533 if self.rows != rhs.rows || self.cols != rhs.cols {
534 return Err(MatrixError::DimensionMismatch {
535 operation: "add",
536 left: (self.rows, self.cols),
537 right: (rhs.rows, rhs.cols),
538 });
539 }
540
541 let mut result = Matrix::new(self.rows, self.cols);
542 for i in 0..self.data.len() {
543 result.data[i] = self.data[i] + rhs.data[i];
544 }
545 Ok(result)
546 }
547
548 pub fn try_sub(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
549 if self.rows != rhs.rows || self.cols != rhs.cols {
550 return Err(MatrixError::DimensionMismatch {
551 operation: "sub",
552 left: (self.rows, self.cols),
553 right: (rhs.rows, rhs.cols),
554 });
555 }
556
557 let mut result = Matrix::new(self.rows, self.cols);
558 for i in 0..self.data.len() {
559 result.data[i] = self.data[i] - rhs.data[i];
560 }
561 Ok(result)
562 }
563
564 pub fn try_mul(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
565 if self.cols != rhs.rows {
566 return Err(MatrixError::DimensionMismatch {
567 operation: "mul",
568 left: (self.rows, self.cols),
569 right: (rhs.rows, rhs.cols),
570 });
571 }
572
573 let mut result = Matrix::new(self.rows, rhs.cols);
574 for i in 0..self.rows {
575 for j in 0..rhs.cols {
576 for k in 0..self.cols {
577 result[(i, j)] += self[(i, k)] * rhs[(k, j)];
578 }
579 }
580 }
581 Ok(result)
582 }
583}
584
585impl Index<(usize, usize)> for Matrix {
586 type Output = f64;
587
588 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
589 &self.data[row * self.cols + col]
590 }
591}
592
593impl IndexMut<(usize, usize)> for Matrix {
594 fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output {
595 &mut self.data[row * self.cols + col]
596 }
597}
598
599impl Add for &Matrix {
600 type Output = Matrix;
601
602 fn add(self, rhs: Self) -> Self::Output {
603 self.try_add(rhs).unwrap_or_else(|err| panic!("{}", err))
604 }
605}
606
607impl Sub for &Matrix {
608 type Output = Matrix;
609
610 fn sub(self, rhs: Self) -> Self::Output {
611 self.try_sub(rhs).unwrap_or_else(|err| panic!("{}", err))
612 }
613}
614
615impl Mul for &Matrix {
616 type Output = Matrix;
617
618 fn mul(self, rhs: Self) -> Self::Output {
619 self.try_mul(rhs).unwrap_or_else(|err| panic!("{}", err))
620 }
621}
622
623impl Mul<f64> for &Matrix {
624 type Output = Matrix;
625
626 fn mul(self, scalar: f64) -> Self::Output {
627 let mut result = self.clone();
628 for val in &mut result.data {
629 *val *= scalar;
630 }
631 result
632 }
633}
634
635impl fmt::Display for Matrix {
636 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
637 for i in 0..self.rows {
638 write!(f, "[")?;
639 for j in 0..self.cols {
640 if j > 0 {
641 write!(f, ", ")?;
642 }
643 write!(f, "{:8.4}", self[(i, j)])?;
644 }
645 writeln!(f, "]")?;
646 }
647 Ok(())
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654
655 struct TestRng {
656 state: u64,
657 }
658
659 impl TestRng {
660 fn new(seed: u64) -> Self {
661 Self { state: seed }
662 }
663
664 fn next_u32(&mut self) -> u32 {
665 self.state = self
666 .state
667 .wrapping_mul(6364136223846793005)
668 .wrapping_add(1);
669 (self.state >> 32) as u32
670 }
671
672 fn next_f64(&mut self) -> f64 {
673 let v = self.next_u32() as f64 / u32::MAX as f64;
674 2.0 * v - 1.0
675 }
676 }
677
678 fn random_matrix(rows: usize, cols: usize, rng: &mut TestRng) -> Matrix {
679 let mut m = Matrix::new(rows, cols);
680 for i in 0..rows {
681 for j in 0..cols {
682 m[(i, j)] = rng.next_f64();
683 }
684 }
685 m
686 }
687
688 #[test]
689 fn test_matrix_creation() {
690 let m = Matrix::new(2, 3);
691 assert_eq!(m.rows(), 2);
692 assert_eq!(m.cols(), 3);
693 assert_eq!(m[(0, 0)], 0.0);
694 }
695
696 #[test]
697 fn test_identity() {
698 let m = Matrix::identity(3);
699 assert_eq!(m[(0, 0)], 1.0);
700 assert_eq!(m[(1, 1)], 1.0);
701 assert_eq!(m[(2, 2)], 1.0);
702 assert_eq!(m[(0, 1)], 0.0);
703 }
704
705 #[test]
706 fn test_transpose() {
707 let mut m = Matrix::new(2, 3);
708 m[(0, 0)] = 1.0;
709 m[(0, 1)] = 2.0;
710 m[(0, 2)] = 3.0;
711 m[(1, 0)] = 4.0;
712 m[(1, 1)] = 5.0;
713 m[(1, 2)] = 6.0;
714
715 let mt = m.transpose();
716 assert_eq!(mt.rows(), 3);
717 assert_eq!(mt.cols(), 2);
718 assert_eq!(mt[(0, 0)], 1.0);
719 assert_eq!(mt[(1, 0)], 2.0);
720 assert_eq!(mt[(2, 0)], 3.0);
721 assert_eq!(mt[(0, 1)], 4.0);
722 }
723
724 #[test]
725 fn test_matrix_multiplication() {
726 let mut a = Matrix::new(2, 3);
727 a[(0, 0)] = 1.0;
728 a[(0, 1)] = 2.0;
729 a[(0, 2)] = 3.0;
730 a[(1, 0)] = 4.0;
731 a[(1, 1)] = 5.0;
732 a[(1, 2)] = 6.0;
733
734 let mut b = Matrix::new(3, 2);
735 b[(0, 0)] = 7.0;
736 b[(0, 1)] = 8.0;
737 b[(1, 0)] = 9.0;
738 b[(1, 1)] = 10.0;
739 b[(2, 0)] = 11.0;
740 b[(2, 1)] = 12.0;
741
742 let c = &a * &b;
743 assert_eq!(c.rows(), 2);
744 assert_eq!(c.cols(), 2);
745 assert_eq!(c[(0, 0)], 58.0);
746 assert_eq!(c[(0, 1)], 64.0);
747 assert_eq!(c[(1, 0)], 139.0);
748 assert_eq!(c[(1, 1)], 154.0);
749 }
750
751 #[test]
752 fn test_lu_solve() {
753 let a = Matrix::from_vec(vec![2.0, 1.0, 3.0, 4.0], 2, 2).unwrap();
754 let b = Matrix::from_vec(vec![5.0, 11.0], 2, 1).unwrap();
755
756 let x = a.solve_lu(&b).unwrap();
757
758 let verify = &a * &x;
759
760 assert!((verify[(0, 0)] - b[(0, 0)]).abs() < 1e-10);
761 assert!((verify[(1, 0)] - b[(1, 0)]).abs() < 1e-10);
762 }
763
764 #[test]
765 fn test_lu_solve_is_scale_invariant() {
766 let a = Matrix::from_vec(vec![2.0, 1.0, 3.0, 4.0], 2, 2).unwrap();
767 let b = Matrix::from_vec(vec![5.0, 11.0], 2, 1).unwrap();
768
769 let x = a.solve_lu(&b).unwrap();
770
771 let scale = 1e-12;
772 let a_scaled = &a * scale;
773 let b_scaled = &b * scale;
774
775 let x_scaled = a_scaled.solve_lu(&b_scaled).unwrap();
776 assert!((x_scaled[(0, 0)] - x[(0, 0)]).abs() < 1e-8);
777 assert!((x_scaled[(1, 0)] - x[(1, 0)]).abs() < 1e-8);
778 }
779
780 #[test]
781 fn test_try_add_dimension_mismatch() {
782 let a = Matrix::new(2, 2);
783 let b = Matrix::new(2, 3);
784
785 let err = a.try_add(&b).expect_err("expected dimension mismatch");
786 assert_eq!(
787 err,
788 MatrixError::DimensionMismatch {
789 operation: "add",
790 left: (2, 2),
791 right: (2, 3),
792 }
793 );
794 }
795
796 #[test]
797 fn test_try_sub_dimension_mismatch() {
798 let a = Matrix::new(3, 2);
799 let b = Matrix::new(2, 2);
800
801 let err = a.try_sub(&b).expect_err("expected dimension mismatch");
802 assert_eq!(
803 err,
804 MatrixError::DimensionMismatch {
805 operation: "sub",
806 left: (3, 2),
807 right: (2, 2),
808 }
809 );
810 }
811
812 #[test]
813 fn test_try_mul_dimension_mismatch() {
814 let a = Matrix::new(2, 3);
815 let b = Matrix::new(2, 2);
816
817 let err = a.try_mul(&b).expect_err("expected dimension mismatch");
818 assert_eq!(
819 err,
820 MatrixError::DimensionMismatch {
821 operation: "mul",
822 left: (2, 3),
823 right: (2, 2),
824 }
825 );
826 }
827
828 #[test]
829 fn test_solve_least_squares_qr_overdetermined_exact() {
830 let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 3, 2).unwrap();
832 let b = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
833
834 let x = a.solve_least_squares_qr(&b).unwrap();
835 assert!((x[(0, 0)] - 1.0).abs() < 1e-12);
836 assert!((x[(1, 0)] - 2.0).abs() < 1e-12);
837 }
838
839 #[test]
840 fn test_solve_least_squares_qr_underdetermined_min_norm() {
841 let a = Matrix::from_vec(vec![1.0, 1.0], 1, 2).unwrap();
843 let b = Matrix::from_vec(vec![1.0], 1, 1).unwrap();
844
845 let x = a.solve_least_squares_qr(&b).unwrap();
846 assert!((x[(0, 0)] - 0.5).abs() < 1e-12);
847 assert!((x[(1, 0)] - 0.5).abs() < 1e-12);
848 }
849
850 #[test]
851 fn test_solve_least_squares_qr_with_info_random_tall_full_rank() {
852 let mut rng = TestRng::new(0x5eed_1234_5678_9abc);
853 for _ in 0..5 {
854 let mut a = random_matrix(6, 3, &mut rng);
855 for i in 0..3 {
856 a[(i, i)] += 2.0;
857 }
858 let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
859 let x_mat = Matrix::from_vec(x_true.clone(), 3, 1).unwrap();
860 let b = &a * &x_mat;
861
862 let (x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
863 assert_eq!(info.rank, 3);
864 assert!(info.cond_est.is_finite());
865 for i in 0..3 {
866 assert!((x[(i, 0)] - x_true[i]).abs() < 1e-8);
867 }
868 }
869 }
870
871 #[test]
872 fn test_solve_least_squares_qr_with_info_random_wide_min_norm() {
873 let mut rng = TestRng::new(0x1234_5678_9abc_def0);
874 let mut a = Matrix::new(2, 4);
875 a[(0, 0)] = 1.0;
876 a[(1, 1)] = 1.0;
877
878 for _ in 0..5 {
879 let b0 = rng.next_f64();
880 let b1 = rng.next_f64();
881 let b = Matrix::from_vec(vec![b0, b1], 2, 1).unwrap();
882
883 let (x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
884 assert_eq!(info.rank, 2);
885 assert!(info.cond_est.is_finite());
886 assert!((x[(0, 0)] - b0).abs() < 1e-12);
887 assert!((x[(1, 0)] - b1).abs() < 1e-12);
888 assert!((x[(2, 0)]).abs() < 1e-12);
889 assert!((x[(3, 0)]).abs() < 1e-12);
890 }
891 }
892
893 #[test]
894 fn test_solve_least_squares_qr_with_info_detects_ill_conditioning() {
895 let mut a = Matrix::identity(3);
896 for i in 0..3 {
897 a[(i, 2)] *= 1e-8;
898 }
899 let b = Matrix::from_vec(vec![0.0, 0.0, 0.0], 3, 1).unwrap();
900
901 let (_x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
902 assert_eq!(info.rank, 3);
903 assert!(info.cond_est > 1e6, "cond_est was {}", info.cond_est);
904 }
905
906 #[test]
907 fn test_solve_least_squares_qr_with_info_tall_full_rank() {
908 let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 3, 2).unwrap();
909 let b = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
910
911 let (x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
912 assert_eq!(info.rank, 2);
913 assert!(info.cond_est.is_finite());
914 assert!((x[(0, 0)] - 1.0).abs() < 1e-12);
915 assert!((x[(1, 0)] - 2.0).abs() < 1e-12);
916 }
917
918 #[test]
919 fn test_solve_least_squares_qr_with_info_wide_full_rank() {
920 let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
921 let b = Matrix::from_vec(vec![1.0, 2.0], 2, 1).unwrap();
922
923 let (x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
924 assert_eq!(info.rank, 2);
925 assert!(info.cond_est.is_finite());
926 assert!((x[(0, 0)] - 1.0).abs() < 1e-12);
927 assert!((x[(1, 0)] - 2.0).abs() < 1e-12);
928 assert!((x[(2, 0)] - 0.0).abs() < 1e-12);
929 }
930
931 #[test]
932 fn test_solve_least_squares_qr_with_info_rank_deficient() {
933 let a = Matrix::from_vec(vec![1.0, 1.0, 2.0, 2.0], 2, 2).unwrap();
934 let b = Matrix::from_vec(vec![1.0, 2.0], 2, 1).unwrap();
935
936 let err = a
937 .solve_least_squares_qr_with_info(&b)
938 .expect_err("expected rank-deficient QR solve to fail");
939 assert!(err.contains("rank deficient"), "{err}");
940 }
941}