use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::{fmt::Debug, iter::Sum, sync::Arc};
use crate::error::{LinalgError, LinalgResult};
use crate::norm::vector_norm;
pub trait MatrixFreeOp<F>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync,
{
fn apply(&self, x: &ArrayView1<F>) -> LinalgResult<Array1<F>>;
fn nrows(&self) -> usize;
fn ncols(&self) -> usize;
fn is_symmetric(&self) -> bool {
false }
fn is_positive_definite(&self) -> bool {
false }
}
pub type LinearOperatorFn<F> = Arc<dyn Fn(&ArrayView1<F>) -> Array1<F> + Send + Sync>;
pub struct LinearOperator<F>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync,
{
dim_rows: usize,
dim_cols: usize,
op: LinearOperatorFn<F>,
symmetric: bool,
positive_definite: bool,
}
impl<F> LinearOperator<F>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync,
{
pub fn new<O>(dimension: usize, op: O) -> Self
where
O: Fn(&ArrayView1<F>) -> Array1<F> + Send + Sync + 'static,
{
LinearOperator {
dim_rows: dimension,
dim_cols: dimension,
op: Arc::new(op),
symmetric: false,
positive_definite: false,
}
}
pub fn new_rectangular<O>(rows: usize, cols: usize, op: O) -> Self
where
O: Fn(&ArrayView1<F>) -> Array1<F> + Send + Sync + 'static,
{
LinearOperator {
dim_rows: rows,
dim_cols: cols,
op: Arc::new(op),
symmetric: false,
positive_definite: false,
}
}
pub fn symmetric(mut self) -> Self {
if self.dim_rows != self.dim_cols {
panic!("Only square operators can be symmetric");
}
self.symmetric = true;
self
}
pub fn positive_definite(mut self) -> Self {
if !self.symmetric {
panic!("Only symmetric operators can be positive definite");
}
self.positive_definite = true;
self
}
pub fn transpose(&self) -> Self
where
F: 'static,
{
let op_arc = Arc::clone(&self.op);
let rows = self.dim_rows;
let cols = self.dim_cols;
if self.symmetric && rows == cols {
return LinearOperator {
dim_rows: rows,
dim_cols: cols,
op: op_arc,
symmetric: true,
positive_definite: self.positive_definite,
};
}
LinearOperator {
dim_rows: cols,
dim_cols: rows,
op: Arc::new(move |x: &ArrayView1<F>| {
let mut result = Array1::zeros(rows);
for i in 0..cols {
let mut unit = Array1::zeros(cols);
unit[i] = F::one();
let col = (op_arc)(&unit.view());
for j in 0..rows {
result[j] += col[j] * x[i];
}
}
result
}),
symmetric: false,
positive_definite: false,
}
}
}
impl<F> MatrixFreeOp<F> for LinearOperator<F>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync,
{
fn apply(&self, x: &ArrayView1<F>) -> LinalgResult<Array1<F>> {
if x.len() != self.dim_cols {
return Err(LinalgError::ShapeError(format!(
"Input vector has wrong length: expected {}, got {}",
self.dim_cols,
x.len()
)));
}
Ok((self.op)(x))
}
fn nrows(&self) -> usize {
self.dim_rows
}
fn ncols(&self) -> usize {
self.dim_cols
}
fn is_symmetric(&self) -> bool {
self.symmetric
}
fn is_positive_definite(&self) -> bool {
self.positive_definite
}
}
#[allow(dead_code)]
pub fn diagonal_operator<F>(diag: &ArrayView1<F>) -> LinearOperator<F>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Clone + Debug + Send + Sync + 'static,
{
let diag_owned = diag.to_owned();
let n = diag.len();
LinearOperator {
dim_rows: n,
dim_cols: n,
op: Arc::new(move |x: &ArrayView1<F>| {
let mut result = Array1::zeros(n);
for i in 0..n {
result[i] = diag_owned[i] * x[i];
}
result
}),
symmetric: true,
positive_definite: diag.iter().all(|&d| d > F::zero()),
}
}
#[allow(dead_code)]
pub fn block_diagonal_operator<F>(blocks: Vec<LinearOperator<F>>) -> LinearOperator<F>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Clone + Debug + Send + Sync + 'static,
{
let n_rows: usize = blocks.iter().map(|b| b.nrows()).sum();
let n_cols: usize = blocks.iter().map(|b| b.ncols()).sum();
let all_symmetric = blocks.iter().all(|b| b.is_symmetric());
let all_positive_definite = all_symmetric && blocks.iter().all(|b| b.is_positive_definite());
let blocks_owned = blocks; LinearOperator {
dim_rows: n_rows,
dim_cols: n_cols,
op: Arc::new(move |x: &ArrayView1<F>| {
let mut result = Array1::zeros(n_rows);
let mut row_offset = 0;
let mut col_offset = 0;
for block in &blocks_owned {
let n_block_rows = block.nrows();
let n_block_cols = block.ncols();
let x_block = x.slice(s![col_offset..col_offset + n_block_cols]);
let result_block = block.apply(&x_block.view()).expect("Operation failed");
for (i, &val) in result_block.iter().enumerate() {
result[row_offset + i] = val;
}
row_offset += n_block_rows;
col_offset += n_block_cols;
}
result
}),
symmetric: all_symmetric,
positive_definite: all_positive_definite,
}
}
#[allow(dead_code)]
pub fn conjugate_gradient<F, A>(
a: &A,
b: &Array1<F>,
max_iter: usize,
tol: F,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Debug + Send + Sync,
A: MatrixFreeOp<F>,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Expected square operator, got shape {}x{}",
a.nrows(),
a.ncols()
)));
}
if a.nrows() != b.len() {
return Err(LinalgError::ShapeError(format!(
"Shape mismatch: operator shape {}x{}, vector shape {}",
a.nrows(),
a.ncols(),
b.len()
)));
}
if !a.is_symmetric() {
eprintln!("Warning: Operator might not be symmetric");
}
if !a.is_positive_definite() {
eprintln!("Warning: Operator might not be positive definite");
}
let n = a.nrows();
let mut x = Array1::zeros(n);
let b_norm = vector_norm(&b.view(), 2)?;
if b_norm < F::epsilon() {
return Ok(x);
}
let ax = a.apply(&x.view())?;
let mut r = b.clone();
r -= &ax;
let mut p = r.clone();
let mut rsold = r.dot(&r);
if rsold.sqrt() < tol * b_norm {
return Ok(x);
}
for _iter in 0..max_iter {
let ap = a.apply(&p.view())?;
let pap = p.dot(&ap);
let alpha = rsold / pap;
x = &x + &(&p * alpha);
r = &r - &(&ap * alpha);
let rsnew = r.dot(&r);
if rsnew.sqrt() < tol * b_norm {
return Ok(x);
}
let beta = rsnew / rsold;
p = &r + &(&p * beta);
rsold = rsnew;
}
Ok(x)
}
#[allow(dead_code)]
pub fn gmres<F, A>(
a: &A,
b: &Array1<F>,
max_iter: usize,
tol: F,
restart: Option<usize>,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Debug + Send + Sync,
A: MatrixFreeOp<F>,
{
if a.nrows() != b.len() {
return Err(LinalgError::ShapeError(format!(
"Shape mismatch: operator shape {}x{}, vector shape {}",
a.nrows(),
a.ncols(),
b.len()
)));
}
let n = a.nrows();
let restart_iter = restart.unwrap_or(n);
let mut x = Array1::zeros(n);
let b_norm = vector_norm(&b.view(), 2)?;
if b_norm < F::epsilon() {
return Ok(x);
}
for _outer in 0..max_iter {
let ax = a.apply(&x.view())?;
let mut r = b.clone();
r -= &ax;
let r_norm = vector_norm(&r.view(), 2)?;
if r_norm < tol * b_norm {
return Ok(x);
}
let beta = r_norm;
let mut v = Array1::zeros(n);
for i in 0..n {
v[i] = r[i] / beta;
}
let mut h = Array2::zeros((restart_iter + 1, restart_iter));
let mut v_basis = Vec::with_capacity(restart_iter + 1);
v_basis.push(v);
let mut cs: Vec<F> = Vec::with_capacity(restart_iter);
let mut sn: Vec<F> = Vec::with_capacity(restart_iter);
let mut g = Array1::zeros(restart_iter + 1);
g[0] = beta;
let mut i = 0;
while i < restart_iter {
let av = a.apply(&v_basis[i].view())?;
let mut w = av;
for j in 0..=i {
h[[j, i]] = w.dot(&v_basis[j]);
w = &w - &(&v_basis[j] * h[[j, i]]);
}
h[[i + 1, i]] = vector_norm(&w.view(), 2)?;
if h[[i + 1, i]] < F::epsilon() {
i += 1;
break;
}
let mut new_v = Array1::zeros(n);
for j in 0..n {
new_v[j] = w[j] / h[[i + 1, i]];
}
v_basis.push(new_v);
for j in 0..i {
let temp = h[[j, i]];
h[[j, i]] = cs[j] * temp + sn[j] * h[[j + 1, i]];
h[[j + 1, i]] = -sn[j] * temp + cs[j] * h[[j + 1, i]];
}
let (c, s) = givens_rotation(h[[i, i]], h[[i + 1, i]]);
cs.push(c);
sn.push(s);
h[[i, i]] = c * h[[i, i]] + s * h[[i + 1, i]];
h[[i + 1, i]] = F::zero();
let temp = g[i];
g[i] = c * temp + s * g[i + 1];
g[i + 1] = -s * temp + c * g[i + 1];
let residual = g[i + 1].abs();
if residual < tol * b_norm {
i += 1;
break;
}
i += 1;
}
let mut y = Array1::zeros(i);
for j in (0..i).rev() {
let mut sum = g[j];
for k in (j + 1)..i {
sum -= h[[j, k]] * y[k];
}
y[j] = sum / h[[j, j]];
}
for j in 0..i {
x = &x + &(&v_basis[j] * y[j]);
}
let ax = a.apply(&x.view())?;
let mut r = b.clone();
r -= &ax;
let r_norm = vector_norm(&r.view(), 2)?;
if r_norm < tol * b_norm || i < restart_iter {
return Ok(x);
}
}
Ok(x)
}
#[allow(dead_code)]
fn givens_rotation<F>(a: F, b: F) -> (F, F)
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync,
{
if b == F::zero() {
(F::one(), F::zero())
} else if a.abs() < b.abs() {
let t = a / b;
let s = F::one() / (F::one() + t * t).sqrt();
let c = s * t;
(c, s)
} else {
let t = b / a;
let c = F::one() / (F::one() + t * t).sqrt();
let s = c * t;
(c, s)
}
}
#[allow(dead_code)]
pub fn jacobi_preconditioner<F, A>(a: &A) -> LinalgResult<LinearOperator<F>>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Clone + Debug + Send + Sync + 'static,
A: MatrixFreeOp<F>,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(
"Jacobi preconditioner requires a square operator".to_string(),
));
}
let n = a.nrows();
let mut diag = Array1::zeros(n);
for i in 0..n {
let mut e_i = Array1::zeros(n);
e_i[i] = F::one();
let a_e_i = a.apply(&e_i.view())?;
diag[i] = a_e_i[i];
}
for i in 0..n {
if diag[i].abs() < F::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Jacobi preconditioner encountered zero on diagonal".to_string(),
));
}
diag[i] = F::one() / diag[i];
}
Ok(diagonal_operator(&diag.view()))
}
#[allow(dead_code)]
pub fn preconditioned_conjugate_gradient<F, A, M>(
a: &A,
m: &M,
b: &Array1<F>,
max_iter: usize,
tol: F,
) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Debug + Send + Sync,
A: MatrixFreeOp<F>,
M: MatrixFreeOp<F>,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Expected square operator, got shape {}x{}",
a.nrows(),
a.ncols()
)));
}
if a.nrows() != b.len() {
return Err(LinalgError::ShapeError(format!(
"Shape mismatch: operator shape {}x{}, vector shape {}",
a.nrows(),
a.ncols(),
b.len()
)));
}
if m.nrows() != a.nrows() || m.ncols() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Preconditioner shape {}x{} doesn't match operator shape {}x{}",
m.nrows(),
m.ncols(),
a.nrows(),
a.ncols()
)));
}
let n = a.nrows();
let mut x = Array1::zeros(n);
let b_norm = vector_norm(&b.view(), 2)?;
if b_norm < F::epsilon() {
return Ok(x);
}
let ax = a.apply(&x.view())?;
let mut r = b.clone();
r -= &ax;
let mut z = m.apply(&r.view())?;
let mut p = z.clone();
let mut rz_old = r.dot(&z);
if vector_norm(&r.view(), 2)? < tol * b_norm {
return Ok(x);
}
for _iter in 0..max_iter {
let ap = a.apply(&p.view())?;
let pap = p.dot(&ap);
let alpha = rz_old / pap;
x = &x + &(&p * alpha);
r = &r - &(&ap * alpha);
if vector_norm(&r.view(), 2)? < tol * b_norm {
return Ok(x);
}
z = m.apply(&r.view())?;
let rz_new = r.dot(&z);
let beta = rz_new / rz_old;
p = &z + &(&p * beta);
rz_old = rz_new;
}
Ok(x)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
fn check_solution<F, A>(a: &A, x: &ArrayView1<F>, b: &ArrayView1<F>, tol: F) -> bool
where
F: Float + NumAssign + Zero + Sum + One + ScalarOperand + Debug + Send + Sync,
A: MatrixFreeOp<F>,
{
let ax = a.apply(x).expect("Operation failed");
let mut diff = Array1::zeros(x.len());
for i in 0..x.len() {
diff[i] = ax[i] - b[i];
}
let diff_norm = vector_norm(&diff.view(), 2).expect("Operation failed");
let b_norm = vector_norm(b, 2).expect("Operation failed");
diff_norm < tol * b_norm.max(F::one())
}
#[test]
fn test_linear_operator_apply() {
let identity = LinearOperator::new(2, |v: &ArrayView1<f64>| v.to_owned());
let x = array![1.0, 2.0];
let y = identity.apply(&x.view()).expect("Operation failed");
assert_relative_eq!(y[0], 1.0, epsilon = 1e-10);
assert_relative_eq!(y[1], 2.0, epsilon = 1e-10);
}
#[test]
fn test_diagonal_operator() {
let diag = array![2.0, 3.0];
let diag_op = diagonal_operator(&diag.view());
let x = array![1.0, 2.0];
let y = diag_op.apply(&x.view()).expect("Operation failed");
assert_relative_eq!(y[0], 2.0, epsilon = 1e-10);
assert_relative_eq!(y[1], 6.0, epsilon = 1e-10);
}
#[test]
fn test_block_diagonal_operator() {
let diag1 = array![2.0, 3.0];
let diag_op1 = diagonal_operator(&diag1.view());
let diag2 = array![4.0];
let diag_op2 = diagonal_operator(&diag2.view());
let block_op = block_diagonal_operator(vec![diag_op1, diag_op2]);
let x = array![1.0, 2.0, 3.0];
let y = block_op.apply(&x.view()).expect("Operation failed");
assert_relative_eq!(y[0], 2.0, epsilon = 1e-10);
assert_relative_eq!(y[1], 6.0, epsilon = 1e-10);
assert_relative_eq!(y[2], 12.0, epsilon = 1e-10);
}
#[test]
fn testmatrix_free_conjugate_gradient() {
let spd_op = LinearOperator::new(2, |v: &ArrayView1<f64>| {
let mut result = Array1::zeros(2);
result[0] = 4.0 * v[0] + 1.0 * v[1];
result[1] = 1.0 * v[0] + 3.0 * v[1];
result
})
.symmetric()
.positive_definite();
let b = array![1.0, 2.0];
let x = conjugate_gradient(&spd_op, &b, 10, 1e-10).expect("Operation failed");
assert!(check_solution(&spd_op, &x.view(), &b.view(), 1e-8));
}
#[test]
fn testmatrix_free_gmres() {
let op = LinearOperator::new_rectangular(2, 2, |v: &ArrayView1<f64>| {
let mut result = Array1::zeros(2);
result[0] = 3.0 * v[0] + 1.0 * v[1];
result[1] = 1.0 * v[0] + 2.0 * v[1];
result
});
let b = array![4.0, 3.0];
let x = gmres(&op, &b, 10, 1e-10, None).expect("Operation failed");
assert!(check_solution(&op, &x.view(), &b.view(), 1e-8));
}
#[test]
fn test_jacobi_preconditioner() {
let a_mat = array![[4.0, 1.0], [1.0, 3.0]];
let op = LinearOperator::new(2, move |v: &ArrayView1<f64>| {
let mut result = Array1::zeros(2);
for i in 0..2 {
for j in 0..2 {
result[i] += a_mat[[i, j]] * v[j];
}
}
result
});
let precond = jacobi_preconditioner(&op).expect("Operation failed");
let x = array![1.0, 2.0];
let y = precond.apply(&x.view()).expect("Operation failed");
assert_relative_eq!(y[0], 0.25, epsilon = 1e-10);
assert_relative_eq!(y[1], 2.0 / 3.0, epsilon = 1e-10);
}
#[test]
fn test_preconditioned_conjugate_gradient() {
let spd_op = LinearOperator::new(2, |v: &ArrayView1<f64>| {
let mut result = Array1::zeros(2);
result[0] = 4.0 * v[0] + 1.0 * v[1];
result[1] = 1.0 * v[0] + 3.0 * v[1];
result
})
.symmetric()
.positive_definite();
let diag = array![1.0 / 4.0, 1.0 / 3.0];
let precond = diagonal_operator(&diag.view());
let b = array![1.0, 2.0];
let x = preconditioned_conjugate_gradient(&spd_op, &precond, &b, 10, 1e-10)
.expect("Operation failed");
assert!(check_solution(&spd_op, &x.view(), &b.view(), 1e-8));
}
}