1use crate::traits::{ComplexField, LinearOperator};
7use ndarray::Array1;
8use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
9
10#[derive(Debug, Clone)]
12pub struct BiCgstabConfig<R> {
13 pub max_iterations: usize,
15 pub tolerance: R,
17 pub print_interval: usize,
19}
20
21impl Default for BiCgstabConfig<f64> {
22 fn default() -> Self {
23 Self {
24 max_iterations: 1000,
25 tolerance: 1e-6,
26 print_interval: 0,
27 }
28 }
29}
30
31#[derive(Debug)]
33pub struct BiCgstabSolution<T: ComplexField> {
34 pub x: Array1<T>,
36 pub iterations: usize,
38 pub residual: T::Real,
40 pub converged: bool,
42}
43
44pub fn bicgstab<T, A>(
46 operator: &A,
47 b: &Array1<T>,
48 config: &BiCgstabConfig<T::Real>,
49) -> BiCgstabSolution<T>
50where
51 T: ComplexField,
52 A: LinearOperator<T>,
53{
54 let n = b.len();
55 let mut x = Array1::from_elem(n, T::zero());
56
57 let b_norm = vector_norm(b);
58 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
59 if b_norm < tol_threshold {
60 return BiCgstabSolution {
61 x,
62 iterations: 0,
63 residual: T::Real::zero(),
64 converged: true,
65 };
66 }
67
68 let mut r = b.clone();
70 let r0 = r.clone(); let mut rho = T::one();
73 let mut alpha = T::one();
74 let mut omega = T::one();
75
76 let mut p = Array1::from_elem(n, T::zero());
77 let mut v = Array1::from_elem(n, T::zero());
78
79 for iter in 0..config.max_iterations {
80 let rho_new = inner_product(&r0, &r);
81
82 if rho_new.norm() < T::Real::from_f64(1e-30).unwrap() {
84 return BiCgstabSolution {
85 x,
86 iterations: iter,
87 residual: vector_norm(&r) / b_norm,
88 converged: false,
89 };
90 }
91
92 let beta = (rho_new / rho) * (alpha / omega);
93 rho = rho_new;
94
95 p = &r + &(&p - &v.mapv(|vi| vi * omega)).mapv(|pi| pi * beta);
97
98 v = operator.apply(&p);
100
101 let r0v = inner_product(&r0, &v);
102 if r0v.norm() < T::Real::from_f64(1e-30).unwrap() {
103 return BiCgstabSolution {
104 x,
105 iterations: iter,
106 residual: vector_norm(&r) / b_norm,
107 converged: false,
108 };
109 }
110
111 alpha = rho / r0v;
112
113 let s = &r - &v.mapv(|vi| vi * alpha);
115
116 let s_norm = vector_norm(&s);
118 if s_norm / b_norm < config.tolerance {
119 x = &x + &p.mapv(|pi| pi * alpha);
120 return BiCgstabSolution {
121 x,
122 iterations: iter + 1,
123 residual: s_norm / b_norm,
124 converged: true,
125 };
126 }
127
128 let t = operator.apply(&s);
130
131 let tt = inner_product(&t, &t);
133 if tt.norm() < T::Real::from_f64(1e-30).unwrap() {
134 return BiCgstabSolution {
135 x,
136 iterations: iter,
137 residual: vector_norm(&r) / b_norm,
138 converged: false,
139 };
140 }
141 omega = inner_product(&t, &s) / tt;
142
143 x = &x + &p.mapv(|pi| pi * alpha) + &s.mapv(|si| si * omega);
145
146 r = &s - &t.mapv(|ti| ti * omega);
148
149 let rel_residual = vector_norm(&r) / b_norm;
150
151 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
152 log::info!(
153 "BiCGSTAB iteration {}: relative residual = {:.6e}",
154 iter + 1,
155 rel_residual.to_f64().unwrap_or(0.0)
156 );
157 }
158
159 if rel_residual < config.tolerance {
160 return BiCgstabSolution {
161 x,
162 iterations: iter + 1,
163 residual: rel_residual,
164 converged: true,
165 };
166 }
167
168 if omega.norm() < T::Real::from_f64(1e-30).unwrap() {
170 return BiCgstabSolution {
171 x,
172 iterations: iter + 1,
173 residual: rel_residual,
174 converged: false,
175 };
176 }
177 }
178
179 let rel_residual = vector_norm(&r) / b_norm;
180 BiCgstabSolution {
181 x,
182 iterations: config.max_iterations,
183 residual: rel_residual,
184 converged: false,
185 }
186}
187
188#[inline]
189fn inner_product<T: ComplexField>(x: &Array1<T>, y: &Array1<T>) -> T {
190 x.iter()
191 .zip(y.iter())
192 .fold(T::zero(), |acc, (&xi, &yi)| acc + xi.conj() * yi)
193}
194
195#[inline]
196fn vector_norm<T: ComplexField>(x: &Array1<T>) -> T::Real {
197 x.iter()
198 .map(|xi| xi.norm_sqr())
199 .fold(T::Real::zero(), |acc, v| acc + v)
200 .sqrt()
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::sparse::CsrMatrix;
207 use ndarray::array;
208 use num_complex::Complex64;
209
210 #[test]
211 fn test_bicgstab_simple() {
212 let dense = array![
213 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
214 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
215 ];
216
217 let a = CsrMatrix::from_dense(&dense, 1e-15);
218 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
219
220 let config = BiCgstabConfig {
221 max_iterations: 100,
222 tolerance: 1e-10,
223 print_interval: 0,
224 };
225
226 let solution = bicgstab(&a, &b, &config);
227
228 assert!(solution.converged, "BiCGSTAB should converge");
229
230 let ax = a.matvec(&solution.x);
231 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
232 assert!(error < 1e-8, "Solution should satisfy Ax = b");
233 }
234}