use crate::indexing::SpIndex;
use crate::sparse::{CsMatViewI, CsVecI, CsVecViewI};
use num_traits::One;
#[derive(Debug)]
pub struct BiCGSTAB<'a, T, I: SpIndex, Iptr: SpIndex> {
iteration_count: usize,
soft_restart_threshold: T,
soft_restart_count: usize,
hard_restart_count: usize,
err: T,
a: CsMatViewI<'a, T, I, Iptr>,
b: CsVecViewI<'a, T, I>,
x: CsVecI<T, I>,
r: CsVecI<T, I>,
rhat: CsVecI<T, I>, p: CsVecI<T, I>,
rho: T,
}
macro_rules! bicgstab_impl {
($T: ty) => {
impl<'a, I: SpIndex, Iptr: SpIndex> BiCGSTAB<'a, $T, I, Iptr> {
pub fn new(
a: CsMatViewI<'a, $T, I, Iptr>,
x0: CsVecViewI<'a, $T, I>,
b: CsVecViewI<'a, $T, I>,
) -> Self {
let r = &b - &(&a.view() * &x0.view()).view();
let rhat = r.to_owned();
let p = r.to_owned();
let err = (&r).l2_norm();
let rho = err * err;
let x = x0.to_owned();
Self {
iteration_count: 0,
soft_restart_threshold: 0.1 * <$T>::one(), soft_restart_count: 0,
hard_restart_count: 0,
err,
a,
b,
x,
r,
rhat,
p,
rho,
}
}
pub fn solve(
a: CsMatViewI<'a, $T, I, Iptr>,
x0: CsVecViewI<'a, $T, I>,
b: CsVecViewI<'a, $T, I>,
tol: $T,
max_iter: usize,
) -> Result<
Box<BiCGSTAB<'a, $T, I, Iptr>>,
Box<BiCGSTAB<'a, $T, I, Iptr>>,
> {
let mut solver = Self::new(a, x0, b);
for _ in 0..max_iter {
solver.step();
if solver.err() < tol {
solver.hard_restart();
if solver.err() < tol {
return Ok(Box::new(solver));
}
}
}
Err(Box::new(solver))
}
pub fn soft_restart(&mut self) {
self.soft_restart_count += 1;
self.rhat = self.r.to_owned();
self.rho = self.err * self.err; self.p = self.r.to_owned();
}
pub fn hard_restart(&mut self) {
self.hard_restart_count += 1;
self.r = &self.b - &(&self.a.view() * &self.x.view()).view();
self.err = (&self.r).l2_norm();
self.soft_restart();
self.soft_restart_count -= 1; }
pub fn step(&mut self) -> $T {
self.iteration_count += 1;
let v = &self.a.view() * &self.p.view();
let alpha = self.rho / ((&self.rhat).dot(&v));
let h = &self.x + &self.p.map(|x| x * alpha);
let s = &self.r - &v.map(|x| x * alpha); let t = &self.a.view() * &s.view();
let omega = t.dot(&s) / &t.squared_l2_norm();
self.x = &h.view() + &s.map(|x| omega * x);
self.r = &s - &t.map(|x| x * omega);
self.err = (&self.r).l2_norm();
let rho_prev = self.rho;
self.rho = (&self.rhat).dot(&self.r);
if self.rho.abs() / (self.err * self.err)
< self.soft_restart_threshold
{
self.soft_restart();
} else {
let beta = (self.rho / rho_prev) * (alpha / omega);
self.p = &self.r
+ (&self.p - &v.map(|x| x * omega)).map(|x| x * beta);
}
self.err
}
pub fn with_restart_threshold(mut self, thresh: $T) -> Self {
self.soft_restart_threshold = thresh;
self
}
pub fn iteration_count(&self) -> usize {
self.iteration_count
}
pub fn soft_restart_threshold(&self) -> $T {
self.soft_restart_threshold
}
pub fn soft_restart_count(&self) -> usize {
self.soft_restart_count
}
pub fn hard_restart_count(&self) -> usize {
self.hard_restart_count
}
pub fn err(&self) -> $T {
self.err
}
pub fn rho(&self) -> $T {
self.rho
}
pub fn a(&self) -> CsMatViewI<'_, $T, I, Iptr> {
self.a.view()
}
pub fn x(&self) -> CsVecViewI<'_, $T, I> {
self.x.view()
}
pub fn b(&self) -> CsVecViewI<'_, $T, I> {
self.b.view()
}
pub fn r(&self) -> CsVecViewI<'_, $T, I> {
self.r.view()
}
pub fn rhat(&self) -> CsVecViewI<'_, $T, I> {
self.rhat.view()
}
pub fn p(&self) -> CsVecViewI<'_, $T, I> {
self.p.view()
}
}
};
}
bicgstab_impl!(f64);
bicgstab_impl!(f32);
#[cfg(test)]
mod test {
use super::*;
use crate::CsMatI;
#[test]
fn test_bicgstab_f32() {
let a = CsMatI::new_csc(
(4, 4),
vec![0, 2, 4, 6, 8],
vec![0, 3, 1, 2, 1, 2, 0, 3],
vec![1.0, 2., 21., 6., 6., 2., 2., 8.],
);
let tol = 1e-18;
let max_iter = 50;
let b = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0; 4]);
let x0 = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0, 1.0, 1.0, 1.0]);
let res = BiCGSTAB::<'_, f32, _, _>::solve(
a.view(),
x0.view(),
b.view(),
tol,
max_iter,
)
.unwrap();
let b_recovered = &a * &res.x();
println!("Iteration count {:?}", res.iteration_count());
println!("Soft restart count {:?}", res.soft_restart_count());
println!("Hard restart count {:?}", res.hard_restart_count());
for (input, output) in
b.to_dense().iter().zip(b_recovered.to_dense().iter())
{
assert!(
(1.0 - input / output).abs() < tol,
"Solved output did not match input"
);
}
}
#[test]
fn test_bicgstab_f64() {
let a = CsMatI::new_csc(
(4, 4),
vec![0, 2, 4, 6, 8],
vec![0, 3, 1, 2, 1, 2, 0, 3],
vec![1.0, 2., 21., 6., 6., 2., 2., 8.],
);
let tol = 1e-60;
let max_iter = 50;
let b = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0; 4]);
let x0 = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0, 1.0, 1.0, 1.0]);
let res = BiCGSTAB::<'_, f64, _, _>::solve(
a.view(),
x0.view(),
b.view(),
tol,
max_iter,
)
.unwrap();
let b_recovered = &a * &res.x();
println!("Iteration count {:?}", res.iteration_count());
println!("Soft restart count {:?}", res.soft_restart_count());
println!("Hard restart count {:?}", res.hard_restart_count());
for (input, output) in
b.to_dense().iter().zip(b_recovered.to_dense().iter())
{
assert!(
(1.0 - input / output).abs() < tol,
"Solved output did not match input"
);
}
}
}