Skip to main content

constraint_solver/
matrix.rs

1/*
2MIT License
3
4Copyright (c) 2026 Raja Lehtihet & Wael El Oraiby
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23*/
24
25use std::fmt;
26use std::ops::{Add, Index, IndexMut, Mul, Sub};
27
28use rayon::prelude::*;
29
30const PARALLEL_THRESHOLD: usize = 16_384;
31
32fn should_parallelize(rows: usize, cols: usize) -> bool {
33    rows.saturating_mul(cols) >= PARALLEL_THRESHOLD
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum MatrixError {
38    DimensionMismatch {
39        operation: &'static str,
40        left: (usize, usize),
41        right: (usize, usize),
42    },
43}
44
45#[derive(Debug, Clone, Copy)]
46pub struct LeastSquaresQrInfo {
47    pub rank: usize,
48    pub cond_est: f64,
49}
50
51impl fmt::Display for MatrixError {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        match self {
54            MatrixError::DimensionMismatch {
55                operation,
56                left,
57                right,
58            } => write!(
59                f,
60                "Matrix dimension mismatch for {}: left is {}x{}, right is {}x{}",
61                operation, left.0, left.1, right.0, right.1
62            ),
63        }
64    }
65}
66
67impl std::error::Error for MatrixError {}
68
69#[derive(Debug, Clone, PartialEq)]
70pub struct Matrix {
71    data: Vec<f64>,
72    rows: usize,
73    cols: usize,
74}
75
76impl Matrix {
77    pub fn new(rows: usize, cols: usize) -> Self {
78        Matrix {
79            data: vec![0.0; rows * cols],
80            rows,
81            cols,
82        }
83    }
84
85    pub fn from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Result<Self, String> {
86        if data.len() != rows * cols {
87            return Err("Data length doesn't match dimensions".to_string());
88        }
89        Ok(Matrix { data, rows, cols })
90    }
91
92    pub fn identity(size: usize) -> Self {
93        let mut m = Matrix::new(size, size);
94        for i in 0..size {
95            m[(i, i)] = 1.0;
96        }
97        m
98    }
99
100    pub fn rows(&self) -> usize {
101        self.rows
102    }
103
104    pub fn cols(&self) -> usize {
105        self.cols
106    }
107
108    pub fn transpose(&self) -> Matrix {
109        self.transpose_with_parallel(false)
110    }
111
112    pub fn transpose_with_parallel(&self, parallel: bool) -> Matrix {
113        let mut result = Matrix::new(self.cols, self.rows);
114        if parallel && should_parallelize(self.rows, self.cols) {
115            result
116                .data
117                .par_chunks_mut(self.rows)
118                .enumerate()
119                .for_each(|(j, row)| {
120                    for i in 0..self.rows {
121                        row[i] = self[(i, j)];
122                    }
123                });
124            return result;
125        }
126
127        for i in 0..self.rows {
128            for j in 0..self.cols {
129                result[(j, i)] = self[(i, j)];
130            }
131        }
132        result
133    }
134
135    pub fn lu_decomposition(&self) -> Result<(Matrix, Matrix, Vec<usize>), String> {
136        self.lu_decomposition_with_tolerance(1e-12)
137    }
138
139    pub fn lu_decomposition_with_tolerance(
140        &self,
141        singular_relative_epsilon: f64,
142    ) -> Result<(Matrix, Matrix, Vec<usize>), String> {
143        if self.rows != self.cols {
144            return Err("Matrix must be square for LU decomposition".to_string());
145        }
146        if !singular_relative_epsilon.is_finite() || singular_relative_epsilon < 0.0 {
147            return Err("Invalid singular tolerance".to_string());
148        }
149
150        let n = self.rows;
151        let mut l = Matrix::identity(n);
152        let mut u = self.clone();
153        let mut pivot = (0..n).collect::<Vec<_>>();
154
155        for k in 0..n {
156            let mut max_val = 0.0;
157            let mut max_row = k;
158            for i in k..n {
159                let val = u[(i, k)].abs();
160                if val > max_val {
161                    max_val = val;
162                    max_row = i;
163                }
164            }
165
166            let mut row_norm: f64 = 0.0;
167            for j in k..n {
168                row_norm = row_norm.max(u[(max_row, j)].abs());
169            }
170
171            if max_val <= singular_relative_epsilon * row_norm {
172                return Err("Matrix is singular".to_string());
173            }
174
175            if max_row != k {
176                pivot.swap(k, max_row);
177                for j in 0..n {
178                    let temp = u[(k, j)];
179                    u[(k, j)] = u[(max_row, j)];
180                    u[(max_row, j)] = temp;
181                }
182                for j in 0..k {
183                    let temp = l[(k, j)];
184                    l[(k, j)] = l[(max_row, j)];
185                    l[(max_row, j)] = temp;
186                }
187            }
188
189            for i in (k + 1)..n {
190                l[(i, k)] = u[(i, k)] / u[(k, k)];
191                for j in k..n {
192                    u[(i, j)] -= l[(i, k)] * u[(k, j)];
193                }
194            }
195        }
196
197        Ok((l, u, pivot))
198    }
199
200    pub fn solve_lu(&self, b: &Matrix) -> Result<Matrix, String> {
201        self.solve_lu_with_tolerance(b, 1e-12)
202    }
203
204    pub fn solve_lu_with_tolerance(
205        &self,
206        b: &Matrix,
207        singular_relative_epsilon: f64,
208    ) -> Result<Matrix, String> {
209        if self.rows != self.cols {
210            return Err("Matrix must be square".to_string());
211        }
212        if b.rows != self.rows || b.cols != 1 {
213            return Err("Invalid dimensions for b".to_string());
214        }
215
216        let (l, u, pivot) = self.lu_decomposition_with_tolerance(singular_relative_epsilon)?;
217
218        let mut pb = Matrix::new(b.rows, 1);
219        for i in 0..b.rows {
220            pb[(i, 0)] = b[(pivot[i], 0)];
221        }
222
223        let mut y = Matrix::new(self.rows, 1);
224        for i in 0..self.rows {
225            y[(i, 0)] = pb[(i, 0)];
226            for j in 0..i {
227                y[(i, 0)] -= l[(i, j)] * y[(j, 0)];
228            }
229        }
230
231        let mut x = Matrix::new(self.rows, 1);
232        for i in (0..self.rows).rev() {
233            x[(i, 0)] = y[(i, 0)];
234            for j in (i + 1)..self.rows {
235                x[(i, 0)] -= u[(i, j)] * x[(j, 0)];
236            }
237            x[(i, 0)] /= u[(i, i)];
238        }
239
240        Ok(x)
241    }
242
243    /// Solve a (possibly non-square) least squares problem using Householder QR.
244    ///
245    /// Solves `min_x ||A x - b||_2`, where `A` is `self`.
246    /// - If `rows >= cols`, returns the standard least-squares solution.
247    /// - If `rows < cols`, returns the minimum-norm solution among all minimizers.
248    pub fn solve_least_squares_qr(&self, b: &Matrix) -> Result<Matrix, String> {
249        self.solve_least_squares_qr_with_info(b)
250            .map(|(solution, _)| solution)
251    }
252
253    pub fn solve_least_squares_qr_with_parallel(
254        &self,
255        b: &Matrix,
256        parallel: bool,
257    ) -> Result<Matrix, String> {
258        self.solve_least_squares_qr_with_info_with_parallel(b, parallel)
259            .map(|(solution, _)| solution)
260    }
261
262    /// Solve a (possibly non-square) least squares problem using Householder QR,
263    /// returning diagnostic information about rank and conditioning.
264    pub fn solve_least_squares_qr_with_info(
265        &self,
266        b: &Matrix,
267    ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
268        self.solve_least_squares_qr_with_info_with_parallel(b, false)
269    }
270
271    pub fn solve_least_squares_qr_with_info_with_parallel(
272        &self,
273        b: &Matrix,
274        parallel: bool,
275    ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
276        if b.cols != 1 {
277            return Err("Invalid dimensions for b (expected column vector)".to_string());
278        }
279        if b.rows != self.rows {
280            return Err("Invalid dimensions for b".to_string());
281        }
282
283        let m = self.rows;
284        let n = self.cols;
285
286        // Trivial cases.
287        if n == 0 {
288            return Ok((
289                Matrix::new(0, 1),
290                LeastSquaresQrInfo {
291                    rank: 0,
292                    cond_est: f64::INFINITY,
293                },
294            ));
295        }
296
297        if m >= n {
298            Self::solve_least_squares_qr_tall_with_info(self, b, parallel)
299        } else {
300            Self::solve_least_squares_qr_wide_with_info(self, b, parallel)
301        }
302    }
303
304    fn solve_least_squares_qr_tall_with_info(
305        a: &Matrix,
306        b: &Matrix,
307        parallel: bool,
308    ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
309        let m = a.rows;
310        let n = a.cols;
311
312        let mut r = a.clone();
313        let mut qt_b = b.clone();
314        let mut taus = Vec::with_capacity(n);
315
316        // Householder QR factorization (thin): apply reflectors to `r` and `qt_b`.
317        for k in 0..n {
318            let mut col_norm: f64 = 0.0;
319            for i in k..m {
320                col_norm = col_norm.hypot(r[(i, k)]);
321            }
322
323            if col_norm == 0.0 {
324                taus.push(0.0);
325                continue;
326            }
327
328            let x0 = r[(k, k)];
329            let sign = if x0 >= 0.0 { 1.0 } else { -1.0 };
330            let alpha = -sign * col_norm;
331            let v0 = x0 - alpha;
332
333            // Store Householder vector v in-place in column k (v0 is implicit 1.0).
334            for i in (k + 1)..m {
335                r[(i, k)] /= v0;
336            }
337
338            let mut v_sq = 1.0;
339            for i in (k + 1)..m {
340                let vi = r[(i, k)];
341                v_sq += vi * vi;
342            }
343            let tau = 2.0 / v_sq;
344            taus.push(tau);
345
346            // Apply reflector to remaining columns.
347            let cols_left = n.saturating_sub(k + 1);
348            if cols_left > 0 {
349                let use_parallel = parallel && should_parallelize(m - (k + 1), cols_left);
350                let dots: Vec<f64> = if use_parallel {
351                    (0..cols_left)
352                        .into_par_iter()
353                        .map(|offset| {
354                            let j = k + 1 + offset;
355                            let mut dot = r[(k, j)];
356                            for i in (k + 1)..m {
357                                dot += r[(i, k)] * r[(i, j)];
358                            }
359                            dot * tau
360                        })
361                        .collect()
362                } else {
363                    let mut dots = Vec::with_capacity(cols_left);
364                    for j in (k + 1)..n {
365                        let mut dot = r[(k, j)];
366                        for i in (k + 1)..m {
367                            dot += r[(i, k)] * r[(i, j)];
368                        }
369                        dots.push(dot * tau);
370                    }
371                    dots
372                };
373
374                for (offset, dot) in dots.iter().enumerate() {
375                    let j = k + 1 + offset;
376                    r[(k, j)] -= dot;
377                }
378
379                if use_parallel {
380                    let cols = r.cols;
381                    let j_start = k + 1;
382                    let k_col = k;
383                    r.data[(k + 1) * cols..]
384                        .par_chunks_mut(cols)
385                        .for_each(|row| {
386                            let vik = row[k_col];
387                            if vik != 0.0 {
388                                for (offset, dot) in dots.iter().enumerate() {
389                                    row[j_start + offset] -= vik * dot;
390                                }
391                            }
392                        });
393                } else {
394                    for i in (k + 1)..m {
395                        let vik = r[(i, k)];
396                        if vik != 0.0 {
397                            for (offset, dot) in dots.iter().enumerate() {
398                                let j = k + 1 + offset;
399                                r[(i, j)] -= vik * dot;
400                            }
401                        }
402                    }
403                }
404            }
405
406            // Apply reflector to b: qt_b = Q^T b
407            let mut dot = qt_b[(k, 0)];
408            for i in (k + 1)..m {
409                dot += r[(i, k)] * qt_b[(i, 0)];
410            }
411            dot *= tau;
412            qt_b[(k, 0)] -= dot;
413
414            let use_parallel = parallel && should_parallelize(m - (k + 1), 1);
415            if use_parallel {
416                let cols = r.cols;
417                let k_col = k;
418                let r_data = &r.data;
419                qt_b.data[(k + 1)..]
420                    .par_iter_mut()
421                    .enumerate()
422                    .for_each(|(idx, val)| {
423                        let i = k + 1 + idx;
424                        let vik = r_data[i * cols + k_col];
425                        *val -= vik * dot;
426                    });
427            } else {
428                for i in (k + 1)..m {
429                    qt_b[(i, 0)] -= r[(i, k)] * dot;
430                }
431            }
432
433            r[(k, k)] = alpha;
434        }
435
436        let mut max_diag: f64 = 0.0;
437        for i in 0..n {
438            max_diag = max_diag.max(r[(i, i)].abs());
439        }
440        let tol = 1e-12 * max_diag.max(1.0);
441        let mut rank = 0;
442        let mut min_diag = f64::INFINITY;
443        for i in 0..n {
444            let diag = r[(i, i)].abs();
445            if diag > tol {
446                rank += 1;
447                if diag < min_diag {
448                    min_diag = diag;
449                }
450            }
451        }
452        let cond_est = if rank == 0 || !min_diag.is_finite() {
453            f64::INFINITY
454        } else {
455            max_diag / min_diag
456        };
457
458        // Back-substitute R x = Q^T b, using the top n entries.
459        let mut x = Matrix::new(n, 1);
460        for i in (0..n).rev() {
461            let mut sum = qt_b[(i, 0)];
462            for j in (i + 1)..n {
463                sum -= r[(i, j)] * x[(j, 0)];
464            }
465
466            let diag = r[(i, i)];
467            if !diag.is_finite() {
468                return Err("Least squares solve failed: non-finite diagonal in R".to_string());
469            }
470
471            let mut row_norm: f64 = 0.0;
472            for j in i..n {
473                row_norm = row_norm.max(r[(i, j)].abs());
474            }
475
476            if diag.abs() <= 1e-12 * row_norm {
477                return Err("Least squares solve failed: matrix is rank deficient".to_string());
478            }
479
480            x[(i, 0)] = sum / diag;
481        }
482
483        Ok((
484            x,
485            LeastSquaresQrInfo {
486                rank,
487                cond_est,
488            },
489        ))
490    }
491
492    fn solve_least_squares_qr_wide_with_info(
493        a: &Matrix,
494        b: &Matrix,
495        parallel: bool,
496    ) -> Result<(Matrix, LeastSquaresQrInfo), String> {
497        // Use QR on A^T to compute the minimum-norm least squares solution.
498        //
499        // Let A be m x n with m < n. Compute QR of A^T (n x m):
500        // A^T = Q R, where Q is n x n orthogonal and R has top-left m x m upper-triangular.
501        // The minimum-norm solution is x = Q [y; 0], where y solves R^T y = b.
502        let m = a.rows;
503        let n = a.cols;
504
505        let mut r = a.transpose(); // n x m, tall
506        let mut taus = Vec::with_capacity(m);
507
508        // Householder QR on r (n x m), producing R in the top m x m block.
509        for k in 0..m {
510            let mut col_norm: f64 = 0.0;
511            for i in k..n {
512                col_norm = col_norm.hypot(r[(i, k)]);
513            }
514
515            if col_norm == 0.0 {
516                taus.push(0.0);
517                continue;
518            }
519
520            let x0 = r[(k, k)];
521            let sign = if x0 >= 0.0 { 1.0 } else { -1.0 };
522            let alpha = -sign * col_norm;
523            let v0 = x0 - alpha;
524
525            for i in (k + 1)..n {
526                r[(i, k)] /= v0;
527            }
528
529            let mut v_sq = 1.0;
530            for i in (k + 1)..n {
531                let vi = r[(i, k)];
532                v_sq += vi * vi;
533            }
534            let tau = 2.0 / v_sq;
535            taus.push(tau);
536
537            let cols_left = m.saturating_sub(k + 1);
538            if cols_left > 0 {
539                let use_parallel = parallel && should_parallelize(n - (k + 1), cols_left);
540                let dots: Vec<f64> = if use_parallel {
541                    (0..cols_left)
542                        .into_par_iter()
543                        .map(|offset| {
544                            let j = k + 1 + offset;
545                            let mut dot = r[(k, j)];
546                            for i in (k + 1)..n {
547                                dot += r[(i, k)] * r[(i, j)];
548                            }
549                            dot * tau
550                        })
551                        .collect()
552                } else {
553                    let mut dots = Vec::with_capacity(cols_left);
554                    for j in (k + 1)..m {
555                        let mut dot = r[(k, j)];
556                        for i in (k + 1)..n {
557                            dot += r[(i, k)] * r[(i, j)];
558                        }
559                        dots.push(dot * tau);
560                    }
561                    dots
562                };
563
564                for (offset, dot) in dots.iter().enumerate() {
565                    let j = k + 1 + offset;
566                    r[(k, j)] -= dot;
567                }
568
569                if use_parallel {
570                    let cols = r.cols;
571                    let j_start = k + 1;
572                    let k_col = k;
573                    r.data[(k + 1) * cols..]
574                        .par_chunks_mut(cols)
575                        .for_each(|row| {
576                            let vik = row[k_col];
577                            if vik != 0.0 {
578                                for (offset, dot) in dots.iter().enumerate() {
579                                    row[j_start + offset] -= vik * dot;
580                                }
581                            }
582                        });
583                } else {
584                    for i in (k + 1)..n {
585                        let vik = r[(i, k)];
586                        if vik != 0.0 {
587                            for (offset, dot) in dots.iter().enumerate() {
588                                let j = k + 1 + offset;
589                                r[(i, j)] -= vik * dot;
590                            }
591                        }
592                    }
593                }
594            }
595
596            r[(k, k)] = alpha;
597        }
598
599        let mut max_diag: f64 = 0.0;
600        for i in 0..m {
601            max_diag = max_diag.max(r[(i, i)].abs());
602        }
603        let tol = 1e-12 * max_diag.max(1.0);
604        let mut rank = 0;
605        let mut min_diag = f64::INFINITY;
606        for i in 0..m {
607            let diag = r[(i, i)].abs();
608            if diag > tol {
609                rank += 1;
610                if diag < min_diag {
611                    min_diag = diag;
612                }
613            }
614        }
615        let cond_est = if rank == 0 || !min_diag.is_finite() {
616            f64::INFINITY
617        } else {
618            max_diag / min_diag
619        };
620
621        // Solve R^T y = b (R is m x m upper triangular, so R^T is lower triangular).
622        let mut y = vec![0.0; m];
623        for i in 0..m {
624            let mut sum = b[(i, 0)];
625            for j in 0..i {
626                sum -= r[(j, i)] * y[j];
627            }
628
629            let diag = r[(i, i)];
630            if !diag.is_finite() {
631                return Err("Least squares solve failed: non-finite diagonal in R".to_string());
632            }
633
634            let mut col_norm: f64 = 0.0;
635            for j in 0..=i {
636                col_norm = col_norm.max(r[(j, i)].abs());
637            }
638
639            if diag.abs() <= 1e-12 * col_norm {
640                return Err("Least squares solve failed: matrix is rank deficient".to_string());
641            }
642
643            y[i] = sum / diag;
644        }
645
646        // Form w = [y; 0] in R^n.
647        let mut w = vec![0.0; n];
648        w[..m].copy_from_slice(&y[..m]);
649
650        // Apply Q to w: Q = H0 H1 ... H_{m-1}.
651        // To compute Q w, apply reflectors in reverse order.
652        for k in (0..m).rev() {
653            let tau = taus[k];
654            if tau == 0.0 {
655                continue;
656            }
657
658            let mut dot = w[k];
659            for i in (k + 1)..n {
660                dot += r[(i, k)] * w[i];
661            }
662            dot *= tau;
663
664            w[k] -= dot;
665
666            let use_parallel = parallel && should_parallelize(n - (k + 1), 1);
667            if use_parallel {
668                let cols = r.cols;
669                let k_col = k;
670                let r_data = &r.data;
671                w[(k + 1)..]
672                    .par_iter_mut()
673                    .enumerate()
674                    .for_each(|(idx, val)| {
675                        let i = k + 1 + idx;
676                        let vik = r_data[i * cols + k_col];
677                        *val -= vik * dot;
678                    });
679            } else {
680                for i in (k + 1)..n {
681                    w[i] -= r[(i, k)] * dot;
682                }
683            }
684        }
685
686        Matrix::from_vec(w, n, 1).map(|solution| {
687            (
688                solution,
689                LeastSquaresQrInfo {
690                    rank,
691                    cond_est,
692                },
693            )
694        })
695    }
696
697    pub fn norm(&self) -> f64 {
698        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
699    }
700
701    pub fn try_add(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
702        if self.rows != rhs.rows || self.cols != rhs.cols {
703            return Err(MatrixError::DimensionMismatch {
704                operation: "add",
705                left: (self.rows, self.cols),
706                right: (rhs.rows, rhs.cols),
707            });
708        }
709
710        let mut result = Matrix::new(self.rows, self.cols);
711        for i in 0..self.data.len() {
712            result.data[i] = self.data[i] + rhs.data[i];
713        }
714        Ok(result)
715    }
716
717    pub fn try_sub(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
718        if self.rows != rhs.rows || self.cols != rhs.cols {
719            return Err(MatrixError::DimensionMismatch {
720                operation: "sub",
721                left: (self.rows, self.cols),
722                right: (rhs.rows, rhs.cols),
723            });
724        }
725
726        let mut result = Matrix::new(self.rows, self.cols);
727        for i in 0..self.data.len() {
728            result.data[i] = self.data[i] - rhs.data[i];
729        }
730        Ok(result)
731    }
732
733    pub fn try_mul(&self, rhs: &Matrix) -> Result<Matrix, MatrixError> {
734        self.try_mul_with_parallel(rhs, false)
735    }
736
737    pub fn try_mul_with_parallel(&self, rhs: &Matrix, parallel: bool) -> Result<Matrix, MatrixError> {
738        if self.cols != rhs.rows {
739            return Err(MatrixError::DimensionMismatch {
740                operation: "mul",
741                left: (self.rows, self.cols),
742                right: (rhs.rows, rhs.cols),
743            });
744        }
745
746        let mut result = Matrix::new(self.rows, rhs.cols);
747        if parallel && should_parallelize(self.rows, rhs.cols) {
748            let rhs_cols = rhs.cols;
749            result
750                .data
751                .par_chunks_mut(rhs_cols)
752                .enumerate()
753                .for_each(|(i, row)| {
754                    for k in 0..self.cols {
755                        let a = self[(i, k)];
756                        let rhs_row = &rhs.data[k * rhs_cols..(k + 1) * rhs_cols];
757                        for j in 0..rhs_cols {
758                            row[j] += a * rhs_row[j];
759                        }
760                    }
761                });
762            return Ok(result);
763        }
764
765        for i in 0..self.rows {
766            for k in 0..self.cols {
767                let a = self[(i, k)];
768                for j in 0..rhs.cols {
769                    result[(i, j)] += a * rhs[(k, j)];
770                }
771            }
772        }
773        Ok(result)
774    }
775}
776
777impl Index<(usize, usize)> for Matrix {
778    type Output = f64;
779
780    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
781        &self.data[row * self.cols + col]
782    }
783}
784
785impl IndexMut<(usize, usize)> for Matrix {
786    fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output {
787        &mut self.data[row * self.cols + col]
788    }
789}
790
791impl Add for &Matrix {
792    type Output = Matrix;
793
794    fn add(self, rhs: Self) -> Self::Output {
795        self.try_add(rhs).unwrap_or_else(|err| panic!("{}", err))
796    }
797}
798
799impl Sub for &Matrix {
800    type Output = Matrix;
801
802    fn sub(self, rhs: Self) -> Self::Output {
803        self.try_sub(rhs).unwrap_or_else(|err| panic!("{}", err))
804    }
805}
806
807impl Mul for &Matrix {
808    type Output = Matrix;
809
810    fn mul(self, rhs: Self) -> Self::Output {
811        self.try_mul(rhs).unwrap_or_else(|err| panic!("{}", err))
812    }
813}
814
815impl Mul<f64> for &Matrix {
816    type Output = Matrix;
817
818    fn mul(self, scalar: f64) -> Self::Output {
819        let mut result = self.clone();
820        for val in &mut result.data {
821            *val *= scalar;
822        }
823        result
824    }
825}
826
827impl fmt::Display for Matrix {
828    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
829        for i in 0..self.rows {
830            write!(f, "[")?;
831            for j in 0..self.cols {
832                if j > 0 {
833                    write!(f, ", ")?;
834                }
835                write!(f, "{:8.4}", self[(i, j)])?;
836            }
837            writeln!(f, "]")?;
838        }
839        Ok(())
840    }
841}
842
843#[cfg(test)]
844mod tests {
845    use super::*;
846
847    struct TestRng {
848        state: u64,
849    }
850
851    impl TestRng {
852        fn new(seed: u64) -> Self {
853            Self { state: seed }
854        }
855
856        fn next_u32(&mut self) -> u32 {
857            self.state = self
858                .state
859                .wrapping_mul(6364136223846793005)
860                .wrapping_add(1);
861            (self.state >> 32) as u32
862        }
863
864        fn next_f64(&mut self) -> f64 {
865            let v = self.next_u32() as f64 / u32::MAX as f64;
866            2.0 * v - 1.0
867        }
868    }
869
870    fn random_matrix(rows: usize, cols: usize, rng: &mut TestRng) -> Matrix {
871        let mut m = Matrix::new(rows, cols);
872        for i in 0..rows {
873            for j in 0..cols {
874                m[(i, j)] = rng.next_f64();
875            }
876        }
877        m
878    }
879
880    fn assert_matrix_close(a: &Matrix, b: &Matrix, tol: f64) {
881        assert_eq!(a.rows(), b.rows());
882        assert_eq!(a.cols(), b.cols());
883        for i in 0..a.rows() {
884            for j in 0..a.cols() {
885                let diff = (a[(i, j)] - b[(i, j)]).abs();
886                assert!(diff <= tol, "mismatch at ({i},{j}): {diff}");
887            }
888        }
889    }
890
891    #[test]
892    fn test_matrix_creation() {
893        let m = Matrix::new(2, 3);
894        assert_eq!(m.rows(), 2);
895        assert_eq!(m.cols(), 3);
896        assert_eq!(m[(0, 0)], 0.0);
897    }
898
899    #[test]
900    fn test_identity() {
901        let m = Matrix::identity(3);
902        assert_eq!(m[(0, 0)], 1.0);
903        assert_eq!(m[(1, 1)], 1.0);
904        assert_eq!(m[(2, 2)], 1.0);
905        assert_eq!(m[(0, 1)], 0.0);
906    }
907
908    #[test]
909    fn test_transpose() {
910        let mut m = Matrix::new(2, 3);
911        m[(0, 0)] = 1.0;
912        m[(0, 1)] = 2.0;
913        m[(0, 2)] = 3.0;
914        m[(1, 0)] = 4.0;
915        m[(1, 1)] = 5.0;
916        m[(1, 2)] = 6.0;
917
918        let mt_serial = m.transpose_with_parallel(false);
919        let mt_parallel = m.transpose_with_parallel(true);
920        assert_matrix_close(&mt_serial, &mt_parallel, 0.0);
921        assert_eq!(mt_serial.rows(), 3);
922        assert_eq!(mt_serial.cols(), 2);
923        assert_eq!(mt_serial[(0, 0)], 1.0);
924        assert_eq!(mt_serial[(1, 0)], 2.0);
925        assert_eq!(mt_serial[(2, 0)], 3.0);
926        assert_eq!(mt_serial[(0, 1)], 4.0);
927    }
928
929    #[test]
930    fn test_matrix_multiplication() {
931        let mut a = Matrix::new(2, 3);
932        a[(0, 0)] = 1.0;
933        a[(0, 1)] = 2.0;
934        a[(0, 2)] = 3.0;
935        a[(1, 0)] = 4.0;
936        a[(1, 1)] = 5.0;
937        a[(1, 2)] = 6.0;
938
939        let mut b = Matrix::new(3, 2);
940        b[(0, 0)] = 7.0;
941        b[(0, 1)] = 8.0;
942        b[(1, 0)] = 9.0;
943        b[(1, 1)] = 10.0;
944        b[(2, 0)] = 11.0;
945        b[(2, 1)] = 12.0;
946
947        let c_serial = a.try_mul_with_parallel(&b, false).unwrap();
948        let c_parallel = a.try_mul_with_parallel(&b, true).unwrap();
949        assert_matrix_close(&c_serial, &c_parallel, 0.0);
950        assert_eq!(c_serial.rows(), 2);
951        assert_eq!(c_serial.cols(), 2);
952        assert_eq!(c_serial[(0, 0)], 58.0);
953        assert_eq!(c_serial[(0, 1)], 64.0);
954        assert_eq!(c_serial[(1, 0)], 139.0);
955        assert_eq!(c_serial[(1, 1)], 154.0);
956    }
957
958    #[test]
959    fn test_parallel_equivalence_threshold_size() {
960        let size = 128;
961        let mut rng = TestRng::new(0x3e5a_9f21_d00d_cafe);
962
963        let a = random_matrix(size, size, &mut rng);
964        let b = random_matrix(size, size, &mut rng);
965
966        let c_serial = a.try_mul_with_parallel(&b, false).unwrap();
967        let c_parallel = a.try_mul_with_parallel(&b, true).unwrap();
968        assert_matrix_close(&c_serial, &c_parallel, 1e-10);
969
970        let t_serial = a.transpose_with_parallel(false);
971        let t_parallel = a.transpose_with_parallel(true);
972        assert_matrix_close(&t_serial, &t_parallel, 0.0);
973
974        let m = size * 4;
975        let n = size;
976        let tall = random_matrix(m, n, &mut rng);
977        let tall_b = random_matrix(m, 1, &mut rng);
978        let (x_serial, info_serial) =
979            tall.solve_least_squares_qr_with_info_with_parallel(&tall_b, false)
980                .unwrap();
981        let (x_parallel, info_parallel) =
982            tall.solve_least_squares_qr_with_info_with_parallel(&tall_b, true)
983                .unwrap();
984        assert_eq!(info_serial.rank, info_parallel.rank);
985        assert!(info_serial.cond_est.is_finite());
986        assert!(info_parallel.cond_est.is_finite());
987        assert_matrix_close(&x_serial, &x_parallel, 1e-7);
988
989        let wide = random_matrix(n, m, &mut rng);
990        let wide_b = random_matrix(n, 1, &mut rng);
991        let (x_serial, info_serial) =
992            wide.solve_least_squares_qr_with_info_with_parallel(&wide_b, false)
993                .unwrap();
994        let (x_parallel, info_parallel) =
995            wide.solve_least_squares_qr_with_info_with_parallel(&wide_b, true)
996                .unwrap();
997        assert_eq!(info_serial.rank, info_parallel.rank);
998        assert!(info_serial.cond_est.is_finite());
999        assert!(info_parallel.cond_est.is_finite());
1000        assert_matrix_close(&x_serial, &x_parallel, 1e-7);
1001    }
1002
1003    #[test]
1004    fn test_lu_solve() {
1005        let a = Matrix::from_vec(vec![2.0, 1.0, 3.0, 4.0], 2, 2).unwrap();
1006        let b = Matrix::from_vec(vec![5.0, 11.0], 2, 1).unwrap();
1007
1008        let x = a.solve_lu(&b).unwrap();
1009
1010        let verify_serial = a.try_mul_with_parallel(&x, false).unwrap();
1011        let verify_parallel = a.try_mul_with_parallel(&x, true).unwrap();
1012        assert_matrix_close(&verify_serial, &verify_parallel, 0.0);
1013
1014        assert!((verify_serial[(0, 0)] - b[(0, 0)]).abs() < 1e-10);
1015        assert!((verify_serial[(1, 0)] - b[(1, 0)]).abs() < 1e-10);
1016    }
1017
1018    #[test]
1019    fn test_lu_solve_is_scale_invariant() {
1020        let a = Matrix::from_vec(vec![2.0, 1.0, 3.0, 4.0], 2, 2).unwrap();
1021        let b = Matrix::from_vec(vec![5.0, 11.0], 2, 1).unwrap();
1022
1023        let x = a.solve_lu(&b).unwrap();
1024
1025        let scale = 1e-12;
1026        let a_scaled = &a * scale;
1027        let b_scaled = &b * scale;
1028
1029        let x_scaled = a_scaled.solve_lu(&b_scaled).unwrap();
1030        assert!((x_scaled[(0, 0)] - x[(0, 0)]).abs() < 1e-8);
1031        assert!((x_scaled[(1, 0)] - x[(1, 0)]).abs() < 1e-8);
1032    }
1033
1034    #[test]
1035    fn test_try_add_dimension_mismatch() {
1036        let a = Matrix::new(2, 2);
1037        let b = Matrix::new(2, 3);
1038
1039        let err = a.try_add(&b).expect_err("expected dimension mismatch");
1040        assert_eq!(
1041            err,
1042            MatrixError::DimensionMismatch {
1043                operation: "add",
1044                left: (2, 2),
1045                right: (2, 3),
1046            }
1047        );
1048    }
1049
1050    #[test]
1051    fn test_try_sub_dimension_mismatch() {
1052        let a = Matrix::new(3, 2);
1053        let b = Matrix::new(2, 2);
1054
1055        let err = a.try_sub(&b).expect_err("expected dimension mismatch");
1056        assert_eq!(
1057            err,
1058            MatrixError::DimensionMismatch {
1059                operation: "sub",
1060                left: (3, 2),
1061                right: (2, 2),
1062            }
1063        );
1064    }
1065
1066    #[test]
1067    fn test_try_mul_dimension_mismatch() {
1068        let a = Matrix::new(2, 3);
1069        let b = Matrix::new(2, 2);
1070
1071        let err_serial = a
1072            .try_mul_with_parallel(&b, false)
1073            .expect_err("expected dimension mismatch");
1074        let err_parallel = a
1075            .try_mul_with_parallel(&b, true)
1076            .expect_err("expected dimension mismatch");
1077        assert_eq!(
1078            err_serial,
1079            MatrixError::DimensionMismatch {
1080                operation: "mul",
1081                left: (2, 3),
1082                right: (2, 2),
1083            }
1084        );
1085        assert_eq!(err_serial, err_parallel);
1086    }
1087
1088    #[test]
1089    fn test_solve_least_squares_qr_overdetermined_exact() {
1090        // A is 3x2, consistent system with exact solution x=[1,2].
1091        let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 3, 2).unwrap();
1092        let b = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
1093
1094        let x_serial = a.solve_least_squares_qr_with_parallel(&b, false).unwrap();
1095        let x_parallel = a.solve_least_squares_qr_with_parallel(&b, true).unwrap();
1096        assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1097        assert!((x_serial[(0, 0)] - 1.0).abs() < 1e-12);
1098        assert!((x_serial[(1, 0)] - 2.0).abs() < 1e-12);
1099    }
1100
1101    #[test]
1102    fn test_solve_least_squares_qr_underdetermined_min_norm() {
1103        // A is 1x2: x0 + x1 = 1. The minimum-norm solution is [0.5, 0.5].
1104        let a = Matrix::from_vec(vec![1.0, 1.0], 1, 2).unwrap();
1105        let b = Matrix::from_vec(vec![1.0], 1, 1).unwrap();
1106
1107        let x_serial = a.solve_least_squares_qr_with_parallel(&b, false).unwrap();
1108        let x_parallel = a.solve_least_squares_qr_with_parallel(&b, true).unwrap();
1109        assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1110        assert!((x_serial[(0, 0)] - 0.5).abs() < 1e-12);
1111        assert!((x_serial[(1, 0)] - 0.5).abs() < 1e-12);
1112    }
1113
1114    #[test]
1115    fn test_solve_least_squares_qr_with_info_random_tall_full_rank() {
1116        let mut rng = TestRng::new(0x5eed_1234_5678_9abc);
1117        for _ in 0..5 {
1118            let mut a = random_matrix(6, 3, &mut rng);
1119            for i in 0..3 {
1120                a[(i, i)] += 2.0;
1121            }
1122            let x_true = vec![rng.next_f64(), rng.next_f64(), rng.next_f64()];
1123            let x_mat = Matrix::from_vec(x_true.clone(), 3, 1).unwrap();
1124            let b = &a * &x_mat;
1125
1126            let (x_serial, info_serial) =
1127                a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1128                    .unwrap();
1129            let (x_parallel, info_parallel) =
1130                a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1131                    .unwrap();
1132            assert_eq!(info_serial.rank, 3);
1133            assert_eq!(info_parallel.rank, 3);
1134            assert!(info_serial.cond_est.is_finite());
1135            assert!(info_parallel.cond_est.is_finite());
1136            assert_matrix_close(&x_serial, &x_parallel, 1e-8);
1137            for i in 0..3 {
1138                assert!((x_serial[(i, 0)] - x_true[i]).abs() < 1e-8);
1139                assert!((x_parallel[(i, 0)] - x_true[i]).abs() < 1e-8);
1140            }
1141        }
1142    }
1143
1144    #[test]
1145    fn test_solve_least_squares_qr_with_info_random_wide_min_norm() {
1146        let mut rng = TestRng::new(0x1234_5678_9abc_def0);
1147        let mut a = Matrix::new(2, 4);
1148        a[(0, 0)] = 1.0;
1149        a[(1, 1)] = 1.0;
1150
1151        for _ in 0..5 {
1152            let b0 = rng.next_f64();
1153            let b1 = rng.next_f64();
1154            let b = Matrix::from_vec(vec![b0, b1], 2, 1).unwrap();
1155
1156            let (x_serial, info_serial) =
1157                a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1158                    .unwrap();
1159            let (x_parallel, info_parallel) =
1160                a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1161                    .unwrap();
1162            assert_eq!(info_serial.rank, 2);
1163            assert_eq!(info_parallel.rank, 2);
1164            assert!(info_serial.cond_est.is_finite());
1165            assert!(info_parallel.cond_est.is_finite());
1166            assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1167            assert!((x_serial[(0, 0)] - b0).abs() < 1e-12);
1168            assert!((x_serial[(1, 0)] - b1).abs() < 1e-12);
1169            assert!((x_serial[(2, 0)]).abs() < 1e-12);
1170            assert!((x_serial[(3, 0)]).abs() < 1e-12);
1171            assert!((x_parallel[(0, 0)] - b0).abs() < 1e-12);
1172            assert!((x_parallel[(1, 0)] - b1).abs() < 1e-12);
1173            assert!((x_parallel[(2, 0)]).abs() < 1e-12);
1174            assert!((x_parallel[(3, 0)]).abs() < 1e-12);
1175        }
1176    }
1177
1178    #[test]
1179    fn test_solve_least_squares_qr_with_info_detects_ill_conditioning() {
1180        let mut a = Matrix::identity(3);
1181        for i in 0..3 {
1182            a[(i, 2)] *= 1e-8;
1183        }
1184        let b = Matrix::from_vec(vec![0.0, 0.0, 0.0], 3, 1).unwrap();
1185
1186        let (_x_serial, info_serial) =
1187            a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1188                .unwrap();
1189        let (_x_parallel, info_parallel) =
1190            a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1191                .unwrap();
1192        assert_eq!(info_serial.rank, 3);
1193        assert_eq!(info_parallel.rank, 3);
1194        assert!(info_serial.cond_est > 1e6, "cond_est was {}", info_serial.cond_est);
1195        assert!(info_parallel.cond_est > 1e6, "cond_est was {}", info_parallel.cond_est);
1196    }
1197
1198    #[test]
1199    fn test_solve_least_squares_qr_with_info_tall_full_rank() {
1200        let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 3, 2).unwrap();
1201        let b = Matrix::from_vec(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
1202
1203        let (x_serial, info_serial) =
1204            a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1205                .unwrap();
1206        let (x_parallel, info_parallel) =
1207            a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1208                .unwrap();
1209        assert_eq!(info_serial.rank, 2);
1210        assert_eq!(info_parallel.rank, 2);
1211        assert!(info_serial.cond_est.is_finite());
1212        assert!(info_parallel.cond_est.is_finite());
1213        assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1214        assert!((x_serial[(0, 0)] - 1.0).abs() < 1e-12);
1215        assert!((x_serial[(1, 0)] - 2.0).abs() < 1e-12);
1216    }
1217
1218    #[test]
1219    fn test_solve_least_squares_qr_with_info_wide_full_rank() {
1220        let a = Matrix::from_vec(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
1221        let b = Matrix::from_vec(vec![1.0, 2.0], 2, 1).unwrap();
1222
1223        let (x_serial, info_serial) =
1224            a.solve_least_squares_qr_with_info_with_parallel(&b, false)
1225                .unwrap();
1226        let (x_parallel, info_parallel) =
1227            a.solve_least_squares_qr_with_info_with_parallel(&b, true)
1228                .unwrap();
1229        assert_eq!(info_serial.rank, 2);
1230        assert_eq!(info_parallel.rank, 2);
1231        assert!(info_serial.cond_est.is_finite());
1232        assert!(info_parallel.cond_est.is_finite());
1233        assert_matrix_close(&x_serial, &x_parallel, 1e-12);
1234        assert!((x_serial[(0, 0)] - 1.0).abs() < 1e-12);
1235        assert!((x_serial[(1, 0)] - 2.0).abs() < 1e-12);
1236        assert!((x_serial[(2, 0)] - 0.0).abs() < 1e-12);
1237    }
1238
1239    #[test]
1240    fn test_solve_least_squares_qr_with_info_rank_deficient() {
1241        let a = Matrix::from_vec(vec![1.0, 1.0, 2.0, 2.0], 2, 2).unwrap();
1242        let b = Matrix::from_vec(vec![1.0, 2.0], 2, 1).unwrap();
1243
1244        let err_serial = a
1245            .solve_least_squares_qr_with_info_with_parallel(&b, false)
1246            .expect_err("expected rank-deficient QR solve to fail");
1247        let err_parallel = a
1248            .solve_least_squares_qr_with_info_with_parallel(&b, true)
1249            .expect_err("expected rank-deficient QR solve to fail");
1250        assert!(err_serial.contains("rank deficient"), "{err_serial}");
1251        assert!(err_parallel.contains("rank deficient"), "{err_parallel}");
1252    }
1253}