use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, Zero};
use super::core::{compute_norm_vec, dot_vec, matvec, SolverResult};
pub fn bicgstab<T>(
a: &Array<T>,
b: &Array<T>,
x0: Option<&Array<T>>,
tol: Option<T>,
max_iter: Option<usize>,
) -> Result<SolverResult<T>>
where
T: Float + Clone + Zero,
{
let shape = a.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(NumRs2Error::DimensionMismatch(
"Matrix must be square".to_string(),
));
}
let n = shape[0];
if b.size() != n {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![n],
actual: b.shape(),
});
}
let tol = tol.unwrap_or_else(|| T::from(1e-6).unwrap_or(T::epsilon()));
let max_iter = max_iter.unwrap_or(n);
let mut x_vec: Vec<T> = match x0 {
Some(x) => x.to_vec(),
None => vec![T::zero(); n],
};
let b_vec = b.to_vec();
let b_norm = compute_norm_vec(&b_vec);
if b_norm.is_zero() {
return Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: 0,
residual_norm: T::zero(),
converged: true,
});
}
let x_arr = Array::from_vec(x_vec.clone());
let ax = matvec(a, &x_arr)?;
let ax_vec = ax.to_vec();
let mut r_vec: Vec<T> = b_vec
.iter()
.zip(ax_vec.iter())
.map(|(&bi, &axi)| bi - axi)
.collect();
let r_norm = compute_norm_vec(&r_vec);
if r_norm / b_norm < tol {
return Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: 0,
residual_norm: r_norm,
converged: true,
});
}
let r0_vec = r_vec.clone();
let mut rho = dot_vec(&r0_vec, &r_vec);
let mut p_vec = r_vec.clone();
let mut v_vec: Vec<T>;
for iter in 0..max_iter {
let p_arr = Array::from_vec(p_vec.clone());
let v = matvec(a, &p_arr)?;
v_vec = v.to_vec();
let r0_dot_v = dot_vec(&r0_vec, &v_vec);
if r0_dot_v.abs() < T::from(1e-14).unwrap_or(T::epsilon()) {
return Err(NumRs2Error::ComputationError(
"BiCGSTAB breakdown: r0 dot v too small".to_string(),
));
}
let alpha = rho / r0_dot_v;
let s_vec: Vec<T> = r_vec
.iter()
.zip(v_vec.iter())
.map(|(&ri, &vi)| ri - alpha * vi)
.collect();
let s_norm = compute_norm_vec(&s_vec);
if s_norm / b_norm < tol {
for i in 0..n {
x_vec[i] = x_vec[i] + alpha * p_vec[i];
}
return Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: iter + 1,
residual_norm: s_norm,
converged: true,
});
}
let s_arr = Array::from_vec(s_vec.clone());
let t = matvec(a, &s_arr)?;
let t_vec = t.to_vec();
let t_dot_t = dot_vec(&t_vec, &t_vec);
if t_dot_t.abs() < T::from(1e-14).unwrap_or(T::epsilon()) {
for i in 0..n {
x_vec[i] = x_vec[i] + alpha * p_vec[i];
}
return Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: iter + 1,
residual_norm: s_norm,
converged: true,
});
}
let omega = dot_vec(&t_vec, &s_vec) / t_dot_t;
for i in 0..n {
x_vec[i] = x_vec[i] + alpha * p_vec[i] + omega * s_vec[i];
}
for i in 0..n {
r_vec[i] = s_vec[i] - omega * t_vec[i];
}
let r_norm = compute_norm_vec(&r_vec);
if r_norm / b_norm < tol {
return Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: iter + 1,
residual_norm: r_norm,
converged: true,
});
}
let rho_new = dot_vec(&r0_vec, &r_vec);
if rho.abs() < T::from(1e-14).unwrap_or(T::epsilon()) {
return Err(NumRs2Error::ComputationError(
"BiCGSTAB breakdown: rho too small".to_string(),
));
}
let beta = (rho_new / rho) * (alpha / omega);
for i in 0..n {
p_vec[i] = r_vec[i] + beta * (p_vec[i] - omega * v_vec[i]);
}
rho = rho_new;
}
let r_norm = compute_norm_vec(&r_vec);
Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: max_iter,
residual_norm: r_norm,
converged: false,
})
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_bicgstab_simple() {
let a = Array::from_vec(vec![3.0, 1.0, 1.0, 2.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![1.0, 2.0]);
let result = bicgstab(&a, &b, None, Some(1e-6), Some(100)).expect("Should solve");
assert!(result.converged);
}
#[test]
fn test_bicgstab_identity() {
let a = Array::from_vec(vec![1.0, 0.0, 0.0, 1.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![3.0, 4.0]);
let result = bicgstab(&a, &b, None, Some(1e-10), Some(100)).expect("Should solve");
assert!(result.converged);
for i in 0..2 {
assert_relative_eq!(
result.solution.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-8
);
}
}
#[test]
fn test_bicgstab_diagonal() {
let a = Array::from_vec(vec![2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0]).reshape(&[3, 3]);
let b = Array::from_vec(vec![4.0, 9.0, 16.0]);
let result = bicgstab(&a, &b, None, Some(1e-10), Some(100)).expect("Should solve");
assert!(result.converged);
assert_relative_eq!(
result.solution.get(&[0]).expect("valid"),
2.0,
epsilon = 1e-6
);
assert_relative_eq!(
result.solution.get(&[1]).expect("valid"),
3.0,
epsilon = 1e-6
);
assert_relative_eq!(
result.solution.get(&[2]).expect("valid"),
4.0,
epsilon = 1e-6
);
}
#[test]
fn test_bicgstab_non_symmetric() {
let a = Array::from_vec(vec![4.0, 1.0, 2.0, 3.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5.0, 5.0]);
let result = bicgstab(&a, &b, None, Some(1e-10), Some(100)).expect("Should solve");
assert!(result.converged);
let ax = matvec(&a, &result.solution).expect("matvec should work");
for i in 0..2 {
assert_relative_eq!(
ax.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-6
);
}
}
#[test]
fn test_bicgstab_larger_system() {
let a = Array::from_vec(vec![
4.0, 1.0, 0.0, 0.0, 1.0, 4.0, 1.0, 0.0, 0.0, 1.0, 4.0, 1.0, 0.0, 0.0, 1.0, 4.0,
])
.reshape(&[4, 4]);
let b = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let result = bicgstab(&a, &b, None, Some(1e-10), Some(100)).expect("Should solve");
assert!(result.converged);
let ax = matvec(&a, &result.solution).expect("matvec should work");
for i in 0..4 {
assert_relative_eq!(
ax.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-6
);
}
}
#[test]
fn test_bicgstab_zero_rhs() {
let a = Array::from_vec(vec![4.0, 1.0, 1.0, 3.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![0.0, 0.0]);
let result = bicgstab(&a, &b, None, Some(1e-6), Some(100)).expect("Should solve");
assert!(result.converged);
assert_eq!(result.iterations, 0);
}
#[test]
fn test_bicgstab_with_initial_guess() {
let a = Array::from_vec(vec![4.0, 1.0, 1.0, 3.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5.0, 5.0]);
let x0 = Array::from_vec(vec![1.0, 1.0]);
let result = bicgstab(&a, &b, Some(&x0), Some(1e-10), Some(100)).expect("Should solve");
assert!(result.converged);
let ax = matvec(&a, &result.solution).expect("matvec should work");
for i in 0..2 {
assert_relative_eq!(
ax.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-6
);
}
}
}