1use crate::matrix::DenseMatrix;
8use crate::Scalar;
9use faer::linalg::solvers::Qr;
10use faer::prelude::*;
11use faer::{ComplexField, Conjugate, Entity, Mat, SimpleEntity};
12use numra_core::LinalgError;
13
14pub struct QRFactorization<S: Scalar + Entity> {
19 qr: Qr<S>,
21 m: usize,
22 n: usize,
23}
24
25impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> QRFactorization<S> {
26 pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
32 let m = matrix.rows();
33 let n = matrix.cols();
34
35 if m < n {
36 return Err(LinalgError::DimensionMismatch {
37 expected: (n, n),
38 actual: (m, n),
39 });
40 }
41
42 let qr = Qr::new(matrix.as_faer());
44
45 Ok(Self { qr, m, n })
46 }
47
48 pub fn nrows(&self) -> usize {
50 self.m
51 }
52
53 pub fn ncols(&self) -> usize {
55 self.n
56 }
57
58 pub fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
60 if self.m != self.n {
61 return Err(LinalgError::NotSquare {
62 nrows: self.m,
63 ncols: self.n,
64 });
65 }
66 if b.len() != self.m {
67 return Err(LinalgError::DimensionMismatch {
68 expected: (self.m, 1),
69 actual: (b.len(), 1),
70 });
71 }
72
73 let mut b_mat = Mat::zeros(self.m, 1);
75 for (i, &val) in b.iter().enumerate() {
76 b_mat.write(i, 0, val);
77 }
78
79 let x_mat = self.qr.solve(&b_mat);
81
82 let mut x = Vec::with_capacity(self.n);
84 for i in 0..self.n {
85 x.push(x_mat.read(i, 0));
86 }
87
88 Ok(x)
89 }
90
91 pub fn solve_least_squares(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
96 if b.len() != self.m {
97 return Err(LinalgError::DimensionMismatch {
98 expected: (self.m, 1),
99 actual: (b.len(), 1),
100 });
101 }
102
103 let mut b_mat = Mat::zeros(self.m, 1);
105 for (i, &val) in b.iter().enumerate() {
106 b_mat.write(i, 0, val);
107 }
108
109 let x_mat = self.qr.solve_lstsq(&b_mat);
111
112 let mut x = Vec::with_capacity(self.n);
114 for i in 0..self.n {
115 x.push(x_mat.read(i, 0));
116 }
117
118 Ok(x)
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::Matrix;
126
127 #[test]
128 fn test_qr_square() {
129 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
132 m.set(0, 0, 1.0);
133 m.set(0, 1, 2.0);
134 m.set(1, 0, 3.0);
135 m.set(1, 1, 4.0);
136
137 let qr = QRFactorization::new(&m).unwrap();
138 assert_eq!(qr.nrows(), 2);
139 assert_eq!(qr.ncols(), 2);
140
141 let b = vec![5.0, 11.0];
142 let x = qr.solve(&b).unwrap();
143
144 assert!((x[0] - 1.0).abs() < 1e-10);
145 assert!((x[1] - 2.0).abs() < 1e-10);
146 }
147
148 #[test]
149 fn test_qr_overdetermined() {
150 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(3, 2);
155 m.set(0, 0, 1.0);
156 m.set(0, 1, 1.0);
157 m.set(1, 0, 1.0);
158 m.set(1, 1, 2.0);
159 m.set(2, 0, 1.0);
160 m.set(2, 1, 3.0);
161
162 let qr = QRFactorization::new(&m).unwrap();
163 assert_eq!(qr.nrows(), 3);
164 assert_eq!(qr.ncols(), 2);
165
166 let b = vec![1.0, 2.0, 2.0];
167 let x = qr.solve_least_squares(&b).unwrap();
168
169 assert!((x[0] - 2.0 / 3.0).abs() < 1e-10);
170 assert!((x[1] - 0.5).abs() < 1e-10);
171 }
172
173 #[test]
174 fn test_qr_dimension_mismatch() {
175 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
176 m.set(0, 0, 1.0);
177 m.set(0, 1, 2.0);
178 m.set(1, 0, 3.0);
179 m.set(1, 1, 4.0);
180
181 let qr = QRFactorization::new(&m).unwrap();
182
183 let b = vec![1.0, 2.0, 3.0];
185 let result = qr.solve(&b);
186 assert!(result.is_err());
187
188 let result = qr.solve_least_squares(&b);
190 assert!(result.is_err());
191 }
192}