1use faer::{Mat, linalg::solvers::Solve};
2
3use crate::linalg::{DenseMode, LinAlgResult, LinearSolver};
4
5#[derive(Debug, Clone)]
11pub struct DenseQRSolver {
12 factorizer: Option<faer::linalg::solvers::ColPivQr<f64>>,
14
15 hessian: Option<Mat<f64>>,
17
18 gradient: Option<Mat<f64>>,
20
21 covariance_matrix: Option<Mat<f64>>,
23
24 standard_errors: Option<Mat<f64>>,
26}
27
28impl DenseQRSolver {
29 pub fn new() -> Self {
30 Self {
31 factorizer: None,
32 hessian: None,
33 gradient: None,
34 covariance_matrix: None,
35 standard_errors: None,
36 }
37 }
38
39 pub fn hessian(&self) -> Option<&Mat<f64>> {
40 self.hessian.as_ref()
41 }
42
43 pub fn gradient(&self) -> Option<&Mat<f64>> {
44 self.gradient.as_ref()
45 }
46
47 pub fn compute_standard_errors(&mut self) -> Option<&Mat<f64>> {
49 if self.covariance_matrix.is_none() {
50 LinearSolver::<DenseMode>::compute_covariance_matrix(self);
51 }
52
53 let n = self.hessian.as_ref()?.ncols();
54 if let Some(cov) = &self.covariance_matrix {
55 let mut std_errors = Mat::zeros(n, 1);
56 for i in 0..n {
57 let diag_val = cov[(i, i)];
58 if diag_val >= 0.0 {
59 std_errors[(i, 0)] = diag_val.sqrt();
60 } else {
61 return None;
62 }
63 }
64 self.standard_errors = Some(std_errors);
65 }
66 self.standard_errors.as_ref()
67 }
68
69 pub fn reset_covariance(&mut self) {
71 self.covariance_matrix = None;
72 self.standard_errors = None;
73 }
74
75 fn solve_dense_normal(
77 &mut self,
78 residuals: &Mat<f64>,
79 jacobian: &Mat<f64>,
80 ) -> LinAlgResult<Mat<f64>> {
81 let hessian = jacobian.transpose() * jacobian;
83 let gradient = jacobian.transpose() * residuals;
85
86 let qr = hessian.as_ref().col_piv_qr();
88
89 let dx = qr.solve(-&gradient);
91
92 self.factorizer = Some(qr);
93 self.hessian = Some(hessian);
94 self.gradient = Some(gradient);
95 self.covariance_matrix = None;
96 self.standard_errors = None;
97
98 Ok(dx)
99 }
100
101 fn solve_dense_augmented(
103 &mut self,
104 residuals: &Mat<f64>,
105 jacobian: &Mat<f64>,
106 lambda: f64,
107 ) -> LinAlgResult<Mat<f64>> {
108 let hessian = jacobian.transpose() * jacobian;
110 let gradient = jacobian.transpose() * residuals;
112
113 let n = hessian.nrows();
115 let mut augmented = hessian.clone();
116 for i in 0..n {
117 augmented[(i, i)] += lambda;
118 }
119
120 let qr = augmented.as_ref().col_piv_qr();
122
123 let dx = qr.solve(-&gradient);
125
126 self.factorizer = Some(qr);
128 self.hessian = Some(hessian);
129 self.gradient = Some(gradient);
130 self.covariance_matrix = None;
131 self.standard_errors = None;
132
133 Ok(dx)
134 }
135}
136
137impl Default for DenseQRSolver {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143impl LinearSolver<DenseMode> for DenseQRSolver {
148 fn solve_normal_equation(
149 &mut self,
150 residuals: &Mat<f64>,
151 jacobian: &Mat<f64>,
152 ) -> LinAlgResult<Mat<f64>> {
153 self.solve_dense_normal(residuals, jacobian)
154 }
155
156 fn solve_augmented_equation(
157 &mut self,
158 residuals: &Mat<f64>,
159 jacobian: &Mat<f64>,
160 lambda: f64,
161 ) -> LinAlgResult<Mat<f64>> {
162 self.solve_dense_augmented(residuals, jacobian, lambda)
163 }
164
165 fn get_hessian(&self) -> Option<&Mat<f64>> {
166 self.hessian.as_ref()
167 }
168
169 fn get_gradient(&self) -> Option<&Mat<f64>> {
170 self.gradient.as_ref()
171 }
172
173 fn compute_covariance_matrix(&mut self) -> Option<&Mat<f64>> {
174 if self.covariance_matrix.is_none()
175 && let Some(hessian) = &self.hessian
176 && let Some(factorizer) = &self.factorizer
177 {
178 let n = hessian.nrows();
179 let identity = Mat::identity(n, n);
180 let cov = factorizer.solve(&identity);
181 self.covariance_matrix = Some(cov);
182 }
183 self.covariance_matrix.as_ref()
184 }
185
186 fn get_covariance_matrix(&self) -> Option<&Mat<f64>> {
187 self.covariance_matrix.as_ref()
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 const TOLERANCE: f64 = 1e-10;
196
197 type TestResult = Result<(), Box<dyn std::error::Error>>;
198
199 fn create_test_data() -> (Mat<f64>, Mat<f64>) {
200 let mut j = Mat::zeros(4, 2);
202 j[(0, 0)] = 2.0;
203 j[(0, 1)] = 1.0;
204 j[(1, 0)] = 1.0;
205 j[(1, 1)] = 3.0;
206 j[(2, 0)] = 1.0;
207 j[(2, 1)] = 1.0;
208 j[(3, 0)] = 0.5;
209 j[(3, 1)] = 2.0;
210
211 let mut r = Mat::zeros(4, 1);
212 r[(0, 0)] = 1.0;
213 r[(1, 0)] = 2.0;
214 r[(2, 0)] = 0.5;
215 r[(3, 0)] = 1.5;
216
217 (j, r)
218 }
219
220 #[test]
221 fn test_dense_qr_solver_creation() {
222 let solver = DenseQRSolver::new();
223 assert!(solver.factorizer.is_none());
224 assert!(solver.hessian.is_none());
225 assert!(solver.gradient.is_none());
226
227 let default_solver = DenseQRSolver::default();
228 assert!(default_solver.factorizer.is_none());
229 }
230
231 #[test]
232 fn test_dense_qr_solve_normal_equation() -> TestResult {
233 let (j, r) = create_test_data();
234 let mut solver = DenseQRSolver::new();
235
236 let dx = LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j)?;
237
238 let jtj = j.transpose() * &j;
240 let jtr = j.transpose() * &r;
241 let residual = &jtj * &dx + &jtr;
242
243 for i in 0..dx.nrows() {
244 assert!(
245 residual[(i, 0)].abs() < TOLERANCE,
246 "Residual at index {i}: {}",
247 residual[(i, 0)]
248 );
249 }
250
251 assert!(solver.hessian.is_some());
252 assert!(solver.gradient.is_some());
253 assert!(solver.factorizer.is_some());
254
255 Ok(())
256 }
257
258 #[test]
259 fn test_dense_qr_solve_augmented_equation() -> TestResult {
260 let (j, r) = create_test_data();
261 let lambda = 0.1;
262 let mut solver = DenseQRSolver::new();
263
264 let dx = LinearSolver::<DenseMode>::solve_augmented_equation(&mut solver, &r, &j, lambda)?;
265
266 let mut jtj = j.transpose() * &j;
268 let jtr = j.transpose() * &r;
269 for i in 0..jtj.nrows() {
270 jtj[(i, i)] += lambda;
271 }
272 let residual = &jtj * &dx + &jtr;
273
274 for i in 0..dx.nrows() {
275 assert!(
276 residual[(i, 0)].abs() < TOLERANCE,
277 "Residual at index {i}: {}",
278 residual[(i, 0)]
279 );
280 }
281
282 Ok(())
283 }
284
285 #[test]
286 fn test_dense_qr_augmented_different_lambdas() -> TestResult {
287 let (j, r) = create_test_data();
288 let mut solver = DenseQRSolver::new();
289
290 let dx1 = LinearSolver::<DenseMode>::solve_augmented_equation(&mut solver, &r, &j, 0.01)?;
291 let dx2 = LinearSolver::<DenseMode>::solve_augmented_equation(&mut solver, &r, &j, 1.0)?;
292
293 let mut different = false;
294 for i in 0..dx1.nrows() {
295 if (dx1[(i, 0)] - dx2[(i, 0)]).abs() > TOLERANCE {
296 different = true;
297 break;
298 }
299 }
300 assert!(
301 different,
302 "Solutions should differ with different lambda values"
303 );
304
305 Ok(())
306 }
307
308 #[test]
309 fn test_dense_qr_rank_deficient_matrix() -> TestResult {
310 let mut solver = DenseQRSolver::new();
311
312 let mut j = Mat::zeros(3, 3);
314 j[(0, 0)] = 1.0;
315 j[(0, 1)] = 2.0;
316 j[(0, 2)] = 3.0;
317 j[(1, 0)] = 2.0;
318 j[(1, 1)] = 4.0;
319 j[(1, 2)] = 6.0;
320 j[(2, 0)] = 0.0;
321 j[(2, 1)] = 0.0;
322 j[(2, 2)] = 1.0;
323
324 let mut r = Mat::zeros(3, 1);
325 r[(0, 0)] = 1.0;
326 r[(1, 0)] = 2.0;
327 r[(2, 0)] = 3.0;
328
329 let result = LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j);
330 assert!(result.is_ok(), "QR should handle rank-deficient matrices");
331
332 Ok(())
333 }
334
335 #[test]
336 fn test_dense_qr_numerical_accuracy() -> TestResult {
337 let mut solver = DenseQRSolver::new();
338
339 let mut j = Mat::zeros(3, 3);
341 j[(0, 0)] = 1.0;
342 j[(1, 1)] = 1.0;
343 j[(2, 2)] = 1.0;
344
345 let mut r = Mat::zeros(3, 1);
346 r[(0, 0)] = -1.0;
347 r[(1, 0)] = -2.0;
348 r[(2, 0)] = -3.0;
349
350 let dx = LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j)?;
351
352 for i in 0..3 {
353 let expected = (i + 1) as f64;
354 assert!(
355 (dx[(i, 0)] - expected).abs() < TOLERANCE,
356 "Expected {expected}, got {}",
357 dx[(i, 0)]
358 );
359 }
360
361 Ok(())
362 }
363
364 #[test]
365 fn test_dense_qr_solver_clone() {
366 let solver1 = DenseQRSolver::new();
367 let solver2 = solver1.clone();
368
369 assert!(solver1.factorizer.is_none());
370 assert!(solver2.factorizer.is_none());
371 }
372
373 #[test]
374 fn test_dense_qr_zero_lambda_augmented() -> TestResult {
375 let (j, r) = create_test_data();
376 let mut solver = DenseQRSolver::new();
377
378 let normal_dx = LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j)?;
379 let augmented_dx =
380 LinearSolver::<DenseMode>::solve_augmented_equation(&mut solver, &r, &j, 0.0)?;
381
382 for i in 0..normal_dx.nrows() {
383 assert!(
384 (normal_dx[(i, 0)] - augmented_dx[(i, 0)]).abs() < 1e-8,
385 "Zero-lambda augmented should match normal equation"
386 );
387 }
388
389 Ok(())
390 }
391
392 #[test]
393 fn test_dense_qr_covariance_computation() -> TestResult {
394 let (j, r) = create_test_data();
395 let mut solver = DenseQRSolver::new();
396
397 LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j)?;
398
399 let cov = LinearSolver::<DenseMode>::compute_covariance_matrix(&mut solver)
400 .ok_or("covariance should be computable")?;
401 let n = cov.nrows();
402
403 for i in 0..n {
405 for k in 0..n {
406 assert!(
407 (cov[(i, k)] - cov[(k, i)]).abs() < TOLERANCE,
408 "Covariance not symmetric at ({i}, {k})"
409 );
410 }
411 }
412
413 for i in 0..n {
415 assert!(cov[(i, i)] > 0.0, "Diagonal entry {i} should be positive");
416 }
417
418 Ok(())
419 }
420
421 #[test]
422 fn test_dense_qr_standard_errors_computation() -> TestResult {
423 let (j, r) = create_test_data();
424 let mut solver = DenseQRSolver::new();
425
426 LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j)?;
427
428 let errors = solver
430 .compute_standard_errors()
431 .ok_or("standard errors should be computable")?
432 .clone();
433 let cov = solver
434 .covariance_matrix
435 .as_ref()
436 .ok_or("covariance matrix not available")?;
437
438 assert_eq!(errors.nrows(), cov.nrows());
439 assert_eq!(errors.ncols(), 1);
440
441 for i in 0..errors.nrows() {
442 assert!(
443 errors[(i, 0)] > 0.0,
444 "Standard error at {i} should be positive"
445 );
446 let expected = cov[(i, i)].sqrt();
447 assert!(
448 (errors[(i, 0)] - expected).abs() < TOLERANCE,
449 "Standard error should equal sqrt of covariance diagonal"
450 );
451 }
452
453 Ok(())
454 }
455
456 #[test]
457 fn test_dense_qr_covariance_well_conditioned() -> TestResult {
458 let mut solver = DenseQRSolver::new();
459
460 let mut j = Mat::zeros(2, 2);
462 j[(0, 0)] = 2.0;
463 j[(1, 1)] = 3.0;
464
465 let mut r = Mat::zeros(2, 1);
466 r[(0, 0)] = 1.0;
467 r[(1, 0)] = 2.0;
468
469 LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j)?;
470
471 let cov = LinearSolver::<DenseMode>::compute_covariance_matrix(&mut solver)
472 .ok_or("covariance computation failed")?;
473 assert!(
474 (cov[(0, 0)] - 0.25).abs() < TOLERANCE,
475 "cov[0,0] should be 0.25"
476 );
477 assert!(
478 (cov[(1, 1)] - 1.0 / 9.0).abs() < TOLERANCE,
479 "cov[1,1] should be 1/9"
480 );
481 assert!(cov[(0, 1)].abs() < TOLERANCE);
482 assert!(cov[(1, 0)].abs() < TOLERANCE);
483
484 Ok(())
485 }
486
487 #[test]
488 fn test_dense_qr_covariance_caching() -> TestResult {
489 let (j, r) = create_test_data();
490 let mut solver = DenseQRSolver::new();
491
492 LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j)?;
493
494 LinearSolver::<DenseMode>::compute_covariance_matrix(&mut solver);
495 let ptr1 = solver
496 .covariance_matrix
497 .as_ref()
498 .ok_or("covariance not cached after first call")?
499 .as_ptr();
500
501 LinearSolver::<DenseMode>::compute_covariance_matrix(&mut solver);
503 let ptr2 = solver
504 .covariance_matrix
505 .as_ref()
506 .ok_or("covariance not cached after second call")?
507 .as_ptr();
508
509 assert_eq!(ptr1, ptr2, "Covariance matrix should be cached");
510
511 Ok(())
512 }
513
514 #[test]
515 fn test_dense_qr_covariance_singular_system() -> TestResult {
516 let mut solver = DenseQRSolver::new();
517
518 let mut j = Mat::zeros(2, 2);
520 j[(0, 0)] = 1.0;
521 j[(0, 1)] = 2.0;
522 j[(1, 0)] = 2.0;
523 j[(1, 1)] = 4.0;
524
525 let mut r = Mat::zeros(2, 1);
526 r[(0, 0)] = 0.0;
527 r[(1, 0)] = 1.0;
528
529 let result = LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j);
530 if result.is_ok() {
531 let cov = LinearSolver::<DenseMode>::compute_covariance_matrix(&mut solver);
532 if let Some(cov) = cov {
533 assert_eq!(cov.nrows(), 2);
534 assert_eq!(cov.ncols(), 2);
535 }
536 }
537
538 Ok(())
539 }
540
541 #[test]
542 fn test_dense_qr_reset_covariance() -> TestResult {
543 let (j, r) = create_test_data();
544 let mut solver = DenseQRSolver::new();
545
546 LinearSolver::<DenseMode>::solve_normal_equation(&mut solver, &r, &j)?;
547 LinearSolver::<DenseMode>::compute_covariance_matrix(&mut solver);
548 assert!(solver.covariance_matrix.is_some());
549
550 solver.reset_covariance();
551 assert!(solver.covariance_matrix.is_none());
552 assert!(solver.standard_errors.is_none());
553
554 Ok(())
555 }
556}