use crate::error::{LinalgError, LinalgResult};
use crate::matrixfree::MatrixFreeOp;
use crate::norm::vector_norm;
use crate::quantization::quantized_matrixfree::QuantizedMatrixFreeOp;
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::{AsPrimitive, Float, FromPrimitive, NumAssign, One, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[allow(dead_code)]
pub fn quantized_conjugate_gradient<F>(
a: &QuantizedMatrixFreeOp<F>,
b: &Array1<F>,
max_iter: usize,
tol: F,
adaptive_precision: bool,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Debug + Send + Sync,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Expected square operator, got shape {}x{}",
a.nrows(),
a.ncols()
)));
}
if a.nrows() != b.len() {
return Err(LinalgError::ShapeError(format!(
"Shape mismatch: operator shape {}x{}, vector shape {}",
a.nrows(),
a.ncols(),
b.len()
)));
}
if !a.is_symmetric() {
return Err(LinalgError::ValueError(
"Quantized conjugate gradient requires a symmetric operator".to_string(),
));
}
let n = a.nrows();
let mut x = Array1::zeros(n);
let b_norm = vector_norm(&b.view(), 2)?;
if b_norm < F::epsilon() {
return Ok(x);
}
let ax = a.apply(&x.view())?;
let mut r = b.clone();
r -= &ax;
let mut p = r.clone();
let mut rsold = r.dot(&r);
if rsold.sqrt() < tol * b_norm {
return Ok(x);
}
let mut ap;
let mut successive_slow_progress = 0;
let mut previous_residual = rsold;
for (iteration_, _) in (0..max_iter).enumerate() {
let _iter = iteration_;
ap = a.apply(&p.view())?;
let pap = p.dot(&ap);
if pap.abs() < F::epsilon() {
if _iter == 0usize {
return Err(LinalgError::ComputationError(
"Zero curvature detected in first iteration".to_string(),
));
}
break;
}
let alpha = rsold / pap;
x = &x + &(&p * alpha);
r = &r - &(&ap * alpha);
let mut rsnew = r.dot(&r);
if rsnew.sqrt() < tol * b_norm {
break;
}
if adaptive_precision {
let ratio = rsnew / previous_residual;
if ratio > F::from(0.9).expect("Operation failed") {
successive_slow_progress += 1;
} else {
successive_slow_progress = 0;
}
if successive_slow_progress >= 5 {
let ax = a.apply(&x.view())?;
r = b.clone();
r -= &ax;
successive_slow_progress = 0;
rsnew = r.dot(&r);
if rsnew.sqrt() < tol * b_norm {
break;
}
}
previous_residual = rsnew;
}
let beta = rsnew / rsold;
p = &r + &(&p * beta);
rsold = rsnew;
}
Ok(x)
}
#[allow(dead_code)]
pub fn quantized_gmres<F>(
a: &QuantizedMatrixFreeOp<F>,
b: &Array1<F>,
max_iter: usize,
tol: F,
restart: Option<usize>,
adaptive_precision: bool,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Debug + Send + Sync,
{
if a.nrows() != b.len() {
return Err(LinalgError::ShapeError(format!(
"Shape mismatch: operator shape {}x{}, vector shape {}",
a.nrows(),
a.ncols(),
b.len()
)));
}
let n = a.nrows();
let restart_iter = restart.unwrap_or(n.min(30));
let mut x = Array1::zeros(n);
let b_norm = vector_norm(&b.view(), 2)?;
if b_norm < F::epsilon() {
return Ok(x);
}
let mut reorth_step = if adaptive_precision { 1 } else { restart_iter };
for _outer in 0..max_iter {
let ax = a.apply(&x.view())?;
let mut r = b.clone();
r -= &ax;
let r_norm = vector_norm(&r.view(), 2)?;
if r_norm < tol * b_norm {
return Ok(x);
}
let beta = r_norm;
let mut v = Array1::zeros(n);
for i in 0..n {
v[i] = r[i] / beta;
}
let mut h = Array2::zeros((restart_iter + 1, restart_iter));
let mut v_basis = Vec::with_capacity(restart_iter + 1);
v_basis.push(v);
let mut cs: Vec<F> = Vec::with_capacity(restart_iter);
let mut sn: Vec<F> = Vec::with_capacity(restart_iter);
let mut g = Array1::zeros(restart_iter + 1);
g[0] = beta;
let mut i = 0;
while i < restart_iter {
let av = a.apply(&v_basis[i].view())?;
let mut w = av;
let reorth_needed = adaptive_precision && (i % reorth_step == 0);
for j in 0..=i {
h[[j, i]] = w.dot(&v_basis[j]);
w = &w - &(&v_basis[j] * h[[j, i]]);
if reorth_needed {
let h_correction = w.dot(&v_basis[j]);
h[[j, i]] += h_correction;
w = &w - &(&v_basis[j] * h_correction);
}
}
h[[i + 1, i]] = vector_norm(&w.view(), 2)?;
if h[[i + 1, i]] < F::epsilon() {
i += 1;
break;
}
let mut new_v = Array1::zeros(n);
for j in 0..n {
new_v[j] = w[j] / h[[i + 1, i]];
}
v_basis.push(new_v);
for j in 0..i {
let temp = h[[j, i]];
h[[j, i]] = cs[j] * temp + sn[j] * h[[j + 1, i]];
h[[j + 1, i]] = -sn[j] * temp + cs[j] * h[[j + 1, i]];
}
let (c, s) = givens_rotation(h[[i, i]], h[[i + 1, i]]);
cs.push(c);
sn.push(s);
h[[i, i]] = c * h[[i, i]] + s * h[[i + 1, i]];
h[[i + 1, i]] = F::zero();
let temp = g[i];
g[i] = c * temp + s * g[i + 1];
g[i + 1] = -s * temp + c * g[i + 1];
let residual = g[i + 1].abs();
if residual < tol * b_norm {
i += 1;
break;
}
if adaptive_precision && i > 2 {
let progress_ratio = residual / g[i].abs();
if progress_ratio > F::from(0.8).expect("Operation failed") && reorth_step > 1 {
reorth_step = reorth_step.max(1) / 2;
}
else if progress_ratio < F::from(0.5).expect("Operation failed")
&& reorth_step < restart_iter
{
reorth_step = (reorth_step * 2).min(restart_iter);
}
}
i += 1;
}
let mut y = Array1::zeros(i);
for j in (0..i).rev() {
let mut sum = g[j];
for k in (j + 1)..i {
sum -= h[[j, k]] * y[k];
}
y[j] = sum / h[[j, j]];
}
for j in 0..i {
x = &x + &(&v_basis[j] * y[j]);
}
let ax = a.apply(&x.view())?;
let mut r = b.clone();
r -= &ax;
let r_norm = vector_norm(&r.view(), 2)?;
if r_norm < tol * b_norm || i < restart_iter {
return Ok(x);
}
if adaptive_precision {
}
}
Ok(x)
}
#[allow(dead_code)]
fn givens_rotation<F>(a: F, b: F) -> (F, F)
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync,
{
if b == F::zero() {
(F::one(), F::zero())
} else if a.abs() < b.abs() {
let t = a / b;
let s = F::one() / (F::one() + t * t).sqrt();
let c = s * t;
(c, s)
} else {
let t = b / a;
let c = F::one() / (F::one() + t * t).sqrt();
let s = c * t;
(c, s)
}
}
#[allow(dead_code)]
pub fn quantized_jacobi_preconditioner<F>(
a: &QuantizedMatrixFreeOp<F>,
) -> LinalgResult<QuantizedMatrixFreeOp<F>>
where
F: Float
+ NumAssign
+ Zero
+ Sum
+ One
+ ScalarOperand
+ Clone
+ Debug
+ Send
+ Sync
+ FromPrimitive
+ AsPrimitive<f32>
+ 'static,
f32: AsPrimitive<F>,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(
"Jacobi preconditioner requires a square operator".to_string(),
));
}
let n = a.nrows();
let mut diag = Array1::zeros(n);
for i in 0..n {
let mut e_i = Array1::zeros(n);
e_i[i] = F::one();
let a_e_i = a.apply(&e_i.view())?;
diag[i] = a_e_i[i];
}
for i in 0..n {
if diag[i].abs() < F::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Jacobi preconditioner encountered zero on diagonal".to_string(),
));
}
diag[i] = F::one() / diag[i];
}
let diag_clone = diag.clone();
QuantizedMatrixFreeOp::new(
n,
n,
a.params().bits,
a.params().method,
move |x: &ArrayView1<F>| -> LinalgResult<Array1<F>> {
if x.len() != n {
return Err(LinalgError::ShapeError(format!(
"Expected vector of length {}, got {}",
n,
x.len()
)));
}
let mut result = Array1::zeros(n);
for i in 0..n {
result[i] = diag_clone[i] * x[i];
}
Ok(result)
},
)
}
#[allow(dead_code)]
pub fn quantized_preconditioned_conjugate_gradient<F>(
a: &QuantizedMatrixFreeOp<F>,
m: &QuantizedMatrixFreeOp<F>,
b: &Array1<F>,
max_iter: usize,
tol: F,
adaptive_precision: bool,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Debug + Send + Sync,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Expected square operator, got shape {}x{}",
a.nrows(),
a.ncols()
)));
}
if a.nrows() != b.len() {
return Err(LinalgError::ShapeError(format!(
"Shape mismatch: operator shape {}x{}, vector shape {}",
a.nrows(),
a.ncols(),
b.len()
)));
}
if m.nrows() != a.nrows() || m.ncols() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Preconditioner shape {}x{} doesn't match operator shape {}x{}",
m.nrows(),
m.ncols(),
a.nrows(),
a.ncols()
)));
}
let n = a.nrows();
let mut x = Array1::zeros(n);
let b_norm = vector_norm(&b.view(), 2)?;
if b_norm < F::epsilon() {
return Ok(x);
}
let ax = a.apply(&x.view())?;
let mut r = b.clone();
r -= &ax;
let mut z = m.apply(&r.view())?;
let mut p = z.clone();
let mut rz_old = r.dot(&z);
if vector_norm(&r.view(), 2)? < tol * b_norm {
return Ok(x);
}
let mut successive_slow_progress = 0;
let mut previous_residual = r.dot(&r);
for _iter in 0..max_iter {
let ap = a.apply(&p.view())?;
let pap = p.dot(&ap);
if pap.abs() < F::epsilon() {
if _iter == 0usize {
return Err(LinalgError::ComputationError(
"Zero curvature detected in first iteration".to_string(),
));
}
break;
}
let alpha = rz_old / pap;
x = &x + &(&p * alpha);
r = &r - &(&ap * alpha);
let r_norm = vector_norm(&r.view(), 2)?;
if r_norm < tol * b_norm {
break;
}
if adaptive_precision {
let r_squared = r.dot(&r);
let ratio = r_squared / previous_residual;
if ratio > F::from(0.9).expect("Operation failed") {
successive_slow_progress += 1;
} else {
successive_slow_progress = 0;
}
if successive_slow_progress >= 5 {
let ax = a.apply(&x.view())?;
r = b.clone();
r -= &ax;
successive_slow_progress = 0;
if vector_norm(&r.view(), 2)? < tol * b_norm {
break;
}
}
previous_residual = r_squared;
}
z = m.apply(&r.view())?;
let rz_new = r.dot(&z);
let beta = rz_new / rz_old;
p = &z + &(&p * beta);
rz_old = rz_new;
}
Ok(x)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quantization::quantized_matrixfree::QuantizedMatrixFreeOp;
use crate::quantization::QuantizationMethod;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_quantized_conjugate_gradient_smallmatrix() {
let matrix = array![[4.0f32, 1.0], [1.0, 3.0]];
let op =
QuantizedMatrixFreeOp::frommatrix(&matrix.view(), 8, QuantizationMethod::Symmetric)
.expect("Operation failed")
.symmetric()
.positive_definite();
let b = array![1.0f32, 2.0];
let x = quantized_conjugate_gradient(&op, &b, 10, 1e-6, false).expect("Operation failed");
let expected = array![0.181818f32, 0.636364];
assert_eq!(x.len(), expected.len());
for i in 0..x.len() {
assert_relative_eq!(x[i], expected[i], epsilon = 0.15);
}
}
#[test]
fn test_quantized_gmres_smallmatrix() {
let matrix = array![[3.0f32, 1.0], [1.0, 2.0]];
let op =
QuantizedMatrixFreeOp::frommatrix(&matrix.view(), 8, QuantizationMethod::Symmetric)
.expect("Operation failed");
let b = array![4.0f32, 3.0];
let x = quantized_gmres(&op, &b, 10, 1e-6, None, false).expect("Operation failed");
let expected = array![1.0f32, 1.0];
assert_eq!(x.len(), expected.len());
for i in 0..x.len() {
assert_relative_eq!(x[i], expected[i], epsilon = 0.15);
}
}
#[test]
fn test_quantized_preconditioned_conjugate_gradient() {
let matrix = array![[4.0f32, 1.0], [1.0, 3.0]];
let op =
QuantizedMatrixFreeOp::frommatrix(&matrix.view(), 8, QuantizationMethod::Symmetric)
.expect("Operation failed")
.symmetric()
.positive_definite();
let precond = quantized_jacobi_preconditioner(&op).expect("Operation failed");
let b = array![1.0f32, 2.0];
let x = quantized_preconditioned_conjugate_gradient(&op, &precond, &b, 10, 1e-6, false)
.expect("Operation failed");
let expected = array![0.181818f32, 0.636364];
assert_eq!(x.len(), expected.len());
for i in 0..x.len() {
assert_relative_eq!(x[i], expected[i], epsilon = 0.15);
}
}
#[test]
fn test_quantized_jacobi_preconditioner() {
let matrix = array![[4.0f32, 1.0], [1.0, 3.0]];
let op =
QuantizedMatrixFreeOp::frommatrix(&matrix.view(), 8, QuantizationMethod::Symmetric)
.expect("Operation failed");
let precond = quantized_jacobi_preconditioner(&op).expect("Operation failed");
let x = array![1.0f32, 2.0];
let y = precond.apply(&x.view()).expect("Operation failed");
let expected = array![0.25f32, 2.0 / 3.0];
assert_eq!(y.len(), expected.len());
for i in 0..y.len() {
assert_relative_eq!(y[i], expected[i], epsilon = 0.01);
}
}
#[test]
fn test_adaptive_precision_conjugate_gradient() {
let matrix = array![[4.0f32, 1.0], [1.0, 3.0]];
let op =
QuantizedMatrixFreeOp::frommatrix(&matrix.view(), 8, QuantizationMethod::Symmetric)
.expect("Operation failed")
.symmetric()
.positive_definite();
let b = array![1.0f32, 2.0];
let x_adaptive =
quantized_conjugate_gradient(&op, &b, 10, 1e-6, true).expect("Operation failed");
let x_standard =
quantized_conjugate_gradient(&op, &b, 10, 1e-6, false).expect("Operation failed");
let expected = array![0.181818f32, 0.636364];
assert_eq!(x_adaptive.len(), expected.len());
for i in 0..x_adaptive.len() {
assert_relative_eq!(x_adaptive[i], expected[i], epsilon = 0.15);
}
assert_eq!(x_standard.len(), expected.len());
for i in 0..x_standard.len() {
assert_relative_eq!(x_standard[i], expected[i], epsilon = 0.15);
}
}
}