use crate::error::{SparseError, SparseResult};
use crate::sparray::SparseArray;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::Float;
use scirs2_core::SparseElement;
use std::fmt::Debug;
type GCROTInnerResult<T> = SparseResult<(Array1<T>, Option<Array1<T>>, Option<Array1<T>>, bool)>;
#[derive(Debug, Clone)]
pub struct GCROTOptions {
pub max_iter: usize,
pub tol: f64,
pub truncation_size: usize,
pub store_residual_history: bool,
}
impl Default for GCROTOptions {
fn default() -> Self {
Self {
max_iter: 1000,
tol: 1e-6,
truncation_size: 20,
store_residual_history: true,
}
}
}
#[derive(Debug, Clone)]
pub struct GCROTResult<T> {
pub x: Array1<T>,
pub iterations: usize,
pub residual_norm: T,
pub converged: bool,
pub residual_history: Option<Vec<T>>,
}
#[allow(dead_code)]
pub fn gcrot<T, S>(
matrix: &S,
b: &ArrayView1<T>,
x0: Option<&ArrayView1<T>>,
options: GCROTOptions,
) -> SparseResult<GCROTResult<T>>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let n = b.len();
let (rows, cols) = matrix.shape();
if rows != cols || rows != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: rows,
});
}
let mut x = match x0 {
Some(x0_val) => x0_val.to_owned(),
None => Array1::zeros(n),
};
let ax = matrix_vector_multiply(matrix, &x.view())?;
let mut r = b - &ax;
let initial_residual_norm = l2_norm(&r.view());
let b_norm = l2_norm(b);
let tolerance = T::from(options.tol).expect("Operation failed") * b_norm;
if initial_residual_norm <= tolerance {
return Ok(GCROTResult {
x,
iterations: 0,
residual_norm: initial_residual_norm,
converged: true,
residual_history: if options.store_residual_history {
Some(vec![initial_residual_norm])
} else {
None
},
});
}
let m = options.truncation_size;
let mut c_vectors = Array2::zeros((n, 0)); let mut u_vectors = Array2::zeros((n, 0));
let mut residual_history = if options.store_residual_history {
Some(vec![initial_residual_norm])
} else {
None
};
let mut converged = false;
let mut iter = 0;
for k in 0..options.max_iter {
iter = k + 1;
let (delta_x, new_c, new_u, inner_converged) = gcrot_inner_iteration(
matrix,
&r.view(),
&c_vectors.view(),
&u_vectors.view(),
tolerance,
)?;
x = &x + &delta_x;
let ax = matrix_vector_multiply(matrix, &x.view())?;
r = b - &ax;
let residual_norm = l2_norm(&r.view());
if let Some(ref mut history) = residual_history {
history.push(residual_norm);
}
if residual_norm <= tolerance || inner_converged {
converged = true;
break;
}
if let (Some(c), Some(u)) = (new_c, new_u) {
if c_vectors.ncols() >= m {
let mut new_c_vectors = Array2::zeros((n, m));
let mut new_u_vectors = Array2::zeros((n, m));
for j in 1..c_vectors.ncols() {
for i in 0..n {
new_c_vectors[[i, j - 1]] = c_vectors[[i, j]];
new_u_vectors[[i, j - 1]] = u_vectors[[i, j]];
}
}
for i in 0..n {
new_c_vectors[[i, m - 1]] = c[i];
new_u_vectors[[i, m - 1]] = u[i];
}
c_vectors = new_c_vectors;
u_vectors = new_u_vectors;
} else {
let old_cols = c_vectors.ncols();
let mut new_c_vectors = Array2::zeros((n, old_cols + 1));
let mut new_u_vectors = Array2::zeros((n, old_cols + 1));
for j in 0..old_cols {
for i in 0..n {
new_c_vectors[[i, j]] = c_vectors[[i, j]];
new_u_vectors[[i, j]] = u_vectors[[i, j]];
}
}
for i in 0..n {
new_c_vectors[[i, old_cols]] = c[i];
new_u_vectors[[i, old_cols]] = u[i];
}
c_vectors = new_c_vectors;
u_vectors = new_u_vectors;
}
}
}
let ax_final = matrix_vector_multiply(matrix, &x.view())?;
let final_residual = b - &ax_final;
let final_residual_norm = l2_norm(&final_residual.view());
Ok(GCROTResult {
x,
iterations: iter,
residual_norm: final_residual_norm,
converged,
residual_history,
})
}
#[allow(dead_code)]
fn gcrot_inner_iteration<T, S>(
matrix: &S,
r: &ArrayView1<T>,
c_vectors: &scirs2_core::ndarray::ArrayView2<T>,
u_vectors: &scirs2_core::ndarray::ArrayView2<T>,
tolerance: T,
) -> GCROTInnerResult<T>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let n = r.len();
let k = c_vectors.ncols();
let mut v = r.to_owned();
let beta = l2_norm(&v.view());
if beta <= tolerance {
return Ok((Array1::zeros(n), None, None, true));
}
for i in 0..n {
v[i] = v[i] / beta;
}
for j in 0..k {
let mut proj = T::sparse_zero();
for i in 0..n {
proj = proj + u_vectors[[i, j]] * v[i];
}
for i in 0..n {
v[i] = v[i] - proj * c_vectors[[i, j]];
}
}
let v_norm = l2_norm(&v.view());
if v_norm > T::from(1e-12).expect("Operation failed") {
for i in 0..n {
v[i] = v[i] / v_norm;
}
}
let av = matrix_vector_multiply(matrix, &v.view())?;
let av_norm_sq = dot_product(&av.view(), &av.view());
let av_r_dot = dot_product(&av.view(), r);
if av_norm_sq > T::from(1e-12).expect("Operation failed") {
let alpha = av_r_dot / av_norm_sq;
let mut delta_x = Array1::zeros(n);
for i in 0..n {
delta_x[i] = alpha * v[i];
}
Ok((delta_x, Some(v), Some(av), false))
} else {
Ok((Array1::zeros(n), None, None, true))
}
}
#[allow(dead_code)]
fn matrix_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + 'static,
S: SparseArray<T>,
{
let (rows, cols) = matrix.shape();
if x.len() != cols {
return Err(SparseError::DimensionMismatch {
expected: cols,
found: x.len(),
});
}
let mut result = Array1::zeros(rows);
let (row_indices, col_indices, values) = matrix.find();
for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
result[i] = result[i] + values[k] * x[j];
}
Ok(result)
}
#[allow(dead_code)]
fn l2_norm<T>(x: &ArrayView1<T>) -> T
where
T: Float + Debug + Copy + SparseElement,
{
(x.iter()
.map(|&val| val * val)
.fold(T::sparse_zero(), |a, b| a + b))
.sqrt()
}
#[allow(dead_code)]
fn dot_product<T>(x: &ArrayView1<T>, y: &ArrayView1<T>) -> T
where
T: Float + Debug + Copy + SparseElement,
{
x.iter()
.zip(y.iter())
.map(|(&xi, &yi)| xi * yi)
.fold(T::sparse_zero(), |a, b| a + b)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csr_array::CsrArray;
#[test]
fn test_gcrot_simple_system() {
let rows = vec![0, 0, 1, 1, 2, 2];
let cols = vec![0, 1, 0, 1, 1, 2];
let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
let matrix =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
let result =
gcrot(&matrix, &b.view(), None, GCROTOptions::default()).expect("Operation failed");
assert!(result.converged);
let ax = matrix_vector_multiply(&matrix, &result.x.view()).expect("Operation failed");
let residual = &b - &ax;
let residual_norm = l2_norm(&residual.view());
assert!(residual_norm < 1e-6);
}
#[test]
fn test_gcrot_diagonal_system() {
let rows = vec![0, 1, 2];
let cols = vec![0, 1, 2];
let data = vec![5.0, 5.0, 5.0];
let matrix =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
let b = Array1::from_vec(vec![5.0, 10.0, 15.0]);
let result =
gcrot(&matrix, &b.view(), None, GCROTOptions::default()).expect("Operation failed");
assert!(result.converged);
let ax = matrix_vector_multiply(&matrix, &result.x.view()).expect("Operation failed");
let residual = &b - &ax;
let residual_norm = l2_norm(&residual.view());
assert!(residual_norm < 1e-6);
}
#[test]
fn test_gcrot_truncation() {
let rows = vec![0, 0, 1, 1, 2, 2];
let cols = vec![0, 1, 0, 1, 1, 2];
let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
let matrix =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
let options = GCROTOptions {
truncation_size: 2, ..Default::default()
};
let result = gcrot(&matrix, &b.view(), None, options).expect("Operation failed");
assert!(result.converged);
}
}