math_audio_solvers/iterative/
bicgstab.rs

1//! BiCGSTAB (Bi-Conjugate Gradient Stabilized) solver
2//!
3//! BiCGSTAB is a Krylov subspace method for non-symmetric systems.
4//! It often converges faster than GMRES for certain problem types.
5
6use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11/// BiCGSTAB solver configuration
12#[derive(Debug, Clone)]
13pub struct BiCgstabConfig<R> {
14    /// Maximum number of iterations
15    pub max_iterations: usize,
16    /// Relative tolerance for convergence
17    pub tolerance: R,
18    /// Print progress every N iterations (0 = no output)
19    pub print_interval: usize,
20}
21
22impl Default for BiCgstabConfig<f64> {
23    fn default() -> Self {
24        Self {
25            max_iterations: 1000,
26            tolerance: 1e-6,
27            print_interval: 0,
28        }
29    }
30}
31
32/// BiCGSTAB solver result
33#[derive(Debug)]
34pub struct BiCgstabSolution<T: ComplexField> {
35    /// Solution vector
36    pub x: Array1<T>,
37    /// Number of iterations
38    pub iterations: usize,
39    /// Final relative residual
40    pub residual: T::Real,
41    /// Whether convergence was achieved
42    pub converged: bool,
43}
44
45/// Solve Ax = b using the BiCGSTAB method
46pub fn bicgstab<T, A>(
47    operator: &A,
48    b: &Array1<T>,
49    config: &BiCgstabConfig<T::Real>,
50) -> BiCgstabSolution<T>
51where
52    T: ComplexField,
53    A: LinearOperator<T>,
54{
55    let n = b.len();
56    let mut x = Array1::from_elem(n, T::zero());
57
58    let b_norm = vector_norm(b);
59    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
60    if b_norm < tol_threshold {
61        return BiCgstabSolution {
62            x,
63            iterations: 0,
64            residual: T::Real::zero(),
65            converged: true,
66        };
67    }
68
69    // Initial residual
70    let mut r = b.clone();
71    let r0 = r.clone(); // Shadow residual
72
73    let mut rho = T::one();
74    let mut alpha = T::one();
75    let mut omega = T::one();
76
77    let mut p = Array1::from_elem(n, T::zero());
78    let mut v = Array1::from_elem(n, T::zero());
79
80    for iter in 0..config.max_iterations {
81        let rho_new = inner_product(&r0, &r);
82
83        // Check for breakdown
84        if rho_new.norm() < T::Real::from_f64(1e-30).unwrap() {
85            return BiCgstabSolution {
86                x,
87                iterations: iter,
88                residual: vector_norm(&r) / b_norm,
89                converged: false,
90            };
91        }
92
93        let beta = (rho_new / rho) * (alpha / omega);
94        rho = rho_new;
95
96        // p = r + beta * (p - omega * v)
97        p = &r + &(&p - &v.mapv(|vi| vi * omega)).mapv(|pi| pi * beta);
98
99        // v = A * p
100        v = operator.apply(&p);
101
102        let r0v = inner_product(&r0, &v);
103        if r0v.norm() < T::Real::from_f64(1e-30).unwrap() {
104            return BiCgstabSolution {
105                x,
106                iterations: iter,
107                residual: vector_norm(&r) / b_norm,
108                converged: false,
109            };
110        }
111
112        alpha = rho / r0v;
113
114        // s = r - alpha * v
115        let s = &r - &v.mapv(|vi| vi * alpha);
116
117        // Check for early convergence
118        let s_norm = vector_norm(&s);
119        if s_norm / b_norm < config.tolerance {
120            x = &x + &p.mapv(|pi| pi * alpha);
121            return BiCgstabSolution {
122                x,
123                iterations: iter + 1,
124                residual: s_norm / b_norm,
125                converged: true,
126            };
127        }
128
129        // t = A * s
130        let t = operator.apply(&s);
131
132        // omega = (t, s) / (t, t)
133        let tt = inner_product(&t, &t);
134        if tt.norm() < T::Real::from_f64(1e-30).unwrap() {
135            return BiCgstabSolution {
136                x,
137                iterations: iter,
138                residual: vector_norm(&r) / b_norm,
139                converged: false,
140            };
141        }
142        omega = inner_product(&t, &s) / tt;
143
144        // x = x + alpha * p + omega * s
145        x = &x + &p.mapv(|pi| pi * alpha) + &s.mapv(|si| si * omega);
146
147        // r = s - omega * t
148        r = &s - &t.mapv(|ti| ti * omega);
149
150        let rel_residual = vector_norm(&r) / b_norm;
151
152        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
153            log::info!(
154                "BiCGSTAB iteration {}: relative residual = {:.6e}",
155                iter + 1,
156                rel_residual.to_f64().unwrap_or(0.0)
157            );
158        }
159
160        if rel_residual < config.tolerance {
161            return BiCgstabSolution {
162                x,
163                iterations: iter + 1,
164                residual: rel_residual,
165                converged: true,
166            };
167        }
168
169        // Check for stagnation
170        if omega.norm() < T::Real::from_f64(1e-30).unwrap() {
171            return BiCgstabSolution {
172                x,
173                iterations: iter + 1,
174                residual: rel_residual,
175                converged: false,
176            };
177        }
178    }
179
180    let rel_residual = vector_norm(&r) / b_norm;
181    BiCgstabSolution {
182        x,
183        iterations: config.max_iterations,
184        residual: rel_residual,
185        converged: false,
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::sparse::CsrMatrix;
193    use ndarray::array;
194    use num_complex::Complex64;
195
196    #[test]
197    fn test_bicgstab_simple() {
198        let dense = array![
199            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
200            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
201        ];
202
203        let a = CsrMatrix::from_dense(&dense, 1e-15);
204        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
205
206        let config = BiCgstabConfig {
207            max_iterations: 100,
208            tolerance: 1e-10,
209            print_interval: 0,
210        };
211
212        let solution = bicgstab(&a, &b, &config);
213
214        assert!(solution.converged, "BiCGSTAB should converge");
215
216        let ax = a.matvec(&solution.x);
217        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
218        assert!(error < 1e-8, "Solution should satisfy Ax = b");
219    }
220}