1use ndarray::Array1;
14use num_complex::Complex64;
15
16#[derive(Debug, Clone)]
18pub struct BiCgstabConfig {
19 pub max_iterations: usize,
21 pub tolerance: f64,
23 pub print_interval: usize,
25}
26
27impl Default for BiCgstabConfig {
28 fn default() -> Self {
29 Self {
30 max_iterations: 1000,
31 tolerance: 1e-6,
32 print_interval: 10,
33 }
34 }
35}
36
37#[derive(Debug)]
39pub struct BiCgstabSolution {
40 pub x: Array1<Complex64>,
42 pub iterations: usize,
44 pub residual: f64,
46 pub converged: bool,
48}
49
50pub fn bicgstab_solve<F>(
68 matvec: F,
69 b: &Array1<Complex64>,
70 x0: Option<&Array1<Complex64>>,
71 config: &BiCgstabConfig,
72) -> BiCgstabSolution
73where
74 F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
75{
76 let n = b.len();
77
78 let mut x = match x0 {
80 Some(x0) => x0.clone(),
81 None => Array1::zeros(n),
82 };
83
84 let ax = matvec(&x);
86 let mut r: Array1<Complex64> = b - &ax;
87
88 let r_tilde = r.clone();
90
91 let err_ori = residual_norm(&r);
93 if err_ori < 1e-15 {
94 return BiCgstabSolution {
95 x,
96 iterations: 0,
97 residual: 0.0,
98 converged: true,
99 };
100 }
101
102 let mut rho = Complex64::new(1.0, 0.0);
104 let mut alpha = Complex64::new(1.0, 0.0);
105 let mut omega = Complex64::new(1.0, 0.0);
106
107 let mut v: Array1<Complex64> = Array1::zeros(n);
109 let mut p: Array1<Complex64> = Array1::zeros(n);
110
111 let mut iterations = 0;
112 let mut err_rel = 1.0;
113
114 for j in 0..config.max_iterations {
115 iterations = j + 1;
116
117 let rho_new = inner_product(&r_tilde, &r);
119
120 if rho_new.norm() < 1e-30 {
122 return BiCgstabSolution {
123 x,
124 iterations,
125 residual: err_rel,
126 converged: false,
127 };
128 }
129
130 let beta = (rho_new / rho) * (alpha / omega);
132
133 let p_minus_omega_v: Array1<Complex64> = &p - &(&v * omega);
135 p = &r + &(&p_minus_omega_v * beta);
136
137 v = matvec(&p);
139
140 let r_tilde_v = inner_product(&r_tilde, &v);
142 if r_tilde_v.norm() < 1e-30 {
143 return BiCgstabSolution {
144 x,
145 iterations,
146 residual: err_rel,
147 converged: false,
148 };
149 }
150 alpha = rho_new / r_tilde_v;
151
152 let s: Array1<Complex64> = &r - &(&v * alpha);
154
155 let s_norm = residual_norm(&s);
157 if s_norm / err_ori < config.tolerance {
158 x = &x + &(&p * alpha);
160 return BiCgstabSolution {
161 x,
162 iterations,
163 residual: s_norm / err_ori,
164 converged: true,
165 };
166 }
167
168 let t = matvec(&s);
170
171 let t_s = inner_product(&t, &s);
173 let t_t = inner_product(&t, &t);
174 if t_t.norm() < 1e-30 {
175 return BiCgstabSolution {
176 x,
177 iterations,
178 residual: err_rel,
179 converged: false,
180 };
181 }
182 omega = t_s / t_t;
183
184 x = &x + &(&p * alpha) + &(&s * omega);
186
187 r = &s - &(&t * omega);
189
190 let r_norm = residual_norm(&r);
192 err_rel = r_norm / err_ori;
193
194 if config.print_interval > 0 && j % config.print_interval == 0 {
196 eprintln!("BiCGSTAB iteration {}: relative residual = {:.6e}", j, err_rel);
197 }
198
199 if err_rel < config.tolerance {
201 return BiCgstabSolution {
202 x,
203 iterations,
204 residual: err_rel,
205 converged: true,
206 };
207 }
208
209 if omega.norm() < 1e-30 {
211 return BiCgstabSolution {
212 x,
213 iterations,
214 residual: err_rel,
215 converged: false,
216 };
217 }
218
219 rho = rho_new;
220 }
221
222 BiCgstabSolution {
223 x,
224 iterations,
225 residual: err_rel,
226 converged: false,
227 }
228}
229
230fn inner_product(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
232 x.iter()
233 .zip(y.iter())
234 .map(|(xi, yi)| xi.conj() * yi)
235 .sum()
236}
237
238fn residual_norm(r: &Array1<Complex64>) -> f64 {
240 r.iter().map(|ri| ri.norm_sqr()).sum::<f64>().sqrt()
241}
242
243pub fn bicgstab_solve_preconditioned<F, P>(
247 matvec: F,
248 precond_solve: P,
249 b: &Array1<Complex64>,
250 x0: Option<&Array1<Complex64>>,
251 config: &BiCgstabConfig,
252) -> BiCgstabSolution
253where
254 F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
255 P: Fn(&Array1<Complex64>) -> Array1<Complex64>,
256{
257 let n = b.len();
258
259 let mut x = match x0 {
261 Some(x0) => x0.clone(),
262 None => Array1::zeros(n),
263 };
264
265 let ax = matvec(&x);
267 let r0: Array1<Complex64> = b - &ax;
268
269 let mut r = precond_solve(&r0);
271 let r_tilde = r.clone();
272
273 let err_ori = residual_norm(&r);
274 if err_ori < 1e-15 {
275 return BiCgstabSolution {
276 x,
277 iterations: 0,
278 residual: 0.0,
279 converged: true,
280 };
281 }
282
283 let mut rho = Complex64::new(1.0, 0.0);
284 let mut alpha = Complex64::new(1.0, 0.0);
285 let mut omega = Complex64::new(1.0, 0.0);
286
287 let mut v: Array1<Complex64> = Array1::zeros(n);
288 let mut p: Array1<Complex64> = Array1::zeros(n);
289
290 let mut iterations = 0;
291 let mut err_rel = 1.0;
292
293 for j in 0..config.max_iterations {
294 iterations = j + 1;
295
296 let rho_new = inner_product(&r_tilde, &r);
297 if rho_new.norm() < 1e-30 {
298 return BiCgstabSolution {
299 x,
300 iterations,
301 residual: err_rel,
302 converged: false,
303 };
304 }
305
306 let beta = (rho_new / rho) * (alpha / omega);
307 let p_minus_omega_v: Array1<Complex64> = &p - &(&v * omega);
308 p = &r + &(&p_minus_omega_v * beta);
309
310 let ap = matvec(&p);
312 v = precond_solve(&ap);
313
314 let r_tilde_v = inner_product(&r_tilde, &v);
315 if r_tilde_v.norm() < 1e-30 {
316 return BiCgstabSolution {
317 x,
318 iterations,
319 residual: err_rel,
320 converged: false,
321 };
322 }
323 alpha = rho_new / r_tilde_v;
324
325 let s: Array1<Complex64> = &r - &(&v * alpha);
326 let s_norm = residual_norm(&s);
327 if s_norm / err_ori < config.tolerance {
328 x = &x + &(&p * alpha);
329 return BiCgstabSolution {
330 x,
331 iterations,
332 residual: s_norm / err_ori,
333 converged: true,
334 };
335 }
336
337 let as_ = matvec(&s);
339 let t = precond_solve(&as_);
340
341 let t_s = inner_product(&t, &s);
342 let t_t = inner_product(&t, &t);
343 if t_t.norm() < 1e-30 {
344 return BiCgstabSolution {
345 x,
346 iterations,
347 residual: err_rel,
348 converged: false,
349 };
350 }
351 omega = t_s / t_t;
352
353 x = &x + &(&p * alpha) + &(&s * omega);
354 r = &s - &(&t * omega);
355
356 let r_norm = residual_norm(&r);
357 err_rel = r_norm / err_ori;
358
359 if config.print_interval > 0 && j % config.print_interval == 0 {
360 eprintln!(
361 "BiCGSTAB (precond) iteration {}: relative residual = {:.6e}",
362 j, err_rel
363 );
364 }
365
366 if err_rel < config.tolerance {
367 return BiCgstabSolution {
368 x,
369 iterations,
370 residual: err_rel,
371 converged: true,
372 };
373 }
374
375 if omega.norm() < 1e-30 {
376 return BiCgstabSolution {
377 x,
378 iterations,
379 residual: err_rel,
380 converged: false,
381 };
382 }
383
384 rho = rho_new;
385 }
386
387 BiCgstabSolution {
388 x,
389 iterations,
390 residual: err_rel,
391 converged: false,
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use ndarray::Array2;
399
400 #[test]
401 fn test_bicgstab_simple() {
402 let a = Array2::from_shape_vec(
404 (2, 2),
405 vec![
406 Complex64::new(4.0, 0.0),
407 Complex64::new(1.0, 0.0),
408 Complex64::new(1.0, 0.0),
409 Complex64::new(3.0, 0.0),
410 ],
411 )
412 .unwrap();
413
414 let b = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)]);
415
416 let matvec = |x: &Array1<Complex64>| a.dot(x);
417
418 let config = BiCgstabConfig {
419 max_iterations: 100,
420 tolerance: 1e-10,
421 print_interval: 0,
422 };
423
424 let solution = bicgstab_solve(&matvec, &b, None, &config);
425
426 assert!(solution.converged, "BiCGSTAB should converge");
427
428 let ax = a.dot(&solution.x);
430 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
431 assert!(error < 1e-8, "Solution should satisfy Ax = b");
432 }
433
434 #[test]
435 fn test_bicgstab_complex() {
436 let a = Array2::from_shape_vec(
438 (2, 2),
439 vec![
440 Complex64::new(2.0, 1.0),
441 Complex64::new(0.0, -1.0),
442 Complex64::new(0.0, 1.0),
443 Complex64::new(2.0, -1.0),
444 ],
445 )
446 .unwrap();
447
448 let b = Array1::from_vec(vec![Complex64::new(1.0, 1.0), Complex64::new(1.0, -1.0)]);
449
450 let matvec = |x: &Array1<Complex64>| a.dot(x);
451
452 let config = BiCgstabConfig {
453 max_iterations: 100,
454 tolerance: 1e-10,
455 print_interval: 0,
456 };
457
458 let solution = bicgstab_solve(&matvec, &b, None, &config);
459
460 assert!(solution.converged, "BiCGSTAB should converge for complex system");
461
462 let ax = a.dot(&solution.x);
464 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
465 assert!(error < 1e-8, "Solution should satisfy Ax = b");
466 }
467
468 #[test]
469 fn test_bicgstab_identity() {
470 let n = 5;
472 let b = Array1::from_vec(
473 (1..=n)
474 .map(|i| Complex64::new(i as f64, 0.0))
475 .collect::<Vec<_>>(),
476 );
477
478 let matvec = |x: &Array1<Complex64>| x.clone();
479
480 let config = BiCgstabConfig {
481 max_iterations: 10,
482 tolerance: 1e-12,
483 print_interval: 0,
484 };
485
486 let solution = bicgstab_solve(&matvec, &b, None, &config);
487
488 assert!(solution.converged);
489 assert!(solution.iterations <= 2); let error: f64 = (&solution.x - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
493 assert!(error < 1e-10);
494 }
495
496 #[test]
497 fn test_bicgstab_non_symmetric() {
498 let a = Array2::from_shape_vec(
500 (3, 3),
501 vec![
502 Complex64::new(4.0, 0.0),
503 Complex64::new(1.0, 0.0),
504 Complex64::new(0.0, 0.0),
505 Complex64::new(2.0, 0.0),
506 Complex64::new(5.0, 0.0),
507 Complex64::new(1.0, 0.0),
508 Complex64::new(0.0, 0.0),
509 Complex64::new(1.0, 0.0),
510 Complex64::new(3.0, 0.0),
511 ],
512 )
513 .unwrap();
514
515 let b = Array1::from_vec(vec![
516 Complex64::new(5.0, 0.0),
517 Complex64::new(8.0, 0.0),
518 Complex64::new(4.0, 0.0),
519 ]);
520
521 let matvec = |x: &Array1<Complex64>| a.dot(x);
522
523 let config = BiCgstabConfig {
524 max_iterations: 100,
525 tolerance: 1e-10,
526 print_interval: 0,
527 };
528
529 let solution = bicgstab_solve(&matvec, &b, None, &config);
530
531 assert!(solution.converged, "BiCGSTAB should converge for non-symmetric system");
532
533 let ax = a.dot(&solution.x);
535 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
536 assert!(error < 1e-8, "Solution should satisfy Ax = b");
537 }
538
539 #[test]
540 fn test_bicgstab_vs_cgs_stability() {
541 let a = Array2::from_shape_vec(
543 (3, 3),
544 vec![
545 Complex64::new(1.0, 0.1),
546 Complex64::new(0.5, 0.0),
547 Complex64::new(0.0, 0.0),
548 Complex64::new(0.5, 0.0),
549 Complex64::new(1.0, -0.1),
550 Complex64::new(0.5, 0.0),
551 Complex64::new(0.0, 0.0),
552 Complex64::new(0.5, 0.0),
553 Complex64::new(1.0, 0.1),
554 ],
555 )
556 .unwrap();
557
558 let b = Array1::from_vec(vec![
559 Complex64::new(1.0, 0.0),
560 Complex64::new(1.0, 0.0),
561 Complex64::new(1.0, 0.0),
562 ]);
563
564 let matvec = |x: &Array1<Complex64>| a.dot(x);
565
566 let config = BiCgstabConfig {
567 max_iterations: 100,
568 tolerance: 1e-10,
569 print_interval: 0,
570 };
571
572 let solution = bicgstab_solve(&matvec, &b, None, &config);
573
574 assert!(solution.converged);
575
576 let ax = a.dot(&solution.x);
577 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
578 assert!(error < 1e-8);
579 }
580}