1use crate::blas_helpers::{axpy, inner_product, vector_norm};
10use crate::direct::LuError;
11use crate::direct::lu::{LuFactorization, lu_factorize};
12use crate::traits::{ComplexField, LinearOperator, Preconditioner};
13use ndarray::{Array1, Array2};
14use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
15
16use super::gmres::{GmresConfig, GmresSolution};
17
18#[derive(Debug, Clone)]
29pub struct DeflationSubspace<T: ComplexField> {
30 w_columns: Vec<Array1<T>>,
31 aw_columns: Vec<Array1<T>>,
32 e_lu: LuFactorization<T>,
33}
34
35impl<T: ComplexField> DeflationSubspace<T> {
36 pub fn new<A: LinearOperator<T>>(
44 w_columns: Vec<Array1<T>>,
45 operator: &A,
46 ) -> Result<Self, LuError> {
47 let r = w_columns.len();
48 if r == 0 {
49 return Ok(Self {
50 w_columns,
51 aw_columns: Vec::new(),
52 e_lu: LuFactorization {
53 lu: Array2::from_elem((0, 0), T::zero()),
54 pivots: Vec::new(),
55 n: 0,
56 },
57 });
58 }
59
60 let aw_columns: Vec<Array1<T>> = w_columns.iter().map(|w| operator.apply(w)).collect();
62
63 let mut e = Array2::from_elem((r, r), T::zero());
65 for i in 0..r {
66 for j in 0..r {
67 e[[i, j]] = inner_product(&w_columns[i], &aw_columns[j]);
68 }
69 }
70
71 let e_lu = lu_factorize(&e)?;
72
73 Ok(Self {
74 w_columns,
75 aw_columns,
76 e_lu,
77 })
78 }
79
80 pub fn num_vectors(&self) -> usize {
82 self.w_columns.len()
83 }
84
85 pub fn apply_left_projector(&self, v: &Array1<T>) -> Array1<T> {
87 let r = self.w_columns.len();
88 if r == 0 {
89 return v.clone();
90 }
91
92 let mut wh_v = Array1::from_elem(r, T::zero());
94 for i in 0..r {
95 wh_v[i] = inner_product(&self.w_columns[i], v);
96 }
97
98 let y = self
100 .e_lu
101 .solve(&wh_v)
102 .expect("Deflation matrix E should be non-singular");
103
104 let mut result = v.clone();
106 for i in 0..r {
107 axpy(-y[i], &self.aw_columns[i], &mut result);
108 }
109
110 result
111 }
112
113 pub fn coarse_correction(&self, b: &Array1<T>) -> Array1<T> {
115 let r = self.w_columns.len();
116 let n = b.len();
117 if r == 0 {
118 return Array1::from_elem(n, T::zero());
119 }
120
121 let mut wh_b = Array1::from_elem(r, T::zero());
123 for i in 0..r {
124 wh_b[i] = inner_product(&self.w_columns[i], b);
125 }
126
127 let y = self
129 .e_lu
130 .solve(&wh_b)
131 .expect("Deflation matrix E should be non-singular");
132
133 let mut result = Array1::from_elem(n, T::zero());
135 for i in 0..r {
136 axpy(y[i], &self.w_columns[i], &mut result);
137 }
138
139 result
140 }
141
142 pub fn apply_recovery<A: LinearOperator<T>>(&self, v: &Array1<T>, operator: &A) -> Array1<T> {
144 let r = self.w_columns.len();
145 if r == 0 {
146 return v.clone();
147 }
148
149 let av = operator.apply(v);
151
152 let mut wh_av = Array1::from_elem(r, T::zero());
154 for i in 0..r {
155 wh_av[i] = inner_product(&self.w_columns[i], &av);
156 }
157
158 let y = self
160 .e_lu
161 .solve(&wh_av)
162 .expect("Deflation matrix E should be non-singular");
163
164 let mut result = v.clone();
166 for i in 0..r {
167 axpy(-y[i], &self.w_columns[i], &mut result);
168 }
169
170 result
171 }
172}
173
174pub fn gmres_deflated<T, A>(
179 operator: &A,
180 deflation: &DeflationSubspace<T>,
181 b: &Array1<T>,
182 x0: Option<&Array1<T>>,
183 config: &GmresConfig<T::Real>,
184) -> GmresSolution<T>
185where
186 T: ComplexField,
187 A: LinearOperator<T>,
188{
189 use crate::traits::IdentityPreconditioner;
190 let precond = IdentityPreconditioner;
191 gmres_deflated_preconditioned(operator, &precond, deflation, b, x0, config)
192}
193
194pub fn gmres_deflated_preconditioned<T, A, P>(
202 operator: &A,
203 precond: &P,
204 deflation: &DeflationSubspace<T>,
205 b: &Array1<T>,
206 x0: Option<&Array1<T>>,
207 config: &GmresConfig<T::Real>,
208) -> GmresSolution<T>
209where
210 T: ComplexField,
211 A: LinearOperator<T>,
212 P: Preconditioner<T>,
213{
214 if deflation.num_vectors() == 0 {
216 return super::gmres::gmres_preconditioned_with_guess(operator, precond, b, x0, config);
217 }
218
219 let n = b.len();
220 let m = config.restart;
221
222 let x_c = deflation.coarse_correction(b);
224
225 let mut x_hat = match x0 {
227 Some(guess) => guess.clone(),
228 None => Array1::from_elem(n, T::zero()),
229 };
230
231 let pb = precond.apply(&deflation.apply_left_projector(b));
233 let b_norm = vector_norm(&pb);
234 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
235 if b_norm < tol_threshold {
236 return GmresSolution {
238 x: x_c,
239 iterations: 0,
240 restarts: 0,
241 residual: T::Real::zero(),
242 converged: true,
243 status: crate::traits::SolverStatus::Converged,
244 };
245 }
246
247 let mut total_iterations = 0;
248 let mut restarts = 0;
249
250 for _outer in 0..config.max_iterations {
252 let ax = operator.apply(&x_hat);
254 let residual: Array1<T> = b - &ax;
255 let deflated_residual = deflation.apply_left_projector(&residual);
256 let r = precond.apply(&deflated_residual);
257 let beta = vector_norm(&r);
258
259 let rel_residual = beta / b_norm;
260 if rel_residual < config.tolerance {
261 let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
263 return GmresSolution {
264 x,
265 iterations: total_iterations,
266 restarts,
267 residual: rel_residual,
268 converged: true,
269 status: crate::traits::SolverStatus::Converged,
270 };
271 }
272
273 let mut v: Vec<Array1<T>> = Vec::with_capacity(m + 1);
275 v.push(r.mapv(|ri| ri * T::from_real(T::Real::one() / beta)));
276
277 let mut h: Array2<T> = Array2::from_elem((m + 1, m), T::zero());
279
280 let mut cs: Vec<T> = Vec::with_capacity(m);
282 let mut sn: Vec<T> = Vec::with_capacity(m);
283
284 let mut g: Array1<T> = Array1::from_elem(m + 1, T::zero());
286 g[0] = T::from_real(beta);
287
288 let mut inner_converged = false;
289
290 for j in 0..m {
292 total_iterations += 1;
293
294 let av = operator.apply(&v[j]);
296 let pav = deflation.apply_left_projector(&av);
297 let mut w = precond.apply(&pav);
298
299 for i in 0..=j {
301 h[[i, j]] = inner_product(&v[i], &w);
302 let h_ij = h[[i, j]];
303 w = &w - &v[i].mapv(|vi| vi * h_ij);
304 }
305
306 let w_norm = vector_norm(&w);
307 h[[j + 1, j]] = T::from_real(w_norm);
308
309 let breakdown_tol = T::Real::from_f64(1e-20).unwrap();
311 if w_norm < breakdown_tol {
312 inner_converged = true;
313 } else {
314 v.push(w.mapv(|wi| wi * T::from_real(T::Real::one() / w_norm)));
315 }
316
317 for i in 0..j {
319 let temp = cs[i].conj() * h[[i, j]] + sn[i].conj() * h[[i + 1, j]];
320 h[[i + 1, j]] = T::zero() - sn[i] * h[[i, j]] + cs[i] * h[[i + 1, j]];
321 h[[i, j]] = temp;
322 }
323
324 let (c, s) = givens_rotation(h[[j, j]], h[[j + 1, j]]);
326 cs.push(c);
327 sn.push(s);
328
329 h[[j, j]] = c.conj() * h[[j, j]] + s.conj() * h[[j + 1, j]];
331 h[[j + 1, j]] = T::zero();
332
333 let temp = c.conj() * g[j] + s.conj() * g[j + 1];
334 g[j + 1] = T::zero() - s * g[j] + c * g[j + 1];
335 g[j] = temp;
336
337 let abs_residual = g[j + 1].norm();
339 let rel_residual = abs_residual / b_norm;
340 let abs_tol = T::Real::from_f64(1e-20).unwrap();
341
342 if config.print_interval > 0 && total_iterations % config.print_interval == 0 {
343 log::info!(
344 "Deflated GMRES iteration {} (restart {}): relative residual = {:.6e}",
345 total_iterations,
346 restarts,
347 rel_residual.to_f64().unwrap_or(0.0)
348 );
349 }
350
351 if rel_residual < config.tolerance || abs_residual < abs_tol || inner_converged {
352 let y = solve_upper_triangular(&h, &g, j + 1);
354
355 for (i, &yi) in y.iter().enumerate() {
357 x_hat = &x_hat + &v[i].mapv(|vi| vi * yi);
358 }
359
360 let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
362
363 let status = if inner_converged
364 && rel_residual >= config.tolerance
365 && abs_residual >= abs_tol
366 {
367 crate::traits::SolverStatus::Breakdown
368 } else {
369 crate::traits::SolverStatus::Converged
370 };
371
372 return GmresSolution {
373 x,
374 iterations: total_iterations,
375 restarts,
376 residual: rel_residual,
377 converged: true,
378 status,
379 };
380 }
381 }
382
383 let y = solve_upper_triangular(&h, &g, m);
385 for (i, &yi) in y.iter().enumerate() {
386 x_hat = &x_hat + &v[i].mapv(|vi| vi * yi);
387 }
388
389 restarts += 1;
390 }
391
392 let x = deflation.apply_recovery(&x_hat, operator) + &x_c;
394
395 let ax = operator.apply(&x);
397 let r: Array1<T> = b - &ax;
398 let r_norm = vector_norm(&r);
399 let b_true_norm = vector_norm(b);
400 let rel_residual = if b_true_norm > T::Real::from_f64(1e-15).unwrap() {
401 r_norm / b_true_norm
402 } else {
403 r_norm
404 };
405
406 GmresSolution {
407 x,
408 iterations: total_iterations,
409 restarts,
410 residual: rel_residual,
411 converged: false,
412 status: crate::traits::SolverStatus::MaxIterationsReached,
413 }
414}
415
416#[inline]
418fn givens_rotation<T: ComplexField>(a: T, b: T) -> (T, T) {
419 let tol = T::Real::from_f64(1e-30).unwrap();
420 if b.norm() < tol {
421 return (T::one(), T::zero());
422 }
423 if a.norm() < tol {
424 return (T::zero(), T::one());
425 }
426
427 let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
428 let c = a * T::from_real(T::Real::one() / r);
429 let s = b * T::from_real(T::Real::one() / r);
430
431 (c, s)
432}
433
434fn solve_upper_triangular<T: ComplexField>(h: &Array2<T>, g: &Array1<T>, k: usize) -> Vec<T> {
436 let mut y = vec![T::zero(); k];
437 let tol = T::Real::from_f64(1e-30).unwrap();
438
439 for i in (0..k).rev() {
440 let mut sum = g[i];
441 for j in (i + 1)..k {
442 sum -= h[[i, j]] * y[j];
443 }
444 if h[[i, i]].norm() > tol {
445 y[i] = sum * h[[i, i]].inv();
446 }
447 }
448
449 y
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use crate::sparse::CsrMatrix;
456 use approx::assert_relative_eq;
457 use ndarray::array;
458 use num_complex::Complex64;
459
460 fn build_clustered_diagonal(n: usize, cluster_center: f64) -> CsrMatrix<Complex64> {
463 let mut dense = Array2::from_elem((n, n), Complex64::new(0.0, 0.0));
464 for i in 0..n {
465 if i < n - 3 {
466 let offset = 0.01 * (i as f64 - (n as f64 - 3.0) / 2.0);
468 dense[[i, i]] = Complex64::new(cluster_center + offset, 0.0);
469 } else {
470 dense[[i, i]] = Complex64::new((i + 1) as f64 * 10.0, 0.0);
472 }
473 }
474 CsrMatrix::from_dense(&dense, 1e-15)
475 }
476
477 #[test]
478 fn test_deflation_subspace_construction() {
479 let dense = array![
481 [Complex64::new(1.0, 0.0), Complex64::ZERO, Complex64::ZERO],
482 [Complex64::ZERO, Complex64::new(2.0, 0.0), Complex64::ZERO],
483 [Complex64::ZERO, Complex64::ZERO, Complex64::new(3.0, 0.0)],
484 ];
485 let a = CsrMatrix::from_dense(&dense, 1e-15);
486
487 let w1 = array![Complex64::new(1.0, 0.0), Complex64::ZERO, Complex64::ZERO];
489 let deflation = DeflationSubspace::new(vec![w1], &a).unwrap();
490
491 assert_eq!(deflation.num_vectors(), 1);
492
493 let e1 = array![Complex64::new(1.0, 0.0), Complex64::ZERO, Complex64::ZERO];
495 let pe1 = deflation.apply_left_projector(&e1);
496 let pe1_norm: f64 = pe1.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
497 assert!(pe1_norm < 1e-12, "Deflated eigenvector should be zero");
498
499 let e2 = array![Complex64::ZERO, Complex64::new(1.0, 0.0), Complex64::ZERO];
501 let pe2 = deflation.apply_left_projector(&e2);
502 let pe2_norm: f64 = pe2.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
503 assert!(pe2_norm > 0.5, "Non-deflated vector should remain");
504 }
505
506 #[test]
507 fn test_deflated_gmres_fewer_iterations() {
508 let n = 20;
509 let cluster = 5.0;
510 let a = build_clustered_diagonal(n, cluster);
511
512 let b = Array1::from_iter((0..n).map(|i| Complex64::new((i + 1) as f64, 0.0)));
514
515 let config = GmresConfig {
516 max_iterations: 200,
517 restart: 30,
518 tolerance: 1e-10,
519 print_interval: 0,
520 };
521
522 let sol_standard = super::super::gmres::gmres(&a, &b, &config);
524
525 let mut w_cols = Vec::new();
527 for i in 0..(n - 3) {
528 let mut w = Array1::from_elem(n, Complex64::ZERO);
529 w[i] = Complex64::new(1.0, 0.0);
530 w_cols.push(w);
531 }
532 let deflation = DeflationSubspace::new(w_cols, &a).unwrap();
533 let sol_deflated = gmres_deflated(&a, &deflation, &b, None, &config);
534
535 assert!(sol_standard.converged, "Standard GMRES should converge");
536 assert!(sol_deflated.converged, "Deflated GMRES should converge");
537
538 assert!(
540 sol_deflated.iterations <= sol_standard.iterations,
541 "Deflated ({}) should use <= iterations than standard ({})",
542 sol_deflated.iterations,
543 sol_standard.iterations
544 );
545
546 let ax = a.matvec(&sol_deflated.x);
548 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
549 assert!(error < 1e-8, "Deflated solution should satisfy Ax = b");
550 }
551
552 #[test]
553 fn test_deflated_gmres_with_preconditioner() {
554 let dense = array![
555 [
556 Complex64::new(4.0, 0.0),
557 Complex64::new(1.0, 0.0),
558 Complex64::ZERO,
559 ],
560 [
561 Complex64::new(1.0, 0.0),
562 Complex64::new(3.0, 0.0),
563 Complex64::new(1.0, 0.0),
564 ],
565 [
566 Complex64::ZERO,
567 Complex64::new(1.0, 0.0),
568 Complex64::new(5.0, 0.0),
569 ],
570 ];
571 let a = CsrMatrix::from_dense(&dense, 1e-15);
572 let b = array![
573 Complex64::new(1.0, 0.0),
574 Complex64::new(2.0, 0.0),
575 Complex64::new(3.0, 0.0)
576 ];
577
578 let precond = crate::preconditioners::DiagonalPreconditioner::from_csr(&a);
580
581 let w1 = array![
583 Complex64::new(0.5, 0.0),
584 Complex64::new(0.7, 0.0),
585 Complex64::new(0.5, 0.0)
586 ];
587 let w1_norm = vector_norm(&w1);
588 let w1_normalized = w1.mapv(|v| v * Complex64::new(1.0 / w1_norm, 0.0));
589
590 let deflation = DeflationSubspace::new(vec![w1_normalized], &a).unwrap();
591
592 let config = GmresConfig {
593 max_iterations: 50,
594 restart: 10,
595 tolerance: 1e-10,
596 print_interval: 0,
597 };
598
599 let sol = gmres_deflated_preconditioned(&a, &precond, &deflation, &b, None, &config);
600 assert!(
601 sol.converged,
602 "Deflated preconditioned GMRES should converge"
603 );
604
605 let ax = a.matvec(&sol.x);
607 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
608 assert!(error < 1e-8, "Solution should satisfy Ax = b");
609 }
610
611 #[test]
612 fn test_deflated_gmres_zero_vectors_fallback() {
613 let dense = array![
615 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
616 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
617 ];
618 let a = CsrMatrix::from_dense(&dense, 1e-15);
619 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
620
621 let deflation = DeflationSubspace::new(Vec::new(), &a).unwrap();
622 assert_eq!(deflation.num_vectors(), 0);
623
624 let config = GmresConfig {
625 max_iterations: 50,
626 restart: 10,
627 tolerance: 1e-10,
628 print_interval: 0,
629 };
630
631 let sol = gmres_deflated(&a, &deflation, &b, None, &config);
632 assert!(sol.converged, "Zero-vector deflated GMRES should converge");
633
634 let sol_standard = super::super::gmres::gmres(&a, &b, &config);
636 let error: f64 = (&sol.x - &sol_standard.x)
637 .iter()
638 .map(|e| e.norm_sqr())
639 .sum::<f64>()
640 .sqrt();
641 assert!(
642 error < 1e-8,
643 "Zero-deflation should match standard GMRES solution"
644 );
645 }
646
647 #[test]
648 fn test_coarse_correction_exact_for_deflation_space() {
649 let dense = array![
651 [Complex64::new(2.0, 0.0), Complex64::ZERO],
652 [Complex64::ZERO, Complex64::new(3.0, 0.0)],
653 ];
654 let a = CsrMatrix::from_dense(&dense, 1e-15);
655
656 let w1 = array![Complex64::new(1.0, 0.0), Complex64::ZERO];
658 let w2 = array![Complex64::ZERO, Complex64::new(1.0, 0.0)];
659 let deflation = DeflationSubspace::new(vec![w1, w2], &a).unwrap();
660
661 let b = array![Complex64::new(4.0, 0.0), Complex64::new(9.0, 0.0)];
662 let x_c = deflation.coarse_correction(&b);
663
664 let ax_c = a.matvec(&x_c);
666 let error: f64 = (&ax_c - &b)
667 .iter()
668 .map(|e| e.norm_sqr())
669 .sum::<f64>()
670 .sqrt();
671 assert!(error < 1e-10, "Coarse correction should be exact solution");
672 }
673
674 #[test]
675 fn test_deflated_gmres_f64() {
676 let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
677 let a = CsrMatrix::from_dense(&dense, 1e-15);
678 let b = array![1.0_f64, 2.0];
679
680 let w1 = array![std::f64::consts::FRAC_1_SQRT_2, std::f64::consts::FRAC_1_SQRT_2];
682 let w1_norm = vector_norm(&w1);
683 let w1_normalized = w1.mapv(|v| v / w1_norm);
684 let deflation = DeflationSubspace::new(vec![w1_normalized], &a).unwrap();
685
686 let config = GmresConfig {
687 max_iterations: 50,
688 restart: 10,
689 tolerance: 1e-10,
690 print_interval: 0,
691 };
692
693 let sol = gmres_deflated(&a, &deflation, &b, None, &config);
694 assert!(sol.converged);
695
696 let ax = a.matvec(&sol.x);
697 let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
698 assert_relative_eq!(error, 0.0, epsilon = 1e-8);
699 }
700}