Skip to main content

numra_linalg/
qr.rs

1//! QR factorization and solver.
2//!
3//! Author: Moussa Leblouba
4//! Date: 8 February 2026
5//! Modified: 2 May 2026
6
7use crate::matrix::DenseMatrix;
8use crate::Scalar;
9use faer::linalg::solvers::Qr;
10use faer::prelude::*;
11use faer::{ComplexField, Conjugate, Entity, Mat, SimpleEntity};
12use numra_core::LinalgError;
13
14/// QR factorization of a matrix.
15///
16/// Caches the decomposition computed in `new()` and reuses it in every `solve()` /
17/// `solve_least_squares()` call, avoiding redundant recomputation.
18pub struct QRFactorization<S: Scalar + Entity> {
19    /// Cached QR decomposition from faer
20    qr: Qr<S>,
21    m: usize,
22    n: usize,
23}
24
25impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> QRFactorization<S> {
26    /// Compute QR factorization of a matrix.
27    ///
28    /// Requires `m >= n` (the matrix must have at least as many rows as columns).
29    /// The decomposition is cached and reused for subsequent `solve()` /
30    /// `solve_least_squares()` calls.
31    pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
32        let m = matrix.rows();
33        let n = matrix.cols();
34
35        if m < n {
36            return Err(LinalgError::DimensionMismatch {
37                expected: (n, n),
38                actual: (m, n),
39            });
40        }
41
42        // Compute and cache the QR factorization
43        let qr = Qr::new(matrix.as_faer());
44
45        Ok(Self { qr, m, n })
46    }
47
48    /// Number of rows of the original matrix.
49    pub fn nrows(&self) -> usize {
50        self.m
51    }
52
53    /// Number of columns of the original matrix.
54    pub fn ncols(&self) -> usize {
55        self.n
56    }
57
58    /// Solve Ax = b using the cached QR factorization (square systems only).
59    pub fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
60        if self.m != self.n {
61            return Err(LinalgError::NotSquare {
62                nrows: self.m,
63                ncols: self.n,
64            });
65        }
66        if b.len() != self.m {
67            return Err(LinalgError::DimensionMismatch {
68                expected: (self.m, 1),
69                actual: (b.len(), 1),
70            });
71        }
72
73        // Create column vector from b
74        let mut b_mat = Mat::zeros(self.m, 1);
75        for (i, &val) in b.iter().enumerate() {
76            b_mat.write(i, 0, val);
77        }
78
79        // Solve using the cached QR factorization (no recomputation)
80        let x_mat = self.qr.solve(&b_mat);
81
82        // Extract result
83        let mut x = Vec::with_capacity(self.n);
84        for i in 0..self.n {
85            x.push(x_mat.read(i, 0));
86        }
87
88        Ok(x)
89    }
90
91    /// Solve the least-squares problem min ||Ax - b||_2.
92    ///
93    /// Works for both square and overdetermined systems (m >= n).
94    /// Uses the cached QR factorization.
95    pub fn solve_least_squares(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
96        if b.len() != self.m {
97            return Err(LinalgError::DimensionMismatch {
98                expected: (self.m, 1),
99                actual: (b.len(), 1),
100            });
101        }
102
103        // Create column vector from b
104        let mut b_mat = Mat::zeros(self.m, 1);
105        for (i, &val) in b.iter().enumerate() {
106            b_mat.write(i, 0, val);
107        }
108
109        // Solve least-squares using the cached QR factorization (no recomputation)
110        let x_mat = self.qr.solve_lstsq(&b_mat);
111
112        // Extract result (x_mat has n rows after lstsq resize)
113        let mut x = Vec::with_capacity(self.n);
114        for i in 0..self.n {
115            x.push(x_mat.read(i, 0));
116        }
117
118        Ok(x)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::Matrix;
126
127    #[test]
128    fn test_qr_square() {
129        // Solve [1 2; 3 4] * x = [5; 11]
130        // x = [1; 2]
131        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
132        m.set(0, 0, 1.0);
133        m.set(0, 1, 2.0);
134        m.set(1, 0, 3.0);
135        m.set(1, 1, 4.0);
136
137        let qr = QRFactorization::new(&m).unwrap();
138        assert_eq!(qr.nrows(), 2);
139        assert_eq!(qr.ncols(), 2);
140
141        let b = vec![5.0, 11.0];
142        let x = qr.solve(&b).unwrap();
143
144        assert!((x[0] - 1.0).abs() < 1e-10);
145        assert!((x[1] - 2.0).abs() < 1e-10);
146    }
147
148    #[test]
149    fn test_qr_overdetermined() {
150        // Least-squares: A = [1 1; 1 2; 1 3], b = [1; 2; 2]
151        // Normal equations: A^T A x = A^T b
152        // A^T A = [3 6; 6 14], A^T b = [5; 11]
153        // x = [2/3; 1/2]
154        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(3, 2);
155        m.set(0, 0, 1.0);
156        m.set(0, 1, 1.0);
157        m.set(1, 0, 1.0);
158        m.set(1, 1, 2.0);
159        m.set(2, 0, 1.0);
160        m.set(2, 1, 3.0);
161
162        let qr = QRFactorization::new(&m).unwrap();
163        assert_eq!(qr.nrows(), 3);
164        assert_eq!(qr.ncols(), 2);
165
166        let b = vec![1.0, 2.0, 2.0];
167        let x = qr.solve_least_squares(&b).unwrap();
168
169        assert!((x[0] - 2.0 / 3.0).abs() < 1e-10);
170        assert!((x[1] - 0.5).abs() < 1e-10);
171    }
172
173    #[test]
174    fn test_qr_dimension_mismatch() {
175        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
176        m.set(0, 0, 1.0);
177        m.set(0, 1, 2.0);
178        m.set(1, 0, 3.0);
179        m.set(1, 1, 4.0);
180
181        let qr = QRFactorization::new(&m).unwrap();
182
183        // Wrong b size for solve
184        let b = vec![1.0, 2.0, 3.0];
185        let result = qr.solve(&b);
186        assert!(result.is_err());
187
188        // Wrong b size for solve_least_squares
189        let result = qr.solve_least_squares(&b);
190        assert!(result.is_err());
191    }
192}