1use std::ops::Mul;
2
3use faer;
4use faer::{linalg::solvers::Solve, sparse, sparse::linalg::solvers};
5
6use crate::linalg::{LinAlgError, LinAlgResult, SparseLinearSolver};
7
8#[derive(Debug, Clone)]
9pub struct SparseQRSolver {
10 factorizer: Option<solvers::Qr<usize, f64>>,
11
12 symbolic_factorization: Option<solvers::SymbolicQr<usize>>,
19
20 hessian: Option<sparse::SparseColMat<usize, f64>>,
24
25 gradient: Option<faer::Mat<f64>>,
29
30 covariance_matrix: Option<faer::Mat<f64>>,
34 standard_errors: Option<faer::Mat<f64>>,
40}
41
42impl SparseQRSolver {
43 pub fn new() -> Self {
44 SparseQRSolver {
45 factorizer: None,
46 symbolic_factorization: None,
47 hessian: None,
48 gradient: None,
49 covariance_matrix: None,
50 standard_errors: None,
51 }
52 }
53
54 pub fn hessian(&self) -> Option<&sparse::SparseColMat<usize, f64>> {
55 self.hessian.as_ref()
56 }
57
58 pub fn gradient(&self) -> Option<&faer::Mat<f64>> {
59 self.gradient.as_ref()
60 }
61
62 pub fn compute_standard_errors(&mut self) -> Option<&faer::Mat<f64>> {
63 if self.covariance_matrix.is_none() {
65 self.compute_covariance_matrix();
66 }
67
68 let n = self.hessian.as_ref().unwrap().ncols();
69 if let Some(cov) = &self.covariance_matrix {
71 let mut std_errors = faer::Mat::zeros(n, 1);
72 for i in 0..n {
73 let diag_val = cov[(i, i)];
74 if diag_val >= 0.0 {
75 std_errors[(i, 0)] = diag_val.sqrt();
76 } else {
77 return None;
79 }
80 }
81 self.standard_errors = Some(std_errors);
82 }
83 self.standard_errors.as_ref()
84 }
85
86 pub fn reset_covariance(&mut self) {
88 self.covariance_matrix = None;
89 self.standard_errors = None;
90 }
91}
92
93impl Default for SparseQRSolver {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl SparseLinearSolver for SparseQRSolver {
100 fn solve_normal_equation(
101 &mut self,
102 residuals: &faer::Mat<f64>,
103 jacobians: &sparse::SparseColMat<usize, f64>,
104 ) -> LinAlgResult<faer::Mat<f64>> {
105 let jt = jacobians.as_ref().transpose();
107 let hessian = jt
108 .to_col_major()
109 .map_err(|_| {
110 LinAlgError::MatrixConversion(
111 "Failed to convert transposed Jacobian to column-major format".to_string(),
112 )
113 })?
114 .mul(jacobians.as_ref());
115
116 let gradient = jacobians.as_ref().transpose().mul(residuals);
118
119 let sym = if let Some(ref cached_sym) = self.symbolic_factorization {
122 cached_sym.clone()
127 } else {
128 let new_sym = solvers::SymbolicQr::try_new(hessian.symbolic()).map_err(|_| {
130 LinAlgError::FactorizationFailed("Symbolic QR decomposition failed".to_string())
131 })?;
132 self.symbolic_factorization = Some(new_sym.clone());
134 new_sym
135 };
136
137 let qr = solvers::Qr::try_new_with_symbolic(sym, hessian.as_ref())
139 .map_err(|_| LinAlgError::SingularMatrix)?;
140
141 let dx = qr.solve(-&gradient);
143 self.hessian = Some(hessian);
144 self.gradient = Some(gradient);
145 self.factorizer = Some(qr);
146
147 Ok(dx)
148 }
149
150 fn solve_augmented_equation(
151 &mut self,
152 residuals: &faer::Mat<f64>,
153 jacobians: &sparse::SparseColMat<usize, f64>,
154 lambda: f64,
155 ) -> LinAlgResult<faer::Mat<f64>> {
156 let n = jacobians.ncols();
157
158 let jt = jacobians.as_ref().transpose();
160 let hessian = jt
161 .to_col_major()
162 .map_err(|_| {
163 LinAlgError::MatrixConversion(
164 "Failed to convert transposed Jacobian to column-major format".to_string(),
165 )
166 })?
167 .mul(jacobians.as_ref());
168
169 let gradient = jacobians.as_ref().transpose().mul(residuals);
171
172 let mut lambda_i_triplets = Vec::with_capacity(n);
174 for i in 0..n {
175 lambda_i_triplets.push(faer::sparse::Triplet::new(i, i, lambda));
176 }
177 let lambda_i = sparse::SparseColMat::try_new_from_triplets(n, n, &lambda_i_triplets)
178 .map_err(|e| {
179 LinAlgError::SparseMatrixCreation(format!(
180 "Failed to create lambda*I matrix: {:?}",
181 e
182 ))
183 })?;
184
185 let augmented_hessian = hessian.as_ref() + lambda_i;
186
187 let sym = if let Some(ref cached_sym) = self.symbolic_factorization {
192 cached_sym.clone()
193 } else {
194 let new_sym =
196 solvers::SymbolicQr::try_new(augmented_hessian.symbolic()).map_err(|_| {
197 LinAlgError::FactorizationFailed(
198 "Symbolic QR decomposition failed for augmented system".to_string(),
199 )
200 })?;
201 self.symbolic_factorization = Some(new_sym.clone());
203 new_sym
204 };
205
206 let qr = solvers::Qr::try_new_with_symbolic(sym, augmented_hessian.as_ref())
208 .map_err(|_| LinAlgError::SingularMatrix)?;
209
210 let dx = qr.solve(-&gradient);
211 self.hessian = Some(hessian);
212 self.gradient = Some(gradient);
213 self.factorizer = Some(qr);
214
215 Ok(dx)
216 }
217
218 fn get_hessian(&self) -> Option<&sparse::SparseColMat<usize, f64>> {
219 self.hessian.as_ref()
220 }
221
222 fn get_gradient(&self) -> Option<&faer::Mat<f64>> {
223 self.gradient.as_ref()
224 }
225
226 fn compute_covariance_matrix(&mut self) -> Option<&faer::Mat<f64>> {
227 if self.factorizer.is_some()
229 && self.hessian.is_some()
230 && self.covariance_matrix.is_none()
231 && let (Some(factorizer), Some(hessian)) = (&self.factorizer, &self.hessian)
232 {
233 let n = hessian.ncols();
234 let identity = faer::Mat::identity(n, n);
236
237 let cov_matrix = factorizer.solve(&identity);
239 self.covariance_matrix = Some(cov_matrix);
240 }
241 self.covariance_matrix.as_ref()
242 }
243
244 fn get_covariance_matrix(&self) -> Option<&faer::Mat<f64>> {
245 self.covariance_matrix.as_ref()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use faer::Mat;
253 use faer::sparse::SparseColMat;
254
255 const TOLERANCE: f64 = 1e-10;
256
257 fn create_test_data() -> (SparseColMat<usize, f64>, Mat<f64>) {
259 let triplets = vec![
261 faer::sparse::Triplet::new(0, 0, 1.0),
262 faer::sparse::Triplet::new(0, 1, 0.0),
263 faer::sparse::Triplet::new(0, 2, 1.0),
264 faer::sparse::Triplet::new(1, 0, 0.0),
265 faer::sparse::Triplet::new(1, 1, 1.0),
266 faer::sparse::Triplet::new(1, 2, 1.0),
267 faer::sparse::Triplet::new(2, 0, 1.0),
268 faer::sparse::Triplet::new(2, 1, 1.0),
269 faer::sparse::Triplet::new(2, 2, 0.0),
270 faer::sparse::Triplet::new(3, 0, 1.0),
271 faer::sparse::Triplet::new(3, 1, 0.0),
272 faer::sparse::Triplet::new(3, 2, 0.0),
273 ];
274 let jacobian = SparseColMat::try_new_from_triplets(4, 3, &triplets).unwrap();
275
276 let residuals = Mat::from_fn(4, 1, |i, _| (i + 1) as f64);
277
278 (jacobian, residuals)
279 }
280
281 #[test]
283 fn test_qr_solver_creation() {
284 let solver = SparseQRSolver::new();
285 assert!(solver.factorizer.is_none());
286
287 let default_solver = SparseQRSolver::default();
288 assert!(default_solver.factorizer.is_none());
289 }
290
291 #[test]
293 fn test_qr_solve_normal_equation() {
294 let mut solver = SparseQRSolver::new();
295 let (jacobian, residuals) = create_test_data();
296
297 let result = solver.solve_normal_equation(&residuals, &jacobian);
298 assert!(result.is_ok());
299
300 let solution = result.unwrap();
301 assert_eq!(solution.nrows(), 3); assert_eq!(solution.ncols(), 1);
303
304 assert!(solver.factorizer.is_some());
306 }
307
308 #[test]
310 fn test_qr_factorizer_caching() {
311 let mut solver = SparseQRSolver::new();
312 let (jacobian, residuals) = create_test_data();
313
314 let result1 = solver.solve_normal_equation(&residuals, &jacobian);
316 assert!(result1.is_ok());
317 assert!(solver.factorizer.is_some());
318
319 let result2 = solver.solve_normal_equation(&residuals, &jacobian);
321 assert!(result2.is_ok());
322
323 let sol1 = result1.unwrap();
325 let sol2 = result2.unwrap();
326 for i in 0..sol1.nrows() {
327 assert!((sol1[(i, 0)] - sol2[(i, 0)]).abs() < TOLERANCE);
328 }
329 }
330
331 #[test]
333 fn test_qr_solve_augmented_equation() {
334 let mut solver = SparseQRSolver::new();
335 let (jacobian, residuals) = create_test_data();
336 let lambda = 0.1;
337
338 let result = solver.solve_augmented_equation(&residuals, &jacobian, lambda);
339 assert!(result.is_ok());
340
341 let solution = result.unwrap();
342 assert_eq!(solution.nrows(), 3); assert_eq!(solution.ncols(), 1);
344 }
345
346 #[test]
348 fn test_qr_augmented_different_lambdas() {
349 let mut solver = SparseQRSolver::new();
350 let (jacobian, residuals) = create_test_data();
351
352 let lambda1 = 0.01;
353 let lambda2 = 1.0;
354
355 let result1 = solver.solve_augmented_equation(&residuals, &jacobian, lambda1);
356 let result2 = solver.solve_augmented_equation(&residuals, &jacobian, lambda2);
357
358 assert!(result1.is_ok());
359 assert!(result2.is_ok());
360
361 let sol1 = result1.unwrap();
363 let sol2 = result2.unwrap();
364 let mut different = false;
365 for i in 0..sol1.nrows() {
366 if (sol1[(i, 0)] - sol2[(i, 0)]).abs() > TOLERANCE {
367 different = true;
368 break;
369 }
370 }
371 assert!(
372 different,
373 "Solutions should differ with different lambda values"
374 );
375 }
376
377 #[test]
379 fn test_qr_rank_deficient_matrix() {
380 let mut solver = SparseQRSolver::new();
381
382 let triplets = vec![
384 faer::sparse::Triplet::new(0, 0, 1.0),
385 faer::sparse::Triplet::new(0, 1, 2.0),
386 faer::sparse::Triplet::new(0, 2, 3.0),
387 faer::sparse::Triplet::new(1, 0, 2.0),
388 faer::sparse::Triplet::new(1, 1, 4.0),
389 faer::sparse::Triplet::new(1, 2, 6.0), faer::sparse::Triplet::new(2, 0, 0.0),
391 faer::sparse::Triplet::new(2, 1, 0.0),
392 faer::sparse::Triplet::new(2, 2, 1.0),
393 ];
394 let jacobian = SparseColMat::try_new_from_triplets(3, 3, &triplets).unwrap();
395 let residuals = Mat::from_fn(3, 1, |i, _| i as f64);
396
397 let result = solver.solve_normal_equation(&residuals, &jacobian);
399 assert!(result.is_ok());
400 }
401
402 #[test]
404 fn test_qr_augmented_system_structure() {
405 let mut solver = SparseQRSolver::new();
406
407 let triplets = vec![
409 faer::sparse::Triplet::new(0, 0, 1.0),
410 faer::sparse::Triplet::new(0, 1, 0.0),
411 faer::sparse::Triplet::new(1, 0, 0.0),
412 faer::sparse::Triplet::new(1, 1, 1.0),
413 ];
414 let jacobian = SparseColMat::try_new_from_triplets(2, 2, &triplets).unwrap();
415 let residuals = Mat::from_fn(2, 1, |i, _| (i + 1) as f64);
416 let lambda = 0.5;
417
418 let result = solver.solve_augmented_equation(&residuals, &jacobian, lambda);
419 assert!(result.is_ok());
420
421 let solution = result.unwrap();
422 assert_eq!(solution.nrows(), 2); assert_eq!(solution.ncols(), 1);
424 }
425
426 #[test]
428 fn test_qr_numerical_accuracy() {
429 let mut solver = SparseQRSolver::new();
430
431 let triplets = vec![
433 faer::sparse::Triplet::new(0, 0, 1.0),
434 faer::sparse::Triplet::new(1, 1, 1.0),
435 faer::sparse::Triplet::new(2, 2, 1.0),
436 ];
437 let jacobian = SparseColMat::try_new_from_triplets(3, 3, &triplets).unwrap();
438 let residuals = Mat::from_fn(3, 1, |i, _| -((i + 1) as f64)); let result = solver.solve_normal_equation(&residuals, &jacobian);
441 assert!(result.is_ok());
442
443 let solution = result.unwrap();
444 for i in 0..3 {
446 let expected = (i + 1) as f64;
447 assert!(
448 (solution[(i, 0)] - expected).abs() < TOLERANCE,
449 "Expected {}, got {}",
450 expected,
451 solution[(i, 0)]
452 );
453 }
454 }
455
456 #[test]
458 fn test_qr_solver_clone() {
459 let solver1 = SparseQRSolver::new();
460 let solver2 = solver1.clone();
461
462 assert!(solver1.factorizer.is_none());
463 assert!(solver2.factorizer.is_none());
464 }
465
466 #[test]
468 fn test_qr_zero_lambda_augmented() {
469 let mut solver = SparseQRSolver::new();
470 let (jacobian, residuals) = create_test_data();
471
472 let normal_result = solver.solve_normal_equation(&residuals, &jacobian);
473 let augmented_result = solver.solve_augmented_equation(&residuals, &jacobian, 0.0);
474
475 assert!(normal_result.is_ok());
476 assert!(augmented_result.is_ok());
477
478 let normal_sol = normal_result.unwrap();
479 let augmented_sol = augmented_result.unwrap();
480
481 for i in 0..normal_sol.nrows() {
483 assert!(
484 (normal_sol[(i, 0)] - augmented_sol[(i, 0)]).abs() < 1e-8,
485 "Zero lambda augmented should match normal equation"
486 );
487 }
488 }
489
490 #[test]
492 fn test_qr_covariance_computation() {
493 let mut solver = SparseQRSolver::new();
494 let (jacobian, residuals) = create_test_data();
495
496 let result = solver.solve_normal_equation(&residuals, &jacobian);
498 assert!(result.is_ok());
499
500 let cov_matrix = solver.compute_covariance_matrix();
502 assert!(cov_matrix.is_some());
503
504 let cov = cov_matrix.unwrap();
505 assert_eq!(cov.nrows(), 3); assert_eq!(cov.ncols(), 3);
507
508 for i in 0..3 {
510 for j in 0..3 {
511 assert!(
512 (cov[(i, j)] - cov[(j, i)]).abs() < TOLERANCE,
513 "Covariance matrix should be symmetric"
514 );
515 }
516 }
517
518 for i in 0..3 {
520 assert!(
521 cov[(i, i)] > 0.0,
522 "Diagonal elements (variances) should be positive"
523 );
524 }
525 }
526
527 #[test]
529 fn test_qr_standard_errors_computation() {
530 let mut solver = SparseQRSolver::new();
531 let (jacobian, residuals) = create_test_data();
532
533 let result = solver.solve_normal_equation(&residuals, &jacobian);
535 assert!(result.is_ok());
536
537 solver.compute_standard_errors();
539
540 assert!(solver.covariance_matrix.is_some());
542 assert!(solver.standard_errors.is_some());
543
544 let cov = solver.covariance_matrix.as_ref().unwrap();
545 let errors = solver.standard_errors.as_ref().unwrap();
546
547 assert_eq!(errors.nrows(), 3); assert_eq!(errors.ncols(), 1);
549
550 for i in 0..3 {
552 assert!(errors[(i, 0)] > 0.0, "Standard errors should be positive");
553 }
554
555 for i in 0..3 {
557 let expected_std_error = cov[(i, i)].sqrt();
558 assert!(
559 (errors[(i, 0)] - expected_std_error).abs() < TOLERANCE,
560 "Standard error should equal sqrt of covariance diagonal"
561 );
562 }
563 }
564
565 #[test]
567 fn test_qr_covariance_well_conditioned() {
568 let mut solver = SparseQRSolver::new();
569
570 let triplets = vec![
572 faer::sparse::Triplet::new(0, 0, 2.0),
573 faer::sparse::Triplet::new(0, 1, 0.0),
574 faer::sparse::Triplet::new(1, 0, 0.0),
575 faer::sparse::Triplet::new(1, 1, 3.0),
576 ];
577 let jacobian = SparseColMat::try_new_from_triplets(2, 2, &triplets).unwrap();
578 let residuals = Mat::from_fn(2, 1, |i, _| (i + 1) as f64);
579
580 let result = solver.solve_normal_equation(&residuals, &jacobian);
581 assert!(result.is_ok());
582
583 let cov_matrix = solver.compute_covariance_matrix();
584 assert!(cov_matrix.is_some());
585
586 let cov = cov_matrix.unwrap();
587 assert!((cov[(0, 0)] - 0.25).abs() < TOLERANCE);
590 assert!((cov[(1, 1)] - 1.0 / 9.0).abs() < TOLERANCE);
591 assert!(cov[(0, 1)].abs() < TOLERANCE);
592 assert!(cov[(1, 0)].abs() < TOLERANCE);
593 }
594
595 #[test]
597 fn test_qr_covariance_caching() {
598 let mut solver = SparseQRSolver::new();
599 let (jacobian, residuals) = create_test_data();
600
601 let result = solver.solve_normal_equation(&residuals, &jacobian);
603 assert!(result.is_ok());
604
605 solver.compute_covariance_matrix();
607 assert!(solver.covariance_matrix.is_some());
608
609 let cov1_ptr = solver.covariance_matrix.as_ref().unwrap().as_ptr();
611
612 solver.compute_covariance_matrix();
614 assert!(solver.covariance_matrix.is_some());
615
616 let cov2_ptr = solver.covariance_matrix.as_ref().unwrap().as_ptr();
618
619 assert_eq!(cov1_ptr, cov2_ptr, "Covariance matrix should be cached");
621 }
622
623 #[test]
625 fn test_qr_covariance_singular_system() {
626 let mut solver = SparseQRSolver::new();
627
628 let triplets = vec![
630 faer::sparse::Triplet::new(0, 0, 1.0),
631 faer::sparse::Triplet::new(0, 1, 2.0),
632 faer::sparse::Triplet::new(1, 0, 2.0),
633 faer::sparse::Triplet::new(1, 1, 4.0), ];
635 let jacobian = SparseColMat::try_new_from_triplets(2, 2, &triplets).unwrap();
636 let residuals = Mat::from_fn(2, 1, |i, _| i as f64);
637
638 let result = solver.solve_normal_equation(&residuals, &jacobian);
640 if result.is_ok() {
641 let cov_matrix = solver.compute_covariance_matrix();
643 if let Some(cov) = cov_matrix {
645 assert!(cov.nrows() == 2);
647 assert!(cov.ncols() == 2);
648 }
649 }
650 }
651}