use std::ops::Mul;
use super::MatrixOp;
#[cfg(not(feature = "ndarray"))]
use crate::utils::is_vector;
use crate::{
IterSolverError, IterSolverResult,
ops::Vector,
utils::{axpy, dot, norm_l2, zeros},
};
#[derive(Debug, Clone)]
pub struct CG<'mat, Mat: MatrixOp> {
mat: &'mat Mat,
solution: Vector<f64>,
residual: f64,
iteration: usize,
r: Vector<f64>,
c: Vector<f64>,
u: Vector<f64>,
tol: f64,
prev_residual: f64,
}
impl<'mat, Mat: MatrixOp> CG<'mat, Mat> {
pub fn new(
mat: &'mat Mat,
rhs: &'mat Vector<f64>,
abstol: f64,
reltol: f64,
) -> IterSolverResult<Self> {
if !mat.is_square() {
return Err(IterSolverError::DimensionError(format!(
"The matrix is not square, whose shape is ({}, {})",
mat.nrows(),
mat.ncols()
)));
}
#[cfg(feature = "faer")]
if !is_vector(rhs) {
return Err(IterSolverError::DimensionError(format!(
"The `rhs` should be a vector, but got a matrix with shape ({}, {}).",
rhs.nrows(),
rhs.ncols()
)));
}
if mat.nrows() != rhs.len() {
return Err(IterSolverError::DimensionError(format!(
"The matrix with order {}, and the rhs with length {}, do not match",
mat.nrows(),
rhs.len()
)));
}
let n = mat.nrows();
let x = zeros(n);
let r = rhs.clone();
let c = zeros(n);
let u = zeros(n);
let residual = norm_l2(&r);
let prev_residual = residual;
let iteration = 0;
let tol = abstol.max(reltol * residual);
Ok(Self {
mat,
solution: x,
residual,
iteration,
r,
c,
u,
tol,
prev_residual,
})
}
pub fn new_with_initial_guess(
mat: &'mat Mat,
rhs: &'mat Vector<f64>,
initial_guess: Vector<f64>,
abstol: f64,
reltol: f64,
) -> IterSolverResult<Self>
where
&'mat Mat: Mul<Vector<f64>, Output = Vector<f64>>,
{
if !mat.is_square() {
return Err(IterSolverError::DimensionError(format!(
"The matrix is not square, whose shape is ({}, {})",
mat.nrows(),
mat.ncols()
)));
}
#[cfg(not(feature = "ndarray"))]
if !is_vector(rhs) {
return Err(IterSolverError::DimensionError(format!(
"The `rhs` should be a vector, but got a matrix with shape ({}, {}).",
rhs.nrows(),
rhs.ncols()
)));
}
if mat.nrows() != rhs.len() {
return Err(IterSolverError::DimensionError(format!(
"The matrix with order {}, and the rhs with length {}, do not match",
mat.nrows(),
rhs.len()
)));
}
#[cfg(not(feature = "ndarray"))]
if !is_vector(&initial_guess) {
return Err(IterSolverError::DimensionError(format!(
"The `initial_guess` should be a vector, but got a matrix with shape ({}, {}).",
initial_guess.nrows(),
initial_guess.ncols()
)));
}
if initial_guess.len() != mat.nrows() {
return Err(IterSolverError::DimensionError(format!(
"The initial guess with length {}, and the matrix with order {}, do not match",
initial_guess.len(),
mat.nrows()
)));
}
let n = mat.nrows();
let r = rhs - mat * initial_guess.clone();
let c = zeros(n);
let u = zeros(n);
let residual = norm_l2(&r);
let prev_residual = residual;
let iteration = 0;
let tol = abstol.max(reltol * residual);
Ok(Self {
mat,
solution: initial_guess,
residual,
iteration,
r,
c,
u,
tol,
prev_residual,
})
}
#[inline]
fn converged(&self) -> bool {
self.residual <= self.tol
}
#[inline]
fn done(&self) -> bool {
(self.iteration >= self.max_iter()) || self.converged()
}
#[inline]
fn max_iter(&self) -> usize {
self.solution.len()
}
pub fn solve(mut self) -> Self {
self.by_ref().count();
self
}
pub fn solution(&self) -> &Vector<f64> {
&self.solution
}
pub fn residual(&self) -> f64 {
self.residual
}
pub fn iteration(&self) -> usize {
self.iteration
}
pub fn mat(&self) -> &Mat {
self.mat
}
pub fn residual_vector(&self) -> &Vector<f64> {
&self.r
}
pub fn conjugate_direction(&self) -> &Vector<f64> {
&self.u
}
}
impl<'mat, Mat: MatrixOp> Iterator for CG<'mat, Mat> {
type Item = f64;
fn next(&mut self) -> Option<Self::Item> {
if self.done() {
return None;
}
let beta = self.residual.powi(2) / self.prev_residual.powi(2);
axpy(&mut self.u, 1.0, &self.r, beta);
self.mat.gemv(1.0, &self.u, 0.0, &mut self.c);
let alpha = self.residual.powi(2) / dot(&self.u, &self.c).unwrap();
axpy(&mut self.solution, alpha, &self.u, 1.0);
axpy(&mut self.r, -alpha, &self.c, 1.0);
self.prev_residual = self.residual;
self.residual = norm_l2(&self.r);
self.iteration += 1;
Some(self.residual)
}
}
pub fn cg<'mat, Mat: MatrixOp>(
mat: &'mat Mat,
rhs: &'mat Vector<f64>,
abstol: f64,
reltol: f64,
) -> IterSolverResult<CG<'mat, Mat>> {
let mut solver = CG::new(mat, rhs, abstol, reltol)?;
solver.by_ref().count();
Ok(solver)
}
pub fn cg_with_initial_guess<'mat, Mat: MatrixOp>(
mat: &'mat Mat,
rhs: &'mat Vector<f64>,
initial_guess: Vector<f64>,
abstol: f64,
reltol: f64,
) -> IterSolverResult<CG<'mat, Mat>>
where
&'mat Mat: Mul<Vector<f64>, Output = Vector<f64>>,
{
let mut solver = CG::new_with_initial_guess(mat, rhs, initial_guess, abstol, reltol)?;
solver.by_ref().count();
Ok(solver)
}
#[cfg(test)]
mod tests {
use std::f64::consts::PI;
use super::*;
use crate::utils::{dense::symmetric_tridiagonal, sparse::symmetric_tridiagonal_csc};
#[test]
#[cfg(feature = "nalgebra")]
fn test_cg_dense() {
let n = 1024;
let h = 1.0 / (n as f64);
let a = vec![2.0 / (h * h); n - 1];
let b = vec![-1.0 / (h * h); n - 2];
let mat = symmetric_tridiagonal(&a, &b).unwrap();
let rhs: Vec<_> = (1..n)
.map(|i| PI * PI * (i as f64 * h * PI).sin())
.collect();
let solution: Vec<_> = (1..n).map(|i| (i as f64 * h * PI).sin()).collect();
let solution = Vector::from_vec(solution);
let rhs = Vector::from_vec(rhs);
let solver = cg(&mat, &rhs, 1e-10, 1e-8).unwrap();
let e = (solution - solver.solution()).norm();
assert!(e < 1e-4);
}
#[test]
#[cfg(feature = "faer")]
fn test_cg_dense() {
let n = 1024;
let h = 1.0 / (n as f64);
let a = vec![2.0 / (h * h); n - 1];
let b = vec![-1.0 / (h * h); n - 2];
let mat = symmetric_tridiagonal(&a, &b).unwrap();
let rhs: Vec<_> = (1..n)
.map(|i| PI * PI * (i as f64 * h * PI).sin())
.collect();
let solution: Vec<_> = (1..n).map(|i| (i as f64 * h * PI).sin()).collect();
let solution = faer::Mat::from_fn(n - 1, 1, |i, _| solution[i]);
let rhs = faer::Mat::from_fn(n - 1, 1, |i, _| rhs[i]);
let solver = cg(&mat, &rhs, 1e-10, 1e-8).unwrap();
let e = (solution - solver.solution()).norm_l2();
assert!(e < 1e-4);
}
#[test]
#[cfg(feature = "ndarray")]
fn test_cg_dense() {
use ndarray::arr1;
use ndarray_linalg::Norm;
let n = 1024;
let h = 1.0 / (n as f64);
let a = vec![2.0 / (h * h); n - 1];
let b = vec![-1.0 / (h * h); n - 2];
let mat = symmetric_tridiagonal(&a, &b).unwrap();
let rhs: Vec<_> = (1..n)
.map(|i| PI * PI * (i as f64 * h * PI).sin())
.collect();
let solution: Vec<_> = (1..n).map(|i| (i as f64 * h * PI).sin()).collect();
let solution = arr1(&solution);
let rhs = arr1(&rhs);
let solver = cg(&mat, &rhs, 1e-10, 1e-8).unwrap();
let e = (solution - solver.solution()).norm_l2();
assert!(e < 1e-4);
}
#[test]
#[cfg(feature = "nalgebra")]
fn test_cg_sparse() {
let n = 1024;
let h = 1.0 / (n as f64);
let a = vec![2.0 / (h * h); n - 1];
let b = vec![-1.0 / (h * h); n - 2];
let mat = symmetric_tridiagonal_csc(&a, &b).unwrap();
let rhs: Vec<_> = (1..n)
.map(|i| PI * PI * (i as f64 * h * PI).sin())
.collect();
let solution: Vec<_> = (1..n).map(|i| (i as f64 * h * PI).sin()).collect();
let solution = Vector::from_vec(solution);
let rhs = Vector::from_vec(rhs);
let solver = cg(&mat, &rhs, 1e-10, 1e-8).unwrap();
let e = (solution - solver.solution()).norm();
assert!(e < 1e-4);
}
#[test]
#[cfg(feature = "faer")]
fn test_cg_sparse() {
let n = 1024;
let h = 1.0 / (n as f64);
let a = vec![2.0 / (h * h); n - 1];
let b = vec![-1.0 / (h * h); n - 2];
let mat = symmetric_tridiagonal_csc(&a, &b).unwrap();
let rhs: Vec<_> = (1..n)
.map(|i| PI * PI * (i as f64 * h * PI).sin())
.collect();
let solution: Vec<_> = (1..n).map(|i| (i as f64 * h * PI).sin()).collect();
let solution = faer::Mat::from_fn(n - 1, 1, |i, _| solution[i]);
let rhs = faer::Mat::from_fn(n - 1, 1, |i, _| rhs[i]);
let solver = cg(&mat, &rhs, 1e-10, 1e-8).unwrap();
let e = (solution - solver.solution()).norm_l2();
assert!(e < 1e-4);
}
#[test]
#[cfg(feature = "ndarray")]
fn test_cg_sparse() {
use ndarray::arr1;
use ndarray_linalg::Norm;
let n = 1024;
let h = 1.0 / (n as f64);
let a = vec![2.0 / (h * h); n - 1];
let b = vec![-1.0 / (h * h); n - 2];
let mat = symmetric_tridiagonal_csc(&a, &b).unwrap();
let rhs: Vec<_> = (1..n)
.map(|i| PI * PI * (i as f64 * h * PI).sin())
.collect();
let solution: Vec<_> = (1..n).map(|i| (i as f64 * h * PI).sin()).collect();
let solution = arr1(&solution);
let rhs = arr1(&rhs);
let solver = cg(&mat, &rhs, 1e-10, 1e-8).unwrap();
let e = (solution - solver.solution()).norm_l2();
assert!(e < 1e-4);
}
}