math_audio_solvers/iterative/
bicgstab.rs1use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator, SolverStatus};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11#[derive(Debug, Clone)]
13pub struct BiCgstabConfig<R> {
14 pub max_iterations: usize,
16 pub tolerance: R,
18 pub print_interval: usize,
20}
21
22impl Default for BiCgstabConfig<f64> {
23 fn default() -> Self {
24 Self {
25 max_iterations: 1000,
26 tolerance: 1e-6,
27 print_interval: 0,
28 }
29 }
30}
31
32#[derive(Debug)]
34pub struct BiCgstabSolution<T: ComplexField> {
35 pub x: Array1<T>,
37 pub iterations: usize,
39 pub residual: T::Real,
41 pub converged: bool,
43 pub status: SolverStatus,
45}
46
47pub fn bicgstab<T, A>(
49 operator: &A,
50 b: &Array1<T>,
51 config: &BiCgstabConfig<T::Real>,
52) -> BiCgstabSolution<T>
53where
54 T: ComplexField,
55 A: LinearOperator<T>,
56{
57 let n = b.len();
58 let mut x = Array1::from_elem(n, T::zero());
59
60 let b_norm = vector_norm(b);
61 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
62 if b_norm < tol_threshold {
63 return BiCgstabSolution {
64 x,
65 iterations: 0,
66 residual: T::Real::zero(),
67 converged: true,
68 status: SolverStatus::Converged,
69 };
70 }
71
72 let mut r = b.clone();
74 let r0 = r.clone(); let mut rho = T::one();
77 let mut alpha = T::one();
78 let mut omega = T::one();
79
80 let mut p = Array1::from_elem(n, T::zero());
81 let mut v = Array1::from_elem(n, T::zero());
82
83 for iter in 0..config.max_iterations {
84 let rho_new = inner_product(&r0, &r);
85
86 if rho_new.norm() < T::Real::from_f64(1e-20).unwrap() {
88 return BiCgstabSolution {
89 x,
90 iterations: iter,
91 residual: vector_norm(&r) / b_norm,
92 converged: false,
93 status: SolverStatus::Breakdown,
94 };
95 }
96
97 let beta = (rho_new / rho) * (alpha / omega);
98 rho = rho_new;
99
100 p = &r + &(&p - &v.mapv(|vi| vi * omega)).mapv(|pi| pi * beta);
102
103 v = operator.apply(&p);
105
106 let r0v = inner_product(&r0, &v);
107 if r0v.norm() < T::Real::from_f64(1e-20).unwrap() {
108 return BiCgstabSolution {
109 x,
110 iterations: iter,
111 residual: vector_norm(&r) / b_norm,
112 converged: false,
113 status: SolverStatus::Breakdown,
114 };
115 }
116
117 alpha = rho / r0v;
118
119 let s = &r - &v.mapv(|vi| vi * alpha);
121
122 let s_norm = vector_norm(&s);
124 if s_norm / b_norm < config.tolerance {
125 x = &x + &p.mapv(|pi| pi * alpha);
126 return BiCgstabSolution {
127 x,
128 iterations: iter + 1,
129 residual: s_norm / b_norm,
130 converged: true,
131 status: SolverStatus::Converged,
132 };
133 }
134
135 let t = operator.apply(&s);
137
138 let tt = inner_product(&t, &t);
140 if tt.norm() < T::Real::from_f64(1e-20).unwrap() {
141 return BiCgstabSolution {
142 x,
143 iterations: iter,
144 residual: vector_norm(&r) / b_norm,
145 converged: false,
146 status: SolverStatus::Breakdown,
147 };
148 }
149 omega = inner_product(&t, &s) / tt;
150
151 x = &x + &p.mapv(|pi| pi * alpha) + &s.mapv(|si| si * omega);
153
154 r = &s - &t.mapv(|ti| ti * omega);
156
157 let rel_residual = vector_norm(&r) / b_norm;
158
159 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
160 log::info!(
161 "BiCGSTAB iteration {}: relative residual = {:.6e}",
162 iter + 1,
163 rel_residual.to_f64().unwrap_or(0.0)
164 );
165 }
166
167 if rel_residual < config.tolerance {
168 return BiCgstabSolution {
169 x,
170 iterations: iter + 1,
171 residual: rel_residual,
172 converged: true,
173 status: SolverStatus::Converged,
174 };
175 }
176
177 if omega.norm() < T::Real::from_f64(1e-20).unwrap() {
179 return BiCgstabSolution {
180 x,
181 iterations: iter + 1,
182 residual: rel_residual,
183 converged: false,
184 status: SolverStatus::Stagnated,
185 };
186 }
187 }
188
189 let rel_residual = vector_norm(&r) / b_norm;
190 BiCgstabSolution {
191 x,
192 iterations: config.max_iterations,
193 residual: rel_residual,
194 converged: false,
195 status: SolverStatus::MaxIterationsReached,
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::sparse::CsrMatrix;
203 use ndarray::array;
204 use num_complex::Complex64;
205
206 #[test]
207 fn test_bicgstab_simple() {
208 let dense = array![
209 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
210 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
211 ];
212
213 let a = CsrMatrix::from_dense(&dense, 1e-15);
214 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
215
216 let config = BiCgstabConfig {
217 max_iterations: 100,
218 tolerance: 1e-10,
219 print_interval: 0,
220 };
221
222 let solution = bicgstab(&a, &b, &config);
223
224 assert!(solution.converged, "BiCGSTAB should converge");
225
226 let ax = a.matvec(&solution.x);
227 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
228 assert!(error < 1e-8, "Solution should satisfy Ax = b");
229 }
230}