bem/core/solver/
bicgstab.rs

1//! BiCGSTAB (Bi-Conjugate Gradient Stabilized) solver
2//!
3//! Implementation based on van der Vorst (1992).
4//!
5//! BiCGSTAB is often more stable than CGS for non-symmetric systems,
6//! using GMRES-like smoothing to reduce irregular convergence behavior.
7//!
8//! ## Algorithm
9//!
10//! BiCGSTAB improves upon CGS by adding a stabilization step that
11//! minimizes the residual in a 1D subspace at each iteration.
12
13use ndarray::Array1;
14use num_complex::Complex64;
15
16/// BiCGSTAB solver configuration
17#[derive(Debug, Clone)]
18pub struct BiCgstabConfig {
19    /// Maximum number of iterations
20    pub max_iterations: usize,
21    /// Relative tolerance for convergence
22    pub tolerance: f64,
23    /// Print progress every N iterations (0 = no output)
24    pub print_interval: usize,
25}
26
27impl Default for BiCgstabConfig {
28    fn default() -> Self {
29        Self {
30            max_iterations: 1000,
31            tolerance: 1e-6,
32            print_interval: 10,
33        }
34    }
35}
36
37/// BiCGSTAB solver result
38#[derive(Debug)]
39pub struct BiCgstabSolution {
40    /// Solution vector
41    pub x: Array1<Complex64>,
42    /// Number of iterations performed
43    pub iterations: usize,
44    /// Final relative residual
45    pub residual: f64,
46    /// Whether convergence was achieved
47    pub converged: bool,
48}
49
50/// Solve Ax = b using the BiCGSTAB method
51///
52/// # Arguments
53/// * `matvec` - Function to compute A*x for a given x
54/// * `b` - Right-hand side vector
55/// * `x0` - Optional initial guess (defaults to zero)
56/// * `config` - Solver configuration
57///
58/// # Returns
59/// Solution struct containing x, iteration count, and convergence info
60///
61/// # Example
62/// ```ignore
63/// let config = BiCgstabConfig::default();
64/// let matvec = |x: &Array1<Complex64>| system.matvec(x);
65/// let solution = bicgstab_solve(&matvec, &rhs, None, &config);
66/// ```
67pub fn bicgstab_solve<F>(
68    matvec: F,
69    b: &Array1<Complex64>,
70    x0: Option<&Array1<Complex64>>,
71    config: &BiCgstabConfig,
72) -> BiCgstabSolution
73where
74    F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
75{
76    let n = b.len();
77
78    // Initialize solution vector
79    let mut x = match x0 {
80        Some(x0) => x0.clone(),
81        None => Array1::zeros(n),
82    };
83
84    // Initial residual: r = b - A*x
85    let ax = matvec(&x);
86    let mut r: Array1<Complex64> = b - &ax;
87
88    // r̃₀ = r₀ (shadow residual, kept constant)
89    let r_tilde = r.clone();
90
91    // Compute initial residual norm
92    let err_ori = residual_norm(&r);
93    if err_ori < 1e-15 {
94        return BiCgstabSolution {
95            x,
96            iterations: 0,
97            residual: 0.0,
98            converged: true,
99        };
100    }
101
102    // Initialize scalars
103    let mut rho = Complex64::new(1.0, 0.0);
104    let mut alpha = Complex64::new(1.0, 0.0);
105    let mut omega = Complex64::new(1.0, 0.0);
106
107    // Initialize vectors
108    let mut v: Array1<Complex64> = Array1::zeros(n);
109    let mut p: Array1<Complex64> = Array1::zeros(n);
110
111    let mut iterations = 0;
112    let mut err_rel = 1.0;
113
114    for j in 0..config.max_iterations {
115        iterations = j + 1;
116
117        // ρ_j = (r̃₀, r_{j-1})
118        let rho_new = inner_product(&r_tilde, &r);
119
120        // Check for breakdown
121        if rho_new.norm() < 1e-30 {
122            return BiCgstabSolution {
123                x,
124                iterations,
125                residual: err_rel,
126                converged: false,
127            };
128        }
129
130        // β = (ρ_j / ρ_{j-1}) * (α / ω)
131        let beta = (rho_new / rho) * (alpha / omega);
132
133        // p_j = r_{j-1} + β * (p_{j-1} - ω * v_{j-1})
134        let p_minus_omega_v: Array1<Complex64> = &p - &(&v * omega);
135        p = &r + &(&p_minus_omega_v * beta);
136
137        // v_j = A * p_j
138        v = matvec(&p);
139
140        // α = ρ_j / (r̃₀, v_j)
141        let r_tilde_v = inner_product(&r_tilde, &v);
142        if r_tilde_v.norm() < 1e-30 {
143            return BiCgstabSolution {
144                x,
145                iterations,
146                residual: err_rel,
147                converged: false,
148            };
149        }
150        alpha = rho_new / r_tilde_v;
151
152        // s = r_{j-1} - α * v_j
153        let s: Array1<Complex64> = &r - &(&v * alpha);
154
155        // Check if s is small enough for convergence
156        let s_norm = residual_norm(&s);
157        if s_norm / err_ori < config.tolerance {
158            // Update x and return
159            x = &x + &(&p * alpha);
160            return BiCgstabSolution {
161                x,
162                iterations,
163                residual: s_norm / err_ori,
164                converged: true,
165            };
166        }
167
168        // t = A * s
169        let t = matvec(&s);
170
171        // ω = (t, s) / (t, t)
172        let t_s = inner_product(&t, &s);
173        let t_t = inner_product(&t, &t);
174        if t_t.norm() < 1e-30 {
175            return BiCgstabSolution {
176                x,
177                iterations,
178                residual: err_rel,
179                converged: false,
180            };
181        }
182        omega = t_s / t_t;
183
184        // x_j = x_{j-1} + α * p_j + ω * s
185        x = &x + &(&p * alpha) + &(&s * omega);
186
187        // r_j = s - ω * t
188        r = &s - &(&t * omega);
189
190        // Compute residual norm
191        let r_norm = residual_norm(&r);
192        err_rel = r_norm / err_ori;
193
194        // Print progress
195        if config.print_interval > 0 && j % config.print_interval == 0 {
196            eprintln!("BiCGSTAB iteration {}: relative residual = {:.6e}", j, err_rel);
197        }
198
199        // Check convergence
200        if err_rel < config.tolerance {
201            return BiCgstabSolution {
202                x,
203                iterations,
204                residual: err_rel,
205                converged: true,
206            };
207        }
208
209        // Check for breakdown in omega
210        if omega.norm() < 1e-30 {
211            return BiCgstabSolution {
212                x,
213                iterations,
214                residual: err_rel,
215                converged: false,
216            };
217        }
218
219        rho = rho_new;
220    }
221
222    BiCgstabSolution {
223        x,
224        iterations,
225        residual: err_rel,
226        converged: false,
227    }
228}
229
230/// Compute inner product (x, y) = Σ conj(x_i) * y_i
231fn inner_product(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
232    x.iter()
233        .zip(y.iter())
234        .map(|(xi, yi)| xi.conj() * yi)
235        .sum()
236}
237
238/// Compute residual norm ||r||₂
239fn residual_norm(r: &Array1<Complex64>) -> f64 {
240    r.iter().map(|ri| ri.norm_sqr()).sum::<f64>().sqrt()
241}
242
243/// BiCGSTAB solver with preconditioner
244///
245/// Solves Ax = b using left preconditioning: M⁻¹Ax = M⁻¹b
246pub fn bicgstab_solve_preconditioned<F, P>(
247    matvec: F,
248    precond_solve: P,
249    b: &Array1<Complex64>,
250    x0: Option<&Array1<Complex64>>,
251    config: &BiCgstabConfig,
252) -> BiCgstabSolution
253where
254    F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
255    P: Fn(&Array1<Complex64>) -> Array1<Complex64>,
256{
257    let n = b.len();
258
259    // Initialize solution vector
260    let mut x = match x0 {
261        Some(x0) => x0.clone(),
262        None => Array1::zeros(n),
263    };
264
265    // Initial residual: r = b - A*x
266    let ax = matvec(&x);
267    let r0: Array1<Complex64> = b - &ax;
268
269    // Apply preconditioner
270    let mut r = precond_solve(&r0);
271    let r_tilde = r.clone();
272
273    let err_ori = residual_norm(&r);
274    if err_ori < 1e-15 {
275        return BiCgstabSolution {
276            x,
277            iterations: 0,
278            residual: 0.0,
279            converged: true,
280        };
281    }
282
283    let mut rho = Complex64::new(1.0, 0.0);
284    let mut alpha = Complex64::new(1.0, 0.0);
285    let mut omega = Complex64::new(1.0, 0.0);
286
287    let mut v: Array1<Complex64> = Array1::zeros(n);
288    let mut p: Array1<Complex64> = Array1::zeros(n);
289
290    let mut iterations = 0;
291    let mut err_rel = 1.0;
292
293    for j in 0..config.max_iterations {
294        iterations = j + 1;
295
296        let rho_new = inner_product(&r_tilde, &r);
297        if rho_new.norm() < 1e-30 {
298            return BiCgstabSolution {
299                x,
300                iterations,
301                residual: err_rel,
302                converged: false,
303            };
304        }
305
306        let beta = (rho_new / rho) * (alpha / omega);
307        let p_minus_omega_v: Array1<Complex64> = &p - &(&v * omega);
308        p = &r + &(&p_minus_omega_v * beta);
309
310        // v = M⁻¹ * A * p
311        let ap = matvec(&p);
312        v = precond_solve(&ap);
313
314        let r_tilde_v = inner_product(&r_tilde, &v);
315        if r_tilde_v.norm() < 1e-30 {
316            return BiCgstabSolution {
317                x,
318                iterations,
319                residual: err_rel,
320                converged: false,
321            };
322        }
323        alpha = rho_new / r_tilde_v;
324
325        let s: Array1<Complex64> = &r - &(&v * alpha);
326        let s_norm = residual_norm(&s);
327        if s_norm / err_ori < config.tolerance {
328            x = &x + &(&p * alpha);
329            return BiCgstabSolution {
330                x,
331                iterations,
332                residual: s_norm / err_ori,
333                converged: true,
334            };
335        }
336
337        // t = M⁻¹ * A * s
338        let as_ = matvec(&s);
339        let t = precond_solve(&as_);
340
341        let t_s = inner_product(&t, &s);
342        let t_t = inner_product(&t, &t);
343        if t_t.norm() < 1e-30 {
344            return BiCgstabSolution {
345                x,
346                iterations,
347                residual: err_rel,
348                converged: false,
349            };
350        }
351        omega = t_s / t_t;
352
353        x = &x + &(&p * alpha) + &(&s * omega);
354        r = &s - &(&t * omega);
355
356        let r_norm = residual_norm(&r);
357        err_rel = r_norm / err_ori;
358
359        if config.print_interval > 0 && j % config.print_interval == 0 {
360            eprintln!(
361                "BiCGSTAB (precond) iteration {}: relative residual = {:.6e}",
362                j, err_rel
363            );
364        }
365
366        if err_rel < config.tolerance {
367            return BiCgstabSolution {
368                x,
369                iterations,
370                residual: err_rel,
371                converged: true,
372            };
373        }
374
375        if omega.norm() < 1e-30 {
376            return BiCgstabSolution {
377                x,
378                iterations,
379                residual: err_rel,
380                converged: false,
381            };
382        }
383
384        rho = rho_new;
385    }
386
387    BiCgstabSolution {
388        x,
389        iterations,
390        residual: err_rel,
391        converged: false,
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use ndarray::Array2;
399
400    #[test]
401    fn test_bicgstab_simple() {
402        // Simple 2x2 positive definite system
403        let a = Array2::from_shape_vec(
404            (2, 2),
405            vec![
406                Complex64::new(4.0, 0.0),
407                Complex64::new(1.0, 0.0),
408                Complex64::new(1.0, 0.0),
409                Complex64::new(3.0, 0.0),
410            ],
411        )
412        .unwrap();
413
414        let b = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)]);
415
416        let matvec = |x: &Array1<Complex64>| a.dot(x);
417
418        let config = BiCgstabConfig {
419            max_iterations: 100,
420            tolerance: 1e-10,
421            print_interval: 0,
422        };
423
424        let solution = bicgstab_solve(&matvec, &b, None, &config);
425
426        assert!(solution.converged, "BiCGSTAB should converge");
427
428        // Verify solution: Ax ≈ b
429        let ax = a.dot(&solution.x);
430        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
431        assert!(error < 1e-8, "Solution should satisfy Ax = b");
432    }
433
434    #[test]
435    fn test_bicgstab_complex() {
436        // Complex non-symmetric system
437        let a = Array2::from_shape_vec(
438            (2, 2),
439            vec![
440                Complex64::new(2.0, 1.0),
441                Complex64::new(0.0, -1.0),
442                Complex64::new(0.0, 1.0),
443                Complex64::new(2.0, -1.0),
444            ],
445        )
446        .unwrap();
447
448        let b = Array1::from_vec(vec![Complex64::new(1.0, 1.0), Complex64::new(1.0, -1.0)]);
449
450        let matvec = |x: &Array1<Complex64>| a.dot(x);
451
452        let config = BiCgstabConfig {
453            max_iterations: 100,
454            tolerance: 1e-10,
455            print_interval: 0,
456        };
457
458        let solution = bicgstab_solve(&matvec, &b, None, &config);
459
460        assert!(solution.converged, "BiCGSTAB should converge for complex system");
461
462        // Verify solution
463        let ax = a.dot(&solution.x);
464        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
465        assert!(error < 1e-8, "Solution should satisfy Ax = b");
466    }
467
468    #[test]
469    fn test_bicgstab_identity() {
470        // Identity matrix - should converge in 1 iteration
471        let n = 5;
472        let b = Array1::from_vec(
473            (1..=n)
474                .map(|i| Complex64::new(i as f64, 0.0))
475                .collect::<Vec<_>>(),
476        );
477
478        let matvec = |x: &Array1<Complex64>| x.clone();
479
480        let config = BiCgstabConfig {
481            max_iterations: 10,
482            tolerance: 1e-12,
483            print_interval: 0,
484        };
485
486        let solution = bicgstab_solve(&matvec, &b, None, &config);
487
488        assert!(solution.converged);
489        assert!(solution.iterations <= 2); // Should converge very quickly
490
491        // x should equal b for identity matrix
492        let error: f64 = (&solution.x - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
493        assert!(error < 1e-10);
494    }
495
496    #[test]
497    fn test_bicgstab_non_symmetric() {
498        // Non-symmetric matrix - BiCGSTAB should handle this
499        let a = Array2::from_shape_vec(
500            (3, 3),
501            vec![
502                Complex64::new(4.0, 0.0),
503                Complex64::new(1.0, 0.0),
504                Complex64::new(0.0, 0.0),
505                Complex64::new(2.0, 0.0),
506                Complex64::new(5.0, 0.0),
507                Complex64::new(1.0, 0.0),
508                Complex64::new(0.0, 0.0),
509                Complex64::new(1.0, 0.0),
510                Complex64::new(3.0, 0.0),
511            ],
512        )
513        .unwrap();
514
515        let b = Array1::from_vec(vec![
516            Complex64::new(5.0, 0.0),
517            Complex64::new(8.0, 0.0),
518            Complex64::new(4.0, 0.0),
519        ]);
520
521        let matvec = |x: &Array1<Complex64>| a.dot(x);
522
523        let config = BiCgstabConfig {
524            max_iterations: 100,
525            tolerance: 1e-10,
526            print_interval: 0,
527        };
528
529        let solution = bicgstab_solve(&matvec, &b, None, &config);
530
531        assert!(solution.converged, "BiCGSTAB should converge for non-symmetric system");
532
533        // Verify solution
534        let ax = a.dot(&solution.x);
535        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
536        assert!(error < 1e-8, "Solution should satisfy Ax = b");
537    }
538
539    #[test]
540    fn test_bicgstab_vs_cgs_stability() {
541        // System that might be challenging for CGS but stable for BiCGSTAB
542        let a = Array2::from_shape_vec(
543            (3, 3),
544            vec![
545                Complex64::new(1.0, 0.1),
546                Complex64::new(0.5, 0.0),
547                Complex64::new(0.0, 0.0),
548                Complex64::new(0.5, 0.0),
549                Complex64::new(1.0, -0.1),
550                Complex64::new(0.5, 0.0),
551                Complex64::new(0.0, 0.0),
552                Complex64::new(0.5, 0.0),
553                Complex64::new(1.0, 0.1),
554            ],
555        )
556        .unwrap();
557
558        let b = Array1::from_vec(vec![
559            Complex64::new(1.0, 0.0),
560            Complex64::new(1.0, 0.0),
561            Complex64::new(1.0, 0.0),
562        ]);
563
564        let matvec = |x: &Array1<Complex64>| a.dot(x);
565
566        let config = BiCgstabConfig {
567            max_iterations: 100,
568            tolerance: 1e-10,
569            print_interval: 0,
570        };
571
572        let solution = bicgstab_solve(&matvec, &b, None, &config);
573
574        assert!(solution.converged);
575
576        let ax = a.dot(&solution.x);
577        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
578        assert!(error < 1e-8);
579    }
580}