1use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::direct::lu::{LuFactorization, lu_factorize};
11use crate::direct::LuError;
12use crate::traits::{ComplexField, LinearOperator, Preconditioner};
13use ndarray::{Array1, Array2};
14use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
15
16use super::gmres::{GmresConfig, GmresSolution};
17
18#[derive(Debug, Clone)]
29pub struct DeflationSubspace<T: ComplexField> {
30 w_columns: Vec<Array1<T>>,
31 aw_columns: Vec<Array1<T>>,
32 e_lu: LuFactorization<T>,
33}
34
35impl<T: ComplexField> DeflationSubspace<T> {
36 pub fn new<A: LinearOperator<T>>(
44 w_columns: Vec<Array1<T>>,
45 operator: &A,
46 ) -> Result<Self, LuError> {
47 let r = w_columns.len();
48 if r == 0 {
49 return Ok(Self {
50 w_columns,
51 aw_columns: Vec::new(),
52 e_lu: LuFactorization {
53 lu: Array2::from_elem((0, 0), T::zero()),
54 pivots: Vec::new(),
55 n: 0,
56 },
57 });
58 }
59
60 let aw_columns: Vec<Array1<T>> = w_columns.iter().map(|w| operator.apply(w)).collect();
62
63 let mut e = Array2::from_elem((r, r), T::zero());
65 for i in 0..r {
66 for j in 0..r {
67 e[[i, j]] = inner_product(&w_columns[i], &aw_columns[j]);
68 }
69 }
70
71 let e_lu = lu_factorize(&e)?;
72
73 Ok(Self {
74 w_columns,
75 aw_columns,
76 e_lu,
77 })
78 }
79
80 pub fn num_vectors(&self) -> usize {
82 self.w_columns.len()
83 }
84
85 pub fn apply_left_projector(&self, v: &Array1<T>) -> Array1<T> {
87 let r = self.w_columns.len();
88 if r == 0 {
89 return v.clone();
90 }
91
92 let mut wh_v = Array1::from_elem(r, T::zero());
94 for i in 0..r {
95 wh_v[i] = inner_product(&self.w_columns[i], v);
96 }
97
98 let y = self
100 .e_lu
101 .solve(&wh_v)
102 .expect("Deflation matrix E should be non-singular");
103
104 let mut result = v.clone();
106 for i in 0..r {
107 axpy(-y[i], &self.aw_columns[i], &mut result);
108 }
109
110 result
111 }
112
113 pub fn coarse_correction(&self, b: &Array1<T>) -> Array1<T> {
115 let r = self.w_columns.len();
116 let n = b.len();
117 if r == 0 {
118 return Array1::from_elem(n, T::zero());
119 }
120
121 let mut wh_b = Array1::from_elem(r, T::zero());
123 for i in 0..r {
124 wh_b[i] = inner_product(&self.w_columns[i], b);
125 }
126
127 let y = self
129 .e_lu
130 .solve(&wh_b)
131 .expect("Deflation matrix E should be non-singular");
132
133 let mut result = Array1::from_elem(n, T::zero());
135 for i in 0..r {
136 axpy(y[i], &self.w_columns[i], &mut result);
137 }
138
139 result
140 }
141
142 pub fn apply_recovery<A: LinearOperator<T>>(&self, v: &Array1<T>, operator: &A) -> Array1<T> {
144 let r = self.w_columns.len();
145 if r == 0 {
146 return v.clone();
147 }
148
149 let av = operator.apply(v);
151
152 let mut wh_av = Array1::from_elem(r, T::zero());
154 for i in 0..r {
155 wh_av[i] = inner_product(&self.w_columns[i], &av);
156 }
157
158 let y = self
160 .e_lu
161 .solve(&wh_av)
162 .expect("Deflation matrix E should be non-singular");
163
164 let mut result = v.clone();
166 for i in 0..r {
167 axpy(-y[i], &self.w_columns[i], &mut result);
168 }
169
170 result
171 }
172}
173
174pub fn gmres_deflated<T, A>(
179 operator: &A,
180 deflation: &DeflationSubspace<T>,
181 b: &Array1<T>,
182 x0: Option<&Array1<T>>,
183 config: &GmresConfig<T::Real>,
184) -> GmresSolution<T>
185where
186 T: ComplexField,
187 A: LinearOperator<T>,
188{
189 use crate::traits::IdentityPreconditioner;
190 let precond = IdentityPreconditioner;
191 gmres_deflated_preconditioned(operator, &precond, deflation, b, x0, config)
192}
193
194pub fn gmres_deflated_preconditioned<T, A, P>(
202 operator: &A,
203 precond: &P,
204 deflation: &DeflationSubspace<T>,
205 b: &Array1<T>,
206 x0: Option<&Array1<T>>,
207 config: &GmresConfig<T::Real>,
208) -> GmresSolution<T>
209where
210 T: ComplexField,
211 A: LinearOperator<T>,
212 P: Preconditioner<T>,
213{
214 if deflation.num_vectors() == 0 {
216 return super::gmres::gmres_preconditioned_with_guess(operator, precond, b, x0, config);
217 }
218
219 let n = b.len();
220 let m = config.restart;
221
222 let x_c = deflation.coarse_correction(b);
224
225 let mut x_hat = match x0 {
227 Some(guess) => guess.clone(),
228 None => Array1::from_elem(n, T::zero()),
229 };
230
231 let pb = precond.apply(&deflation.apply_left_projector(b));
233 let b_norm = vector_norm(&pb);
234 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
235 if b_norm < tol_threshold {
236 return GmresSolution {
238 x: x_c,
239 iterations: 0,
240 restarts: 0,
241 residual: T::Real::zero(),
242 converged: true,
243 status: crate::traits::SolverStatus::Converged,
244 };
245 }
246
247 let mut total_iterations = 0;
248 let mut restarts = 0;
249
250 for _outer in 0..config.max_iterations {
252 let ax = operator.apply(&x_hat);
254 let residual: Array1<T> = b - &ax;
255 let deflated_residual = deflation.apply_left_projector(&residual);
256 let r = precond.apply(&deflated_residual);
257 let beta = vector_norm(&r);
258
259 let rel_residual = beta / b_norm;
260 if rel_residual < config.tolerance {
261 let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
263 return GmresSolution {
264 x,
265 iterations: total_iterations,
266 restarts,
267 residual: rel_residual,
268 converged: true,
269 status: crate::traits::SolverStatus::Converged,
270 };
271 }
272
273 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
275 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
276
277 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
279
280 let mut cs: Vec<T> = Vec::with_capacity(m);
282 let mut sn: Vec<T> = Vec::with_capacity(m);
283
284 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
286 g[0] = T::from_real(beta);
287
288 let mut inner_converged = false;
289
290 for j in 0..m {
292 total_iterations += 1;
293
294 let av = operator.apply(&v[j]);
296 let pav = deflation.apply_left_projector(&av);
297 let mut w = precond.apply(&pav);
298
299 for i in 0..=j {
301 h[[i, j]] = inner_product(&v[i], &w);
302 let h_ij = h[[i, j]];
303 w = &w - &v[i].mapv(|vi| vi * h_ij);
304 }
305
306 let w_norm = vector_norm(&w);
307 h[[j + 1, j]] = T::from_real(w_norm);
308
309 let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
311 if w_norm < breakdown_tol {
312 inner_converged = true;
313 } else {
314 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
315 }
316
317 for i in 0..j {
319 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
320 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
321 h[[i, j]] = temp;
322 }
323
324 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
326 cs.push(c);
327 sn.push(s);
328
329 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
331 h[[j + 1, j]] = T::zero();
332
333 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
334 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
335 g[j] = temp;
336
337 let abs_residual = g[j + 1].norm();
339 let rel_residual = abs_residual / b_norm;
340 let abs_tol = T::Real::from_f64(1e-20).unwrap();
341
342 if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
343 log::info!(
344 "Deflated GMRES iteration {} (restart {}): relative residual = {:.6e}",
345 total_iterations,
346 restarts,
347 rel_residual.to_f64().unwrap_or(0.0)
348 );
349 }
350
351 if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
352 let y = solve_upper_triangular(&h, &g, j + 1);
354
355 for (i, &yi) in y.iter().enumerate() {
357 x_hat = &x_hat + &v[i].mapv(|vi| vi * yi);
358 }
359
360 let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
362
363 let status =
364 if inner_converged && rel_residual >= config.tolerance && abs_residual >= abs_tol
365 {
366 crate::traits::SolverStatus::Breakdown
367 } else {
368 crate::traits::SolverStatus::Converged
369 };
370
371 return GmresSolution {
372 x,
373 iterations: total_iterations,
374 restarts,
375 residual: rel_residual,
376 converged: true,
377 status,
378 };
379 }
380 }
381
382 let y = solve_upper_triangular(&h, &g, m);
384 for (i, &yi) in y.iter().enumerate() {
385 x_hat = &x_hat + &v[i].mapv(|vi| vi * yi);
386 }
387
388 restarts += 1;
389 }
390
391 let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
393
394 let ax = operator.apply(&x);
396 let r: Array1<T> = b - &ax;
397 let r_norm = vector_norm(&r);
398 let b_true_norm = vector_norm(b);
399 let rel_residual = if b_true_norm > T::Real::from_f64(1e-15).unwrap() {
400 r_norm / b_true_norm
401 } else {
402 r_norm
403 };
404
405 GmresSolution {
406 x,
407 iterations: total_iterations,
408 restarts,
409 residual: rel_residual,
410 converged: false,
411 status: crate::traits::SolverStatus::MaxIterationsReached,
412 }
413}
414
415#[inline]
417fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
418 let tol = T::Real::from_f64(1e-30).unwrap();
419 if b.norm() < tol {
420 return (T::one(), T::zero());
421 }
422 if a.norm() < tol {
423 return (T::zero(), T::one());
424 }
425
426 let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
427 let c = a * T::from_real(T::Real::one() / r);
428 let s = b * T::from_real(T::Real::one() / r);
429
430 (c, s)
431}
432
433fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
435 let mut y = vec![T::zero(); k];
436 let tol = T::Real::from_f64(1e-30).unwrap();
437
438 for i in (0..k).rev() {
439 let mut sum = g[i];
440 for j in (i + 1)..k {
441 sum -= h[[i, j]] * y[j];
442 }
443 if h[[i, i]].norm() > tol {
444 y[i] = sum * h[[i, i]].inv();
445 }
446 }
447
448 y
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use crate::sparse::CsrMatrix;
455 use approx::assert_relative_eq;
456 use ndarray::array;
457 use num_complex::Complex64;
458
459 fn build_clustered_diagonal(n: usize, cluster_center: f64) -> CsrMatrix<Complex64> {
462 let mut dense = Array2::from_elem((n, n), Complex64::new(0.0, 0.0));
463 for i in 0..n {
464 if i < n - 3 {
465 let offset = 0.01 * (i as f64 - (n as f64 - 3.0) / 2.0);
467 dense[[i, i]] = Complex64::new(cluster_center + offset, 0.0);
468 } else {
469 dense[[i, i]] = Complex64::new((i + 1) as f64 * 10.0, 0.0);
471 }
472 }
473 CsrMatrix::from_dense(&dense, 1e-15)
474 }
475
476 #[test]
477 fn test_deflation_subspace_construction() {
478 let dense = array![
480 [Complex64::new(1.0, 0.0), Complex64::ZERO, Complex64::ZERO],
481 [Complex64::ZERO, Complex64::new(2.0, 0.0), Complex64::ZERO],
482 [Complex64::ZERO, Complex64::ZERO, Complex64::new(3.0, 0.0)],
483 ];
484 let a = CsrMatrix::from_dense(&dense, 1e-15);
485
486 let w1 = array![
488 Complex64::new(1.0, 0.0),
489 Complex64::ZERO,
490 Complex64::ZERO
491 ];
492 let deflation = DeflationSubspace::new(vec![w1], &a).unwrap();
493
494 assert_eq!(deflation.num_vectors(), 1);
495
496 let e1 = array![
498 Complex64::new(1.0, 0.0),
499 Complex64::ZERO,
500 Complex64::ZERO
501 ];
502 let pe1 = deflation.apply_left_projector(&e1);
503 let pe1_norm: f64 = pe1.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
504 assert!(pe1_norm < 1e-12, "Deflated eigenvector should be zero");
505
506 let e2 = array![
508 Complex64::ZERO,
509 Complex64::new(1.0, 0.0),
510 Complex64::ZERO
511 ];
512 let pe2 = deflation.apply_left_projector(&e2);
513 let pe2_norm: f64 = pe2.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
514 assert!(pe2_norm > 0.5, "Non-deflated vector should remain");
515 }
516
517 #[test]
518 fn test_deflated_gmres_fewer_iterations() {
519 let n = 20;
520 let cluster = 5.0;
521 let a = build_clustered_diagonal(n, cluster);
522
523 let b = Array1::from_iter((0..n).map(|i| Complex64::new((i + 1) as f64, 0.0)));
525
526 let config = GmresConfig {
527 max_iterations: 200,
528 restart: 30,
529 tolerance: 1e-10,
530 print_interval: 0,
531 };
532
533 let sol_standard = super::super::gmres::gmres(&a, &b, &config);
535
536 let mut w_cols = Vec::new();
538 for i in 0..(n - 3) {
539 let mut w = Array1::from_elem(n, Complex64::ZERO);
540 w[i] = Complex64::new(1.0, 0.0);
541 w_cols.push(w);
542 }
543 let deflation = DeflationSubspace::new(w_cols, &a).unwrap();
544 let sol_deflated = gmres_deflated(&a, &deflation, &b, None, &config);
545
546 assert!(sol_standard.converged, "Standard GMRES should converge");
547 assert!(sol_deflated.converged, "Deflated GMRES should converge");
548
549 assert!(
551 sol_deflated.iterations <= sol_standard.iterations,
552 "Deflated ({}) should use <= iterations than standard ({})",
553 sol_deflated.iterations,
554 sol_standard.iterations
555 );
556
557 let ax = a.matvec(&sol_deflated.x);
559 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
560 assert!(error < 1e-8, "Deflated solution should satisfy Ax = b");
561 }
562
563 #[test]
564 fn test_deflated_gmres_with_preconditioner() {
565 let dense = array![
566 [
567 Complex64::new(4.0, 0.0),
568 Complex64::new(1.0, 0.0),
569 Complex64::ZERO,
570 ],
571 [
572 Complex64::new(1.0, 0.0),
573 Complex64::new(3.0, 0.0),
574 Complex64::new(1.0, 0.0),
575 ],
576 [
577 Complex64::ZERO,
578 Complex64::new(1.0, 0.0),
579 Complex64::new(5.0, 0.0),
580 ],
581 ];
582 let a = CsrMatrix::from_dense(&dense, 1e-15);
583 let b = array![
584 Complex64::new(1.0, 0.0),
585 Complex64::new(2.0, 0.0),
586 Complex64::new(3.0, 0.0)
587 ];
588
589 let precond = crate::preconditioners::DiagonalPreconditioner::from_csr(&a);
591
592 let w1 = array![
594 Complex64::new(0.5, 0.0),
595 Complex64::new(0.7, 0.0),
596 Complex64::new(0.5, 0.0)
597 ];
598 let w1_norm = vector_norm(&w1);
599 let w1_normalized = w1.mapv(|v| v * Complex64::new(1.0 / w1_norm, 0.0));
600
601 let deflation = DeflationSubspace::new(vec![w1_normalized], &a).unwrap();
602
603 let config = GmresConfig {
604 max_iterations: 50,
605 restart: 10,
606 tolerance: 1e-10,
607 print_interval: 0,
608 };
609
610 let sol = gmres_deflated_preconditioned(&a, &precond, &deflation, &b, None, &config);
611 assert!(sol.converged, "Deflated preconditioned GMRES should converge");
612
613 let ax = a.matvec(&sol.x);
615 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
616 assert!(error < 1e-8, "Solution should satisfy Ax = b");
617 }
618
619 #[test]
620 fn test_deflated_gmres_zero_vectors_fallback() {
621 let dense = array![
623 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
624 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
625 ];
626 let a = CsrMatrix::from_dense(&dense, 1e-15);
627 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
628
629 let deflation = DeflationSubspace::new(Vec::new(), &a).unwrap();
630 assert_eq!(deflation.num_vectors(), 0);
631
632 let config = GmresConfig {
633 max_iterations: 50,
634 restart: 10,
635 tolerance: 1e-10,
636 print_interval: 0,
637 };
638
639 let sol = gmres_deflated(&a, &deflation, &b, None, &config);
640 assert!(sol.converged, "Zero-vector deflated GMRES should converge");
641
642 let sol_standard = super::super::gmres::gmres(&a, &b, &config);
644 let error: f64 = (&sol.x - &sol_standard.x)
645 .iter()
646 .map(|e| e.norm_sqr())
647 .sum::<f64>()
648 .sqrt();
649 assert!(
650 error < 1e-8,
651 "Zero-deflation should match standard GMRES solution"
652 );
653 }
654
655 #[test]
656 fn test_coarse_correction_exact_for_deflation_space() {
657 let dense = array![
659 [Complex64::new(2.0, 0.0), Complex64::ZERO],
660 [Complex64::ZERO, Complex64::new(3.0, 0.0)],
661 ];
662 let a = CsrMatrix::from_dense(&dense, 1e-15);
663
664 let w1 = array![Complex64::new(1.0, 0.0), Complex64::ZERO];
666 let w2 = array![Complex64::ZERO, Complex64::new(1.0, 0.0)];
667 let deflation = DeflationSubspace::new(vec![w1, w2], &a).unwrap();
668
669 let b = array![Complex64::new(4.0, 0.0), Complex64::new(9.0, 0.0)];
670 let x_c = deflation.coarse_correction(&b);
671
672 let ax_c = a.matvec(&x_c);
674 let error: f64 = (&ax_c - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
675 assert!(error < 1e-10, "Coarse correction should be exact solution");
676 }
677
678 #[test]
679 fn test_deflated_gmres_f64() {
680 let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
681 let a = CsrMatrix::from_dense(&dense, 1e-15);
682 let b = array![1.0_f64, 2.0];
683
684 let w1 = array![0.7071_f64, 0.7071];
686 let w1_norm = vector_norm(&w1);
687 let w1_normalized = w1.mapv(|v| v / w1_norm);
688 let deflation = DeflationSubspace::new(vec![w1_normalized], &a).unwrap();
689
690 let config = GmresConfig {
691 max_iterations: 50,
692 restart: 10,
693 tolerance: 1e-10,
694 print_interval: 0,
695 };
696
697 let sol = gmres_deflated(&a, &deflation, &b, None, &config);
698 assert!(sol.converged);
699
700 let ax = a.matvec(&sol.x);
701 let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
702 assert_relative_eq!(error, 0.0, epsilon = 1e-8);
703 }
704}