Skip to main content

constraint_solver/
matrix.rs

1/*
2MIT License
3
4Copyright (c) 2026 Raja Lehtihet & Wael El Oraiby
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23*/
24
25use std::fmt;
26use std::ops::{Add, Index, IndexMut, Mul, Sub};
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum MatrixError {
30    DimensionMismatch {
31        operation: &'static str,
32        left: (usize, usize),
33        right: (usize, usize),
34    },
35}
36
37#[derive(Debug, Clone, Copy)]
38pub struct LeastSquaresQrInfo {
39    pub rank: usize,
40    pub cond_est: f64,
41}
42
43impl fmt::Display for MatrixError {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            MatrixError::DimensionMismatch {
47                operation,
48                left,
49                right,
50            } => write!(
51                f,
52                "Matrix dimension mismatch for {}: left is {}x{}, right is {}x{}",
53                operation, left.0, left.1, right.0, right.1
54            ),
55        }
56    }
57}
58
59impl std::error::Error for MatrixError {}
60
61#[derive(Debug, Clone, PartialEq)]
62pub struct Matrix {
63    data: Vec<f64>,
64    rows: usize,
65    cols: usize,
66}
67
68impl Matrix {
69    pub fn new(rows: usize, cols: usize) -> Self {
70        Matrix {
71            data: vec![0.0; rows * cols],
72            rows,
73            cols,
74        }
75    }
76
77    pub fn from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Result<Self, String> {
78        if data.len() != rows * cols {
79            return Err("Data length doesn't match dimensions".to_string());
80        }
81        Ok(Matrix { data, rows, cols })
82    }
83
84    pub fn identity(size: usize) -> Self {
85        let mut m = Matrix::new(size, size);
86        for i in 0..size {
87            m[(i, i)] = 1.0;
88        }
89        m
90    }
91
92    pub fn rows(&self) -> usize {
93        self.rows
94    }
95
96    pub fn cols(&self) -> usize {
97        self.cols
98    }
99
100    pub fn transpose(&self) -> Matrix {
101        let mut result = Matrix::new(self.cols, self.rows);
102        for i in 0..self.rows {
103            for j in 0..self.cols {
104                result[(j, i)] = self[(i, j)];
105            }
106        }
107        result
108    }
109
110    pub fn lu_decomposition(&self) -> Result<(Matrix, Matrix, Vec<usize>), String> {
111        self.lu_decomposition_with_tolerance(1e-12)
112    }
113
114    pub fn lu_decomposition_with_tolerance(
115        &self,
116        singular_relative_epsilon: f64,
117    ) -> Result<(Matrix, Matrix, Vec<usize>), String> {
118        if self.rows != self.cols {
119            return Err("Matrix must be square for LU decomposition".to_string());
120        }
121        if !singular_relative_epsilon.is_finite() || singular_relative_epsilon < 0.0 {
122            return Err("Invalid singular tolerance".to_string());
123        }
124
125        let n = self.rows;
126        let mut l = Matrix::identity(n);
127        let mut u = self.clone();
128        let mut pivot = (0..n).collect::<Vec<_>>();
129
130        for k in 0..n {
131            let mut max_val = 0.0;
132            let mut max_row = k;
133            for i in k..n {
134                let val = u[(i, k)].abs();
135                if val > max_val {
136                    max_val = val;
137                    max_row = i;
138                }
139            }
140
141            let mut row_norm: f64 = 0.0;
142            for j in k..n {
143                row_norm = row_norm.max(u[(max_row, j)].abs());
144            }
145
146            if max_val <= singular_relative_epsilon * row_norm {
147                return Err("Matrix is singular".to_string());
148            }
149
150            if max_row != k {
151                pivot.swap(k, max_row);
152                for j in 0..n {
153                    let temp = u[(k, j)];
154                    u[(k, j)] = u[(max_row, j)];
155                    u[(max_row, j)] = temp;
156                }
157                for j in 0..k {
158                    let temp = l[(k, j)];
159                    l[(k, j)] = l[(max_row, j)];
160                    l[(max_row, j)] = temp;
161                }
162            }
163
164            for i in (k + 1)..n {
165                l[(i, k)] = u[(i, k)] / u[(k, k)];
166                for j in k..n {
167                    u[(i, j)] -= l[(i, k)] * u[(k, j)];
168                }
169            }
170        }
171
172        Ok((l, u, pivot))
173    }
174
175    pub fn solve_lu(&self, b: &Matrix) -> Result<Matrix, String> {
176        self.solve_lu_with_tolerance(b, 1e-12)
177    }
178
179    pub fn solve_lu_with_tolerance(
180        &self,
181        b: &Matrix,
182        singular_relative_epsilon: f64,
183    ) -> Result<Matrix, String> {
184        if self.rows != self.cols {
185            return Err("Matrix must be square".to_string());
186        }
187        if b.rows != self.rows || b.cols != 1 {
188            return Err("Invalid dimensions for b".to_string());
189        }
190
191        let (l, u, pivot) = self.lu_decomposition_with_tolerance(singular_relative_epsilon)?;
192
193        let mut pb = Matrix::new(b.rows, 1);
194        for i in 0..b.rows {
195            pb[(i, 0)] = b[(pivot[i], 0)];
196        }
197
198        let mut y = Matrix::new(self.rows, 1);
199        for i in 0..self.rows {
200            y[(i, 0)] = pb[(i, 0)];
201            for j in 0..i {
202                y[(i, 0)] -= l[(i, j)] * y[(j, 0)];
203            }
204        }
205
206        let mut x = Matrix::new(self.rows, 1);
207        for i in (0..self.rows).rev() {
208            x[(i, 0)] = y[(i, 0)];
209            for j in (i + 1)..self.rows {
210                x[(i, 0)] -= u[(i, j)] * x[(j, 0)];
211            }
212            x[(i, 0)] /= u[(i, i)];
213        }
214
215        Ok(x)
216    }
217
218    /// Solve a (possibly non-square) least squares problem using Householder QR.
219    ///
220    /// Solves `min_x ||A x - b||_2`, where `A` is `self`.
221    /// - If `rows >= cols`, returns the standard least-squares solution.
222    /// - If `rows < cols`, returns the minimum-norm solution among all minimizers.
223    pub fn solve_least_squares_qr(&self, b: &Matrix) -> Result<Matrix, String> {
224        self.solve_least_squares_qr_with_info(b)
225            .map(|(solution, _)| solution)
226    }
227
228    /// Solve a (possibly non-square) least squares problem using Householder QR,
229    /// returning diagnostic information about rank and conditioning.
230    pub fn solve_least_squares_qr_with_info(
231        &self,
232        b: &Matrix,
233    ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
234        if b.cols != 1 {
235            return Err("Invalid dimensions for b (expected column vector)".to_string());
236        }
237        if b.rows != self.rows {
238            return Err("Invalid dimensions for b".to_string());
239        }
240
241        let m = self.rows;
242        let n = self.cols;
243
244        // Trivial cases.
245        if n == 0 {
246            return Ok((
247                Matrix::new(0, 1),
248                LeastSquaresQrInfo {
249                    rank: 0,
250                    cond_est: f64::INFINITY,
251                },
252            ));
253        }
254
255        if m >= n {
256            Self::solve_least_squares_qr_tall_with_info(self, b)
257        } else {
258            Self::solve_least_squares_qr_wide_with_info(self, b)
259        }
260    }
261
262    fn solve_least_squares_qr_tall_with_info(
263        a: &Matrix,
264        b: &Matrix,
265    ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
266        let m = a.rows;
267        let n = a.cols;
268
269        let mut r = a.clone();
270        let mut qt_b = b.clone();
271        let mut taus = Vec::with_capacity(n);
272
273        // Householder QR factorization (thin): apply reflectors to `r` and `qt_b`.
274        for k in 0..n {
275            let mut col_norm: f64 = 0.0;
276            for i in k..m {
277                col_norm = col_norm.hypot(r[(i, k)]);
278            }
279
280            if col_norm == 0.0 {
281                taus.push(0.0);
282                continue;
283            }
284
285            let x0 = r[(k, k)];
286            let sign = if x0 >= 0.0 { 1.0 } else { -1.0 };
287            let alpha = -sign * col_norm;
288            let v0 = x0 - alpha;
289
290            // Store Householder vector v in-place in column k (v0 is implicit 1.0).
291            for i in (k + 1)..m {
292                r[(i, k)] /= v0;
293            }
294
295            let mut v_sq = 1.0;
296            for i in (k + 1)..m {
297                let vi = r[(i, k)];
298                v_sq += vi * vi;
299            }
300            let tau = 2.0 / v_sq;
301            taus.push(tau);
302
303            // Apply reflector to remaining columns.
304            for j in (k + 1)..n {
305                let mut dot = r[(k, j)];
306                for i in (k + 1)..m {
307                    dot += r[(i, k)] * r[(i, j)];
308                }
309                dot *= tau;
310
311                r[(k, j)] -= dot;
312                for i in (k + 1)..m {
313                    r[(i, j)] -= r[(i, k)] * dot;
314                }
315            }
316
317            // Apply reflector to b: qt_b = Q^T b
318            let mut dot = qt_b[(k, 0)];
319            for i in (k + 1)..m {
320                dot += r[(i, k)] * qt_b[(i, 0)];
321            }
322            dot *= tau;
323            qt_b[(k, 0)] -= dot;
324            for i in (k + 1)..m {
325                qt_b[(i, 0)] -= r[(i, k)] * dot;
326            }
327
328            r[(k, k)] = alpha;
329        }
330
331        let mut max_diag: f64 = 0.0;
332        for i in 0..n {
333            max_diag = max_diag.max(r[(i, i)].abs());
334        }
335        let tol = 1e-12 * max_diag.max(1.0);
336        let mut rank = 0;
337        let mut min_diag = f64::INFINITY;
338        for i in 0..n {
339            let diag = r[(i, i)].abs();
340            if diag > tol {
341                rank += 1;
342                if diag < min_diag {
343                    min_diag = diag;
344                }
345            }
346        }
347        let cond_est = if rank == 0 || !min_diag.is_finite() {
348            f64::INFINITY
349        } else {
350            max_diag / min_diag
351        };
352
353        // Back-substitute R x = Q^T b, using the top n entries.
354        let mut x = Matrix::new(n, 1);
355        for i in (0..n).rev() {
356            let mut sum = qt_b[(i, 0)];
357            for j in (i + 1)..n {
358                sum -= r[(i, j)] * x[(j, 0)];
359            }
360
361            let diag = r[(i, i)];
362            if !diag.is_finite() {
363                return Err("Least squares solve failed: non-finite diagonal in R".to_string());
364            }
365
366            let mut row_norm: f64 = 0.0;
367            for j in i..n {
368                row_norm = row_norm.max(r[(i, j)].abs());
369            }
370
371            if diag.abs() <= 1e-12 * row_norm {
372                return Err("Least squares solve failed: matrix is rank deficient".to_string());
373            }
374
375            x[(i, 0)] = sum / diag;
376        }
377
378        Ok((
379            x,
380            LeastSquaresQrInfo {
381                rank,
382                cond_est,
383            },
384        ))
385    }
386
387    fn solve_least_squares_qr_wide_with_info(
388        a: &Matrix,
389        b: &Matrix,
390    ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
391        // Use QR on A^T to compute the minimum-norm least squares solution.
392        //
393        // Let A be m x n with m < n. Compute QR of A^T (n x m):
394        // A^T = Q R, where Q is n x n orthogonal and R has top-left m x m upper-triangular.
395        // The minimum-norm solution is x = Q [y; 0], where y solves R^T y = b.
396        let m = a.rows;
397        let n = a.cols;
398
399        let mut r = a.transpose(); // n x m, tall
400        let mut taus = Vec::with_capacity(m);
401
402        // Householder QR on r (n x m), producing R in the top m x m block.
403        for k in 0..m {
404            let mut col_norm: f64 = 0.0;
405            for i in k..n {
406                col_norm = col_norm.hypot(r[(i, k)]);
407            }
408
409            if col_norm == 0.0 {
410                taus.push(0.0);
411                continue;
412            }
413
414            let x0 = r[(k, k)];
415            let sign = if x0 >= 0.0 { 1.0 } else { -1.0 };
416            let alpha = -sign * col_norm;
417            let v0 = x0 - alpha;
418
419            for i in (k + 1)..n {
420                r[(i, k)] /= v0;
421            }
422
423            let mut v_sq = 1.0;
424            for i in (k + 1)..n {
425                let vi = r[(i, k)];
426                v_sq += vi * vi;
427            }
428            let tau = 2.0 / v_sq;
429            taus.push(tau);
430
431            for j in (k + 1)..m {
432                let mut dot = r[(k, j)];
433                for i in (k + 1)..n {
434                    dot += r[(i, k)] * r[(i, j)];
435                }
436                dot *= tau;
437                r[(k, j)] -= dot;
438                for i in (k + 1)..n {
439                    r[(i, j)] -= r[(i, k)] * dot;
440                }
441            }
442
443            r[(k, k)] = alpha;
444        }
445
446        let mut max_diag: f64 = 0.0;
447        for i in 0..m {
448            max_diag = max_diag.max(r[(i, i)].abs());
449        }
450        let tol = 1e-12 * max_diag.max(1.0);
451        let mut rank = 0;
452        let mut min_diag = f64::INFINITY;
453        for i in 0..m {
454            let diag = r[(i, i)].abs();
455            if diag > tol {
456                rank += 1;
457                if diag < min_diag {
458                    min_diag = diag;
459                }
460            }
461        }
462        let cond_est = if rank == 0 || !min_diag.is_finite() {
463            f64::INFINITY
464        } else {
465            max_diag / min_diag
466        };
467
468        // Solve R^T y = b (R is m x m upper triangular, so R^T is lower triangular).
469        let mut y = vec![0.0; m];
470        for i in 0..m {
471            let mut sum = b[(i, 0)];
472            for j in 0..i {
473                sum -= r[(j, i)] * y[j];
474            }
475
476            let diag = r[(i, i)];
477            if !diag.is_finite() {
478                return Err("Least squares solve failed: non-finite diagonal in R".to_string());
479            }
480
481            let mut col_norm: f64 = 0.0;
482            for j in 0..=i {
483                col_norm = col_norm.max(r[(j, i)].abs());
484            }
485
486            if diag.abs() <= 1e-12 * col_norm {
487                return Err("Least squares solve failed: matrix is rank deficient".to_string());
488            }
489
490            y[i] = sum / diag;
491        }
492
493        // Form w = [y; 0] in R^n.
494        let mut w = vec![0.0; n];
495        w[..m].copy_from_slice(&y[..m]);
496
497        // Apply Q to w: Q = H0 H1 ... H_{m-1}.
498        // To compute Q w, apply reflectors in reverse order.
499        for k in (0..m).rev() {
500            let tau = taus[k];
501            if tau == 0.0 {
502                continue;
503            }
504
505            let mut dot = w[k];
506            for i in (k + 1)..n {
507                dot += r[(i, k)] * w[i];
508            }
509            dot *= tau;
510
511            w[k] -= dot;
512            for i in (k + 1)..n {
513                w[i] -= r[(i, k)] * dot;
514            }
515        }
516
517        Matrix::from_vec(w, n, 1).map(|solution| {
518            (
519                solution,
520                LeastSquaresQrInfo {
521                    rank,
522                    cond_est,
523                },
524            )
525        })
526    }
527
528    pub fn norm(&self) -> f64 {
529        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
530    }
531
532    pub fn try_add(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
533        if self.rows != rhs.rows || self.cols != rhs.cols {
534            return Err(MatrixError::DimensionMismatch {
535                operation: "add",
536                left: (self.rows, self.cols),
537                right: (rhs.rows, rhs.cols),
538            });
539        }
540
541        let mut result = Matrix::new(self.rows, self.cols);
542        for i in 0..self.data.len() {
543            result.data[i] = self.data[i] + rhs.data[i];
544        }
545        Ok(result)
546    }
547
548    pub fn try_sub(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
549        if self.rows != rhs.rows || self.cols != rhs.cols {
550            return Err(MatrixError::DimensionMismatch {
551                operation: "sub",
552                left: (self.rows, self.cols),
553                right: (rhs.rows, rhs.cols),
554            });
555        }
556
557        let mut result = Matrix::new(self.rows, self.cols);
558        for i in 0..self.data.len() {
559            result.data[i] = self.data[i] - rhs.data[i];
560        }
561        Ok(result)
562    }
563
564    pub fn try_mul(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
565        if self.cols != rhs.rows {
566            return Err(MatrixError::DimensionMismatch {
567                operation: "mul",
568                left: (self.rows, self.cols),
569                right: (rhs.rows, rhs.cols),
570            });
571        }
572
573        let mut result = Matrix::new(self.rows, rhs.cols);
574        for i in 0..self.rows {
575            for j in 0..rhs.cols {
576                for k in 0..self.cols {
577                    result[(i, j)] += self[(i, k)] * rhs[(k, j)];
578                }
579            }
580        }
581        Ok(result)
582    }
583}
584
585impl Index<(usize, usize)> for Matrix {
586    type Output = f64;
587
588    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
589        &self.data[row * self.cols + col]
590    }
591}
592
593impl IndexMut<(usize, usize)> for Matrix {
594    fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output {
595        &mut self.data[row * self.cols + col]
596    }
597}
598
599impl Add for &Matrix {
600    type Output = Matrix;
601
602    fn add(self, rhs: Self) -> Self::Output {
603        self.try_add(rhs).unwrap_or_else(|err| panic!("{}", err))
604    }
605}
606
607impl Sub for &Matrix {
608    type Output = Matrix;
609
610    fn sub(self, rhs: Self) -> Self::Output {
611        self.try_sub(rhs).unwrap_or_else(|err| panic!("{}", err))
612    }
613}
614
615impl Mul for &Matrix {
616    type Output = Matrix;
617
618    fn mul(self, rhs: Self) -> Self::Output {
619        self.try_mul(rhs).unwrap_or_else(|err| panic!("{}", err))
620    }
621}
622
623impl Mul<f64> for &Matrix {
624    type Output = Matrix;
625
626    fn mul(self, scalar: f64) -> Self::Output {
627        let mut result = self.clone();
628        for val in &mut result.data {
629            *val *= scalar;
630        }
631        result
632    }
633}
634
635impl fmt::Display for Matrix {
636    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
637        for i in 0..self.rows {
638            write!(f, "[")?;
639            for j in 0..self.cols {
640                if j > 0 {
641                    write!(f, ", ")?;
642                }
643                write!(f, "{:8.4}", self[(i, j)])?;
644            }
645            writeln!(f, "]")?;
646        }
647        Ok(())
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654
655    struct TestRng {
656        state: u64,
657    }
658
659    impl TestRng {
660        fn new(seed: u64) -> Self {
661            Self { state: seed }
662        }
663
664        fn next_u32(&mut self) -> u32 {
665            self.state = self
666                .state
667                .wrapping_mul(6364136223846793005)
668                .wrapping_add(1);
669            (self.state >> 32) as u32
670        }
671
672        fn next_f64(&mut self) -> f64 {
673            let v = self.next_u32() as f64 / u32::MAX as f64;
674            2.0 * v - 1.0
675        }
676    }
677
678    fn random_matrix(rows: usize, cols: usize, rng: &mut TestRng) -> Matrix {
679        let mut m = Matrix::new(rows, cols);
680        for i in 0..rows {
681            for j in 0..cols {
682                m[(i, j)] = rng.next_f64();
683            }
684        }
685        m
686    }
687
688    #[test]
689    fn test_matrix_creation() {
690        let m = Matrix::new(2, 3);
691        assert_eq!(m.rows(), 2);
692        assert_eq!(m.cols(), 3);
693        assert_eq!(m[(0, 0)], 0.0);
694    }
695
696    #[test]
697    fn test_identity() {
698        let m = Matrix::identity(3);
699        assert_eq!(m[(0, 0)], 1.0);
700        assert_eq!(m[(1, 1)], 1.0);
701        assert_eq!(m[(2, 2)], 1.0);
702        assert_eq!(m[(0, 1)], 0.0);
703    }
704
705    #[test]
706    fn test_transpose() {
707        let mut m = Matrix::new(2, 3);
708        m[(0, 0)] = 1.0;
709        m[(0, 1)] = 2.0;
710        m[(0, 2)] = 3.0;
711        m[(1, 0)] = 4.0;
712        m[(1, 1)] = 5.0;
713        m[(1, 2)] = 6.0;
714
715        let mt = m.transpose();
716        assert_eq!(mt.rows(), 3);
717        assert_eq!(mt.cols(), 2);
718        assert_eq!(mt[(0, 0)], 1.0);
719        assert_eq!(mt[(1, 0)], 2.0);
720        assert_eq!(mt[(2, 0)], 3.0);
721        assert_eq!(mt[(0, 1)], 4.0);
722    }
723
724    #[test]
725    fn test_matrix_multiplication() {
726        let mut a = Matrix::new(2, 3);
727        a[(0, 0)] = 1.0;
728        a[(0, 1)] = 2.0;
729        a[(0, 2)] = 3.0;
730        a[(1, 0)] = 4.0;
731        a[(1, 1)] = 5.0;
732        a[(1, 2)] = 6.0;
733
734        let mut b = Matrix::new(3, 2);
735        b[(0, 0)] = 7.0;
736        b[(0, 1)] = 8.0;
737        b[(1, 0)] = 9.0;
738        b[(1, 1)] = 10.0;
739        b[(2, 0)] = 11.0;
740        b[(2, 1)] = 12.0;
741
742        let c = &a * &b;
743        assert_eq!(c.rows(), 2);
744        assert_eq!(c.cols(), 2);
745        assert_eq!(c[(0, 0)], 58.0);
746        assert_eq!(c[(0, 1)], 64.0);
747        assert_eq!(c[(1, 0)], 139.0);
748        assert_eq!(c[(1, 1)], 154.0);
749    }
750
751    #[test]
752    fn test_lu_solve() {
753        let a = Matrix::from_vec(vec![2.0, 1.0, 3.0, 4.0], 2, 2).unwrap();
754        let b = Matrix::from_vec(vec![5.0, 11.0], 2, 1).unwrap();
755
756        let x = a.solve_lu(&b).unwrap();
757
758        let verify = &a * &x;
759
760        assert!((verify[(0, 0)] - b[(0, 0)]).abs() < 1e-10);
761        assert!((verify[(1, 0)] - b[(1, 0)]).abs() < 1e-10);
762    }
763
764    #[test]
765    fn test_lu_solve_is_scale_invariant() {
766        let a = Matrix::from_vec(vec![2.0, 1.0, 3.0, 4.0], 2, 2).unwrap();
767        let b = Matrix::from_vec(vec![5.0, 11.0], 2, 1).unwrap();
768
769        let x = a.solve_lu(&b).unwrap();
770
771        let scale = 1e-12;
772        let a_scaled = &a * scale;
773        let b_scaled = &b * scale;
774
775        let x_scaled = a_scaled.solve_lu(&b_scaled).unwrap();
776        assert!((x_scaled[(0, 0)] - x[(0, 0)]).abs() < 1e-8);
777        assert!((x_scaled[(1, 0)] - x[(1, 0)]).abs() < 1e-8);
778    }
779
780    #[test]
781    fn test_try_add_dimension_mismatch() {
782        let a = Matrix::new(2, 2);
783        let b = Matrix::new(2, 3);
784
785        let err = a.try_add(&b).expect_err("expected dimension mismatch");
786        assert_eq!(
787            err,
788            MatrixError::DimensionMismatch {
789                operation: "add",
790                left: (2, 2),
791                right: (2, 3),
792            }
793        );
794    }
795
796    #[test]
797    fn test_try_sub_dimension_mismatch() {
798        let a = Matrix::new(3, 2);
799        let b = Matrix::new(2, 2);
800
801        let err = a.try_sub(&b).expect_err("expected dimension mismatch");
802        assert_eq!(
803            err,
804            MatrixError::DimensionMismatch {
805                operation: "sub",
806                left: (3, 2),
807                right: (2, 2),
808            }
809        );
810    }
811
812    #[test]
813    fn test_try_mul_dimension_mismatch() {
814        let a = Matrix::new(2, 3);
815        let b = Matrix::new(2, 2);
816
817        let err = a.try_mul(&b).expect_err("expected dimension mismatch");
818        assert_eq!(
819            err,
820            MatrixError::DimensionMismatch {
821                operation: "mul",
822                left: (2, 3),
823                right: (2, 2),
824            }
825        );
826    }
827
828    #[test]
829    fn test_solve_least_squares_qr_overdetermined_exact() {
830        // A is 3x2, consistent system with exact solution x=[1,2].
831        let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 3, 2).unwrap();
832        let b = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
833
834        let x = a.solve_least_squares_qr(&b).unwrap();
835        assert!((x[(0, 0)] - 1.0).abs() < 1e-12);
836        assert!((x[(1, 0)] - 2.0).abs() < 1e-12);
837    }
838
839    #[test]
840    fn test_solve_least_squares_qr_underdetermined_min_norm() {
841        // A is 1x2: x0 + x1 = 1. The minimum-norm solution is [0.5, 0.5].
842        let a = Matrix::from_vec(vec![1.0, 1.0], 1, 2).unwrap();
843        let b = Matrix::from_vec(vec![1.0], 1, 1).unwrap();
844
845        let x = a.solve_least_squares_qr(&b).unwrap();
846        assert!((x[(0, 0)] - 0.5).abs() < 1e-12);
847        assert!((x[(1, 0)] - 0.5).abs() < 1e-12);
848    }
849
850    #[test]
851    fn test_solve_least_squares_qr_with_info_random_tall_full_rank() {
852        let mut rng = TestRng::new(0x5eed_1234_5678_9abc);
853        for _ in 0..5 {
854            let mut a = random_matrix(6, 3, &mut rng);
855            for i in 0..3 {
856                a[(i, i)] += 2.0;
857            }
858            let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
859            let x_mat = Matrix::from_vec(x_true.clone(), 3, 1).unwrap();
860            let b = &a * &x_mat;
861
862            let (x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
863            assert_eq!(info.rank, 3);
864            assert!(info.cond_est.is_finite());
865            for i in 0..3 {
866                assert!((x[(i, 0)] - x_true[i]).abs() < 1e-8);
867            }
868        }
869    }
870
871    #[test]
872    fn test_solve_least_squares_qr_with_info_random_wide_min_norm() {
873        let mut rng = TestRng::new(0x1234_5678_9abc_def0);
874        let mut a = Matrix::new(2, 4);
875        a[(0, 0)] = 1.0;
876        a[(1, 1)] = 1.0;
877
878        for _ in 0..5 {
879            let b0 = rng.next_f64();
880            let b1 = rng.next_f64();
881            let b = Matrix::from_vec(vec![b0, b1], 2, 1).unwrap();
882
883            let (x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
884            assert_eq!(info.rank, 2);
885            assert!(info.cond_est.is_finite());
886            assert!((x[(0, 0)] - b0).abs() < 1e-12);
887            assert!((x[(1, 0)] - b1).abs() < 1e-12);
888            assert!((x[(2, 0)]).abs() < 1e-12);
889            assert!((x[(3, 0)]).abs() < 1e-12);
890        }
891    }
892
893    #[test]
894    fn test_solve_least_squares_qr_with_info_detects_ill_conditioning() {
895        let mut a = Matrix::identity(3);
896        for i in 0..3 {
897            a[(i, 2)] *= 1e-8;
898        }
899        let b = Matrix::from_vec(vec![0.0, 0.0, 0.0], 3, 1).unwrap();
900
901        let (_x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
902        assert_eq!(info.rank, 3);
903        assert!(info.cond_est > 1e6, "cond_est was {}", info.cond_est);
904    }
905
906    #[test]
907    fn test_solve_least_squares_qr_with_info_tall_full_rank() {
908        let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 3, 2).unwrap();
909        let b = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
910
911        let (x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
912        assert_eq!(info.rank, 2);
913        assert!(info.cond_est.is_finite());
914        assert!((x[(0, 0)] - 1.0).abs() < 1e-12);
915        assert!((x[(1, 0)] - 2.0).abs() < 1e-12);
916    }
917
918    #[test]
919    fn test_solve_least_squares_qr_with_info_wide_full_rank() {
920        let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
921        let b = Matrix::from_vec(vec![1.0, 2.0], 2, 1).unwrap();
922
923        let (x, info) = a.solve_least_squares_qr_with_info(&b).unwrap();
924        assert_eq!(info.rank, 2);
925        assert!(info.cond_est.is_finite());
926        assert!((x[(0, 0)] - 1.0).abs() < 1e-12);
927        assert!((x[(1, 0)] - 2.0).abs() < 1e-12);
928        assert!((x[(2, 0)] - 0.0).abs() < 1e-12);
929    }
930
931    #[test]
932    fn test_solve_least_squares_qr_with_info_rank_deficient() {
933        let a = Matrix::from_vec(vec![1.0, 1.0, 2.0, 2.0], 2, 2).unwrap();
934        let b = Matrix::from_vec(vec![1.0, 2.0], 2, 1).unwrap();
935
936        let err = a
937            .solve_least_squares_qr_with_info(&b)
938            .expect_err("expected rank-deficient QR solve to fail");
939        assert!(err.contains("rank deficient"), "{err}");
940    }
941}