use crate::csr::CsrMatrix;
use crate::error::{SparseError, SparseResult};
use scirs2_core::numeric::{Float, NumAssign, SparseElement};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct CGLSResult {
pub x: Vec<f64>,
pub iters: usize,
pub rel_norm: f64,
pub converged: bool,
}
#[inline]
fn vec_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
#[inline]
fn dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
}
#[inline]
fn saxpy(y: &mut [f64], alpha: f64, x: &[f64]) {
for (yi, xi) in y.iter_mut().zip(x.iter()) {
*yi += alpha * xi;
}
}
pub fn cgls<F, G>(
matvec: F,
rmatvec: G,
b: &[f64],
m: usize,
n: usize,
max_iter: usize,
tol: f64,
) -> SparseResult<CGLSResult>
where
F: Fn(&[f64]) -> Vec<f64>,
G: Fn(&[f64]) -> Vec<f64>,
{
if b.len() != m {
return Err(SparseError::DimensionMismatch {
expected: m,
found: b.len(),
});
}
let max_iter = max_iter.max(1);
let mut x = vec![0.0f64; n];
let mut r = b.to_vec();
let mut s = rmatvec(&r);
let mut p = s.clone();
let s_norm_sq = dot(&s, &s);
if s_norm_sq == 0.0 {
return Ok(CGLSResult {
x,
iters: 0,
rel_norm: 0.0,
converged: true,
});
}
let atb_norm = s_norm_sq.sqrt();
let mut gamma = s_norm_sq;
let mut iters = 0usize;
let mut converged = false;
for iter in 0..max_iter {
iters = iter + 1;
let q = matvec(&p);
let q_norm_sq = dot(&q, &q);
if q_norm_sq == 0.0 {
break;
}
let alpha = gamma / q_norm_sq;
saxpy(&mut x, alpha, &p);
saxpy(&mut r, -alpha, &q);
let s_new = rmatvec(&r);
let gamma_new = dot(&s_new, &s_new);
let rel = gamma_new.sqrt() / atb_norm;
if rel <= tol {
converged = true;
break;
}
let beta = gamma_new / gamma;
let mut p_new = s_new.clone();
saxpy(&mut p_new, beta, &p);
p = p_new;
s = s_new;
gamma = gamma_new;
let _ = &s; }
let rel_norm = gamma.sqrt() / atb_norm;
Ok(CGLSResult {
x,
iters,
rel_norm,
converged,
})
}
pub fn cgls_sparse<F>(
a: &CsrMatrix<F>,
b: &[f64],
max_iter: usize,
tol: f64,
) -> SparseResult<CGLSResult>
where
F: Float + NumAssign + SparseElement + Debug + Sum + Into<f64> + Copy,
{
let m = a.rows();
let n = a.cols();
let indptr = a.indptr.clone();
let indices = a.indices.clone();
let data_f64: Vec<f64> = a.data.iter().map(|&v| v.into()).collect();
let indptr2 = indptr.clone();
let indices2 = indices.clone();
let data_f64_2 = data_f64.clone();
let matvec = move |x: &[f64]| -> Vec<f64> {
let mut y = vec![0.0f64; m];
for i in 0..m {
for pos in indptr[i]..indptr[i + 1] {
y[i] += data_f64[pos] * x[indices[pos]];
}
}
y
};
let rmatvec = move |y: &[f64]| -> Vec<f64> {
let mut x = vec![0.0f64; n];
for i in 0..m {
for pos in indptr2[i]..indptr2[i + 1] {
x[indices2[pos]] += data_f64_2[pos] * y[i];
}
}
x
};
cgls(matvec, rmatvec, b, m, n, max_iter, tol)
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_close(a: f64, b: f64, tol: f64) {
assert!(
(a - b).abs() < tol,
"expected ~{b}, got {a} (diff {})",
(a - b).abs()
);
}
fn make_matvec(
indptr: Vec<usize>,
indices: Vec<usize>,
data: Vec<f64>,
m: usize,
n: usize,
) -> (
impl Fn(&[f64]) -> Vec<f64>,
impl Fn(&[f64]) -> Vec<f64>,
) {
let indptr2 = indptr.clone();
let indices2 = indices.clone();
let data2 = data.clone();
let mv = move |x: &[f64]| -> Vec<f64> {
let mut y = vec![0.0; m];
for i in 0..m {
for pos in indptr[i]..indptr[i + 1] {
y[i] += data[pos] * x[indices[pos]];
}
}
y
};
let rmv = move |y: &[f64]| -> Vec<f64> {
let mut x = vec![0.0; n];
for i in 0..m {
for pos in indptr2[i]..indptr2[i + 1] {
x[indices2[pos]] += data2[pos] * y[i];
}
}
x
};
(mv, rmv)
}
#[test]
fn test_cgls_square_spd() {
let m = 3usize;
let n = 3usize;
let indptr = vec![0usize, 2, 5, 7];
let indices = vec![0usize, 1, 0, 1, 2, 1, 2];
let data = vec![4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0];
let b = vec![3.0f64, 2.0, 3.0];
let (mv, rmv) = make_matvec(indptr, indices, data, m, n);
let result = cgls(mv, rmv, &b, m, n, 500, 1e-10).expect("CGLS failed");
assert!(result.converged, "CGLS should converge");
assert_close(result.x[0], 1.0, 1e-4);
assert_close(result.x[1], 1.0, 1e-4);
assert_close(result.x[2], 1.0, 1e-4);
}
#[test]
fn test_cgls_zero_rhs() {
let m = 3usize;
let n = 3usize;
let indptr = vec![0usize, 1, 2, 3];
let indices = vec![0usize, 1, 2];
let data = vec![2.0, 3.0, 4.0];
let b = vec![0.0f64; m];
let (mv, rmv) = make_matvec(indptr, indices, data, m, n);
let result = cgls(mv, rmv, &b, m, n, 100, 1e-10).expect("CGLS zero rhs failed");
assert!(result.converged);
assert!(result.x.iter().all(|&v| v == 0.0));
}
#[test]
fn test_cgls_overdetermined() {
let m = 4usize;
let n = 2usize;
let indptr = vec![0usize, 1, 2, 4, 5];
let indices = vec![0usize, 1, 0, 1, 0];
let data = vec![1.0f64, 1.0, 1.0, 1.0, 2.0];
let b = vec![1.0f64, 1.0, 2.0, 2.0];
let (mv, rmv) = make_matvec(indptr, indices, data, m, n);
let result = cgls(mv, rmv, &b, m, n, 1000, 1e-8).expect("CGLS overdetermined failed");
assert_close(result.x[0], 1.0, 1e-4);
assert_close(result.x[1], 1.0, 1e-4);
}
#[test]
fn test_cgls_sparse_wrapper() {
let rows_v = vec![0usize, 0, 1, 1, 1, 2, 2];
let cols_v = vec![0usize, 1, 0, 1, 2, 1, 2];
let data = vec![4.0f64, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0];
let b = vec![3.0f64, 2.0, 3.0];
let a = CsrMatrix::new(data, rows_v, cols_v, (3, 3)).expect("CsrMatrix failed");
let result = cgls_sparse(&a, &b, 500, 1e-10).expect("cgls_sparse failed");
assert!(result.converged);
assert_close(result.x[0], 1.0, 1e-4);
}
}