use crate::error::{SparseError, SparseResult};
use crate::linalg::interface::LinearOperator;
use crate::linalg::iterative::{dot, norm2, BiCGOptions, IterationResult};
use scirs2_core::numeric::{Float, NumAssign, SparseElement};
use std::iter::Sum;
pub type CGSOptions<F> = BiCGOptions<F>;
pub type CGSResult<F> = IterationResult<F>;
#[allow(dead_code)]
pub fn cgs<F>(
a: &dyn LinearOperator<F>,
b: &[F],
options: CGSOptions<F>,
) -> SparseResult<CGSResult<F>>
where
F: Float + NumAssign + Sum + SparseElement + 'static,
{
let (rows, cols) = a.shape();
if rows != cols {
return Err(SparseError::ValueError(
"Matrix must be square for CGS solver".to_string(),
));
}
if b.len() != rows {
return Err(SparseError::DimensionMismatch {
expected: rows,
found: b.len(),
});
}
let n = rows;
let mut x: Vec<F> = match &options.x0 {
Some(x0) => {
if x0.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: x0.len(),
});
}
x0.clone()
}
None => vec![F::sparse_zero(); n],
};
let ax = a.matvec(&x)?;
let mut r: Vec<F> = b.iter().zip(&ax).map(|(&bi, &axi)| bi - axi).collect();
let mut rnorm = norm2(&r);
let bnorm = norm2(b);
let tolerance = F::max(options.atol, options.rtol * bnorm);
if rnorm <= tolerance {
return Ok(CGSResult {
x,
iterations: 0,
residual_norm: rnorm,
converged: true,
message: "Converged with initial guess".to_string(),
});
}
let r_tilde = r.clone();
let mut u = vec![F::sparse_zero(); n];
let mut p = vec![F::sparse_zero(); n];
let mut q = vec![F::sparse_zero(); n];
let mut rho = F::sparse_one();
let mut iterations = 0;
while iterations < options.max_iter {
let rho_new = dot(&r_tilde, &r);
if rho_new.abs() < F::epsilon() * F::from(10).expect("Failed to convert constant to float")
{
return Ok(CGSResult {
x,
iterations,
residual_norm: rnorm,
converged: false,
message: "CGS breakdown: rho ≈ 0".to_string(),
});
}
let beta = if iterations == 0 {
F::sparse_zero()
} else {
rho_new / rho
};
for i in 0..n {
u[i] = r[i] + beta * q[i];
p[i] = u[i] + beta * (q[i] + beta * p[i]);
}
let p_prec = if let Some(m) = &options.right_preconditioner {
m.matvec(&p)?
} else {
p.clone()
};
let v = a.matvec(&p_prec)?;
let sigma = dot(&r_tilde, &v);
if sigma.abs() < F::epsilon() * F::from(10).expect("Failed to convert constant to float") {
return Ok(CGSResult {
x,
iterations,
residual_norm: rnorm,
converged: false,
message: "CGS breakdown: sigma ≈ 0".to_string(),
});
}
let alpha = rho_new / sigma;
for i in 0..n {
q[i] = u[i] - alpha * v[i];
}
let u_plus_q: Vec<F> = u.iter().zip(&q).map(|(&ui, &qi)| ui + qi).collect();
let u_plus_q_prec = if let Some(m) = &options.right_preconditioner {
m.matvec(&u_plus_q)?
} else {
u_plus_q
};
for i in 0..n {
x[i] += alpha * u_plus_q_prec[i];
}
let q_prec = if let Some(m) = &options.right_preconditioner {
m.matvec(&q)?
} else {
q.clone()
};
let aq = a.matvec(&q_prec)?;
for i in 0..n {
r[i] -= alpha * (v[i] + aq[i]);
}
rho = rho_new;
iterations += 1;
rnorm = norm2(&r);
if rnorm <= tolerance {
break;
}
}
Ok(CGSResult {
x,
iterations,
residual_norm: rnorm,
converged: rnorm <= tolerance,
message: if rnorm <= tolerance {
"Converged".to_string()
} else {
"Maximum iterations reached".to_string()
},
})
}