math_audio_solvers/iterative/
bicgstab.rs

1//! BiCGSTAB (Bi-Conjugate Gradient Stabilized) solver
2//!
3//! BiCGSTAB is a Krylov subspace method for non-symmetric systems.
4//! It often converges faster than GMRES for certain problem types.
5
6use crate::traits::{ComplexField, LinearOperator};
7use ndarray::Array1;
8use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
9
10/// BiCGSTAB solver configuration
11#[derive(Debug, Clone)]
12pub struct BiCgstabConfig<R> {
13    /// Maximum number of iterations
14    pub max_iterations: usize,
15    /// Relative tolerance for convergence
16    pub tolerance: R,
17    /// Print progress every N iterations (0 = no output)
18    pub print_interval: usize,
19}
20
21impl Default for BiCgstabConfig<f64> {
22    fn default() -> Self {
23        Self {
24            max_iterations: 1000,
25            tolerance: 1e-6,
26            print_interval: 0,
27        }
28    }
29}
30
31/// BiCGSTAB solver result
32#[derive(Debug)]
33pub struct BiCgstabSolution<T: ComplexField> {
34    /// Solution vector
35    pub x: Array1<T>,
36    /// Number of iterations
37    pub iterations: usize,
38    /// Final relative residual
39    pub residual: T::Real,
40    /// Whether convergence was achieved
41    pub converged: bool,
42}
43
44/// Solve Ax = b using the BiCGSTAB method
45pub fn bicgstab<T, A>(
46    operator: &A,
47    b: &Array1<T>,
48    config: &BiCgstabConfig<T::Real>,
49) -> BiCgstabSolution<T>
50where
51    T: ComplexField,
52    A: LinearOperator<T>,
53{
54    let n = b.len();
55    let mut x = Array1::from_elem(n, T::zero());
56
57    let b_norm = vector_norm(b);
58    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
59    if b_norm < tol_threshold {
60        return BiCgstabSolution {
61            x,
62            iterations: 0,
63            residual: T::Real::zero(),
64            converged: true,
65        };
66    }
67
68    // Initial residual
69    let mut r = b.clone();
70    let r0 = r.clone(); // Shadow residual
71
72    let mut rho = T::one();
73    let mut alpha = T::one();
74    let mut omega = T::one();
75
76    let mut p = Array1::from_elem(n, T::zero());
77    let mut v = Array1::from_elem(n, T::zero());
78
79    for iter in 0..config.max_iterations {
80        let rho_new = inner_product(&r0, &r);
81
82        // Check for breakdown
83        if rho_new.norm() < T::Real::from_f64(1e-30).unwrap() {
84            return BiCgstabSolution {
85                x,
86                iterations: iter,
87                residual: vector_norm(&r) / b_norm,
88                converged: false,
89            };
90        }
91
92        let beta = (rho_new / rho) * (alpha / omega);
93        rho = rho_new;
94
95        // p = r + beta * (p - omega * v)
96        p = &r + &(&p - &v.mapv(|vi| vi * omega)).mapv(|pi| pi * beta);
97
98        // v = A * p
99        v = operator.apply(&p);
100
101        let r0v = inner_product(&r0, &v);
102        if r0v.norm() < T::Real::from_f64(1e-30).unwrap() {
103            return BiCgstabSolution {
104                x,
105                iterations: iter,
106                residual: vector_norm(&r) / b_norm,
107                converged: false,
108            };
109        }
110
111        alpha = rho / r0v;
112
113        // s = r - alpha * v
114        let s = &r - &v.mapv(|vi| vi * alpha);
115
116        // Check for early convergence
117        let s_norm = vector_norm(&s);
118        if s_norm / b_norm < config.tolerance {
119            x = &x + &p.mapv(|pi| pi * alpha);
120            return BiCgstabSolution {
121                x,
122                iterations: iter + 1,
123                residual: s_norm / b_norm,
124                converged: true,
125            };
126        }
127
128        // t = A * s
129        let t = operator.apply(&s);
130
131        // omega = (t, s) / (t, t)
132        let tt = inner_product(&t, &t);
133        if tt.norm() < T::Real::from_f64(1e-30).unwrap() {
134            return BiCgstabSolution {
135                x,
136                iterations: iter,
137                residual: vector_norm(&r) / b_norm,
138                converged: false,
139            };
140        }
141        omega = inner_product(&t, &s) / tt;
142
143        // x = x + alpha * p + omega * s
144        x = &x + &p.mapv(|pi| pi * alpha) + &s.mapv(|si| si * omega);
145
146        // r = s - omega * t
147        r = &s - &t.mapv(|ti| ti * omega);
148
149        let rel_residual = vector_norm(&r) / b_norm;
150
151        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
152            log::info!(
153                "BiCGSTAB iteration {}: relative residual = {:.6e}",
154                iter + 1,
155                rel_residual.to_f64().unwrap_or(0.0)
156            );
157        }
158
159        if rel_residual < config.tolerance {
160            return BiCgstabSolution {
161                x,
162                iterations: iter + 1,
163                residual: rel_residual,
164                converged: true,
165            };
166        }
167
168        // Check for stagnation
169        if omega.norm() < T::Real::from_f64(1e-30).unwrap() {
170            return BiCgstabSolution {
171                x,
172                iterations: iter + 1,
173                residual: rel_residual,
174                converged: false,
175            };
176        }
177    }
178
179    let rel_residual = vector_norm(&r) / b_norm;
180    BiCgstabSolution {
181        x,
182        iterations: config.max_iterations,
183        residual: rel_residual,
184        converged: false,
185    }
186}
187
188#[inline]
189fn inner_product<T: ComplexField>(x: &Array1<T>, y: &Array1<T>) -> T {
190    x.iter()
191        .zip(y.iter())
192        .fold(T::zero(), |acc, (&xi, &yi)| acc + xi.conj() * yi)
193}
194
195#[inline]
196fn vector_norm<T: ComplexField>(x: &Array1<T>) -> T::Real {
197    x.iter()
198        .map(|xi| xi.norm_sqr())
199        .fold(T::Real::zero(), |acc, v| acc + v)
200        .sqrt()
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use crate::sparse::CsrMatrix;
207    use ndarray::array;
208    use num_complex::Complex64;
209
210    #[test]
211    fn test_bicgstab_simple() {
212        let dense = array![
213            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
214            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
215        ];
216
217        let a = CsrMatrix::from_dense(&dense, 1e-15);
218        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
219
220        let config = BiCgstabConfig {
221            max_iterations: 100,
222            tolerance: 1e-10,
223            print_interval: 0,
224        };
225
226        let solution = bicgstab(&a, &b, &config);
227
228        assert!(solution.converged, "BiCGSTAB should converge");
229
230        let ax = a.matvec(&solution.x);
231        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
232        assert!(error < 1e-8, "Solution should satisfy Ax = b");
233    }
234}