Skip to main content

math_audio_solvers/iterative/
bicgstab.rs

1//! BiCGSTAB (Bi-Conjugate Gradient Stabilized) solver
2//!
3//! BiCGSTAB is a Krylov subspace method for non-symmetric systems.
4//! It often converges faster than GMRES for certain problem types.
5
6use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator, SolverStatus};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11/// BiCGSTAB solver configuration
12#[derive(Debug, Clone)]
13pub struct BiCgstabConfig<R> {
14    /// Maximum number of iterations
15    pub max_iterations: usize,
16    /// Relative tolerance for convergence
17    pub tolerance: R,
18    /// Print progress every N iterations (0 = no output)
19    pub print_interval: usize,
20}
21
22impl Default for BiCgstabConfig<f64> {
23    fn default() -> Self {
24        Self {
25            max_iterations: 1000,
26            tolerance: 1e-6,
27            print_interval: 0,
28        }
29    }
30}
31
32/// BiCGSTAB solver result
33#[derive(Debug)]
34pub struct BiCgstabSolution<T: ComplexField> {
35    /// Solution vector
36    pub x: Array1<T>,
37    /// Number of iterations
38    pub iterations: usize,
39    /// Final relative residual
40    pub residual: T::Real,
41    /// Whether convergence was achieved
42    pub converged: bool,
43    /// Solver status
44    pub status: SolverStatus,
45}
46
47/// Solve Ax = b using the BiCGSTAB method
48pub fn bicgstab<T, A>(
49    operator: &A,
50    b: &Array1<T>,
51    config: &BiCgstabConfig<T::Real>,
52) -> BiCgstabSolution<T>
53where
54    T: ComplexField,
55    A: LinearOperator<T>,
56{
57    let n = b.len();
58    let mut x = Array1::from_elem(n, T::zero());
59
60    let b_norm = vector_norm(b);
61    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
62    if b_norm < tol_threshold {
63        return BiCgstabSolution {
64            x,
65            iterations: 0,
66            residual: T::Real::zero(),
67            converged: true,
68            status: SolverStatus::Converged,
69        };
70    }
71
72    // Initial residual
73    let mut r = b.clone();
74    let r0 = r.clone(); // Shadow residual
75
76    let mut rho = T::one();
77    let mut alpha = T::one();
78    let mut omega = T::one();
79
80    let mut p = Array1::from_elem(n, T::zero());
81    let mut v = Array1::from_elem(n, T::zero());
82
83    for iter in 0..config.max_iterations {
84        let rho_new = inner_product(&r0, &r);
85
86        // Check for breakdown
87        if rho_new.norm() < T::Real::from_f64(1e-20).unwrap() {
88            return BiCgstabSolution {
89                x,
90                iterations: iter,
91                residual: vector_norm(&r) / b_norm,
92                converged: false,
93                status: SolverStatus::Breakdown,
94            };
95        }
96
97        let beta = (rho_new / rho) * (alpha / omega);
98        rho = rho_new;
99
100        // p = r + beta * (p - omega * v)
101        p = &r + &(&p - &v.mapv(|vi| vi * omega)).mapv(|pi| pi * beta);
102
103        // v = A * p
104        v = operator.apply(&p);
105
106        let r0v = inner_product(&r0, &v);
107        if r0v.norm() < T::Real::from_f64(1e-20).unwrap() {
108            return BiCgstabSolution {
109                x,
110                iterations: iter,
111                residual: vector_norm(&r) / b_norm,
112                converged: false,
113                status: SolverStatus::Breakdown,
114            };
115        }
116
117        alpha = rho / r0v;
118
119        // s = r - alpha * v
120        let s = &r - &v.mapv(|vi| vi * alpha);
121
122        // Check for early convergence
123        let s_norm = vector_norm(&s);
124        if s_norm / b_norm < config.tolerance {
125            x = &x + &p.mapv(|pi| pi * alpha);
126            return BiCgstabSolution {
127                x,
128                iterations: iter + 1,
129                residual: s_norm / b_norm,
130                converged: true,
131                status: SolverStatus::Converged,
132            };
133        }
134
135        // t = A * s
136        let t = operator.apply(&s);
137
138        // omega = (t, s) / (t, t)
139        let tt = inner_product(&t, &t);
140        if tt.norm() < T::Real::from_f64(1e-20).unwrap() {
141            return BiCgstabSolution {
142                x,
143                iterations: iter,
144                residual: vector_norm(&r) / b_norm,
145                converged: false,
146                status: SolverStatus::Breakdown,
147            };
148        }
149        omega = inner_product(&t, &s) / tt;
150
151        // x = x + alpha * p + omega * s
152        x = &x + &p.mapv(|pi| pi * alpha) + &s.mapv(|si| si * omega);
153
154        // r = s - omega * t
155        r = &s - &t.mapv(|ti| ti * omega);
156
157        let rel_residual = vector_norm(&r) / b_norm;
158
159        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
160            log::info!(
161                "BiCGSTAB iteration {}: relative residual = {:.6e}",
162                iter + 1,
163                rel_residual.to_f64().unwrap_or(0.0)
164            );
165        }
166
167        if rel_residual < config.tolerance {
168            return BiCgstabSolution {
169                x,
170                iterations: iter + 1,
171                residual: rel_residual,
172                converged: true,
173                status: SolverStatus::Converged,
174            };
175        }
176
177        // Check for stagnation
178        if omega.norm() < T::Real::from_f64(1e-20).unwrap() {
179            return BiCgstabSolution {
180                x,
181                iterations: iter + 1,
182                residual: rel_residual,
183                converged: false,
184                status: SolverStatus::Stagnated,
185            };
186        }
187    }
188
189    let rel_residual = vector_norm(&r) / b_norm;
190    BiCgstabSolution {
191        x,
192        iterations: config.max_iterations,
193        residual: rel_residual,
194        converged: false,
195        status: SolverStatus::MaxIterationsReached,
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::sparse::CsrMatrix;
203    use ndarray::array;
204    use num_complex::Complex64;
205
206    #[test]
207    fn test_bicgstab_simple() {
208        let dense = array![
209            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
210            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
211        ];
212
213        let a = CsrMatrix::from_dense(&dense, 1e-15);
214        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
215
216        let config = BiCgstabConfig {
217            max_iterations: 100,
218            tolerance: 1e-10,
219            print_interval: 0,
220        };
221
222        let solution = bicgstab(&a, &b, &config);
223
224        assert!(solution.converged, "BiCGSTAB should converge");
225
226        let ax = a.matvec(&solution.x);
227        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
228        assert!(error < 1e-8, "Solution should satisfy Ax = b");
229    }
230}