math_audio_solvers/iterative/
bicgstab.rs1use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11#[derive(Debug, Clone)]
13pub struct BiCgstabConfig<R> {
14 pub max_iterations: usize,
16 pub tolerance: R,
18 pub print_interval: usize,
20}
21
22impl Default for BiCgstabConfig<f64> {
23 fn default() -> Self {
24 Self {
25 max_iterations: 1000,
26 tolerance: 1e-6,
27 print_interval: 0,
28 }
29 }
30}
31
32#[derive(Debug)]
34pub struct BiCgstabSolution<T: ComplexField> {
35 pub x: Array1<T>,
37 pub iterations: usize,
39 pub residual: T::Real,
41 pub converged: bool,
43}
44
45pub fn bicgstab<T, A>(
47 operator: &A,
48 b: &Array1<T>,
49 config: &BiCgstabConfig<T::Real>,
50) -> BiCgstabSolution<T>
51where
52 T: ComplexField,
53 A: LinearOperator<T>,
54{
55 let n = b.len();
56 let mut x = Array1::from_elem(n, T::zero());
57
58 let b_norm = vector_norm(b);
59 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
60 if b_norm < tol_threshold {
61 return BiCgstabSolution {
62 x,
63 iterations: 0,
64 residual: T::Real::zero(),
65 converged: true,
66 };
67 }
68
69 let mut r = b.clone();
71 let r0 = r.clone(); let mut rho = T::one();
74 let mut alpha = T::one();
75 let mut omega = T::one();
76
77 let mut p = Array1::from_elem(n, T::zero());
78 let mut v = Array1::from_elem(n, T::zero());
79
80 for iter in 0..config.max_iterations {
81 let rho_new = inner_product(&r0, &r);
82
83 if rho_new.norm() < T::Real::from_f64(1e-30).unwrap() {
85 return BiCgstabSolution {
86 x,
87 iterations: iter,
88 residual: vector_norm(&r) / b_norm,
89 converged: false,
90 };
91 }
92
93 let beta = (rho_new / rho) * (alpha / omega);
94 rho = rho_new;
95
96 p = &r + &(&p - &v.mapv(|vi| vi * omega)).mapv(|pi| pi * beta);
98
99 v = operator.apply(&p);
101
102 let r0v = inner_product(&r0, &v);
103 if r0v.norm() < T::Real::from_f64(1e-30).unwrap() {
104 return BiCgstabSolution {
105 x,
106 iterations: iter,
107 residual: vector_norm(&r) / b_norm,
108 converged: false,
109 };
110 }
111
112 alpha = rho / r0v;
113
114 let s = &r - &v.mapv(|vi| vi * alpha);
116
117 let s_norm = vector_norm(&s);
119 if s_norm / b_norm < config.tolerance {
120 x = &x + &p.mapv(|pi| pi * alpha);
121 return BiCgstabSolution {
122 x,
123 iterations: iter + 1,
124 residual: s_norm / b_norm,
125 converged: true,
126 };
127 }
128
129 let t = operator.apply(&s);
131
132 let tt = inner_product(&t, &t);
134 if tt.norm() < T::Real::from_f64(1e-30).unwrap() {
135 return BiCgstabSolution {
136 x,
137 iterations: iter,
138 residual: vector_norm(&r) / b_norm,
139 converged: false,
140 };
141 }
142 omega = inner_product(&t, &s) / tt;
143
144 x = &x + &p.mapv(|pi| pi * alpha) + &s.mapv(|si| si * omega);
146
147 r = &s - &t.mapv(|ti| ti * omega);
149
150 let rel_residual = vector_norm(&r) / b_norm;
151
152 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
153 log::info!(
154 "BiCGSTAB iteration {}: relative residual = {:.6e}",
155 iter + 1,
156 rel_residual.to_f64().unwrap_or(0.0)
157 );
158 }
159
160 if rel_residual < config.tolerance {
161 return BiCgstabSolution {
162 x,
163 iterations: iter + 1,
164 residual: rel_residual,
165 converged: true,
166 };
167 }
168
169 if omega.norm() < T::Real::from_f64(1e-30).unwrap() {
171 return BiCgstabSolution {
172 x,
173 iterations: iter + 1,
174 residual: rel_residual,
175 converged: false,
176 };
177 }
178 }
179
180 let rel_residual = vector_norm(&r) / b_norm;
181 BiCgstabSolution {
182 x,
183 iterations: config.max_iterations,
184 residual: rel_residual,
185 converged: false,
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::sparse::CsrMatrix;
193 use ndarray::array;
194 use num_complex::Complex64;
195
196 #[test]
197 fn test_bicgstab_simple() {
198 let dense = array![
199 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
200 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
201 ];
202
203 let a = CsrMatrix::from_dense(&dense, 1e-15);
204 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
205
206 let config = BiCgstabConfig {
207 max_iterations: 100,
208 tolerance: 1e-10,
209 print_interval: 0,
210 };
211
212 let solution = bicgstab(&a, &b, &config);
213
214 assert!(solution.converged, "BiCGSTAB should converge");
215
216 let ax = a.matvec(&solution.x);
217 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
218 assert!(error < 1e-8, "Solution should satisfy Ax = b");
219 }
220}