use super::{ComplexCooMatrix, CooMatrix};
use crate::StrError;
use russell_lab::{complex_vec_norm, complex_vec_update, cpx, Complex64, ComplexVector};
use russell_lab::{find_index_abs_max, vec_norm, vec_update, Norm, Vector};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct VerifyLinSys {
pub max_abs_a: f64, pub max_abs_ax: f64, pub max_abs_diff: f64, pub relative_error: f64, }
impl VerifyLinSys {
pub fn from(mat: &CooMatrix, x: &Vector, rhs: &Vector) -> Result<Self, StrError> {
let (nrow, ncol, _, _) = mat.get_info();
if x.dim() != ncol {
return Err("x.dim() must be equal to ncol");
}
if rhs.dim() != nrow {
return Err("rhs.dim() must be equal to nrow");
}
let values = mat.get_values();
if values.len() < 1 {
return Err("matrix is empty");
}
let idx = find_index_abs_max(values);
let max_abs_a = f64::abs(values[idx as usize]);
let mut ax = Vector::new(nrow);
mat.mat_vec_mul(&mut ax, 1.0, &x).unwrap(); let max_abs_ax = vec_norm(&ax, Norm::Max);
vec_update(&mut ax, -1.0, &rhs).unwrap(); let max_abs_diff = vec_norm(&ax, Norm::Max);
let relative_error = max_abs_diff / (max_abs_a + 1.0);
Ok(VerifyLinSys {
max_abs_a,
max_abs_ax,
max_abs_diff,
relative_error,
})
}
pub fn from_complex(mat: &ComplexCooMatrix, x: &ComplexVector, rhs: &ComplexVector) -> Result<Self, StrError> {
let (nrow, ncol, _, _) = mat.get_info();
if x.dim() != ncol {
return Err("x.dim() must be equal to ncol");
}
if rhs.dim() != nrow {
return Err("rhs.dim() must be equal to nrow");
}
let values = mat.get_values();
if values.len() < 1 {
return Err("matrix is empty");
}
let nnz = values.len();
let mut max_abs_a = 0.0;
for k in 0..nnz {
let abs = values[k].norm();
if abs > max_abs_a {
max_abs_a = abs;
}
}
let mut ax = ComplexVector::new(nrow);
mat.mat_vec_mul(&mut ax, cpx!(1.0, 0.0), &x).unwrap(); let max_abs_ax = complex_vec_norm(&ax, Norm::Max);
complex_vec_update(&mut ax, cpx!(-1.0, 0.0), &rhs).unwrap(); let max_abs_diff = complex_vec_norm(&ax, Norm::Max);
let relative_error = max_abs_diff / (max_abs_a + 1.0);
Ok(VerifyLinSys {
max_abs_a,
max_abs_ax,
max_abs_diff,
relative_error,
})
}
}
#[cfg(test)]
mod tests {
use super::VerifyLinSys;
use crate::{ComplexCooMatrix, CooMatrix, Samples, Sym};
use russell_lab::{approx_eq, cpx, Complex64, ComplexVector, Vector};
#[test]
fn from_captures_errors() {
let coo = CooMatrix::new(2, 1, 1, Sym::No).unwrap();
let x = Vector::new(1);
let rhs = Vector::new(2);
assert_eq!(VerifyLinSys::from(&coo, &x, &rhs).err(), Some("matrix is empty"));
let x_wrong = Vector::new(2);
let rhs_wrong = Vector::new(1);
assert_eq!(
VerifyLinSys::from(&coo, &x_wrong, &rhs).err(),
Some("x.dim() must be equal to ncol")
);
assert_eq!(
VerifyLinSys::from(&coo, &x, &rhs_wrong).err(),
Some("rhs.dim() must be equal to nrow")
);
let coo = ComplexCooMatrix::new(2, 1, 1, Sym::No).unwrap();
let x = ComplexVector::new(1);
let rhs = ComplexVector::new(2);
assert_eq!(
VerifyLinSys::from_complex(&coo, &x, &rhs).err(),
Some("matrix is empty")
);
let x_wrong = ComplexVector::new(2);
let rhs_wrong = ComplexVector::new(1);
assert_eq!(
VerifyLinSys::from_complex(&coo, &x_wrong, &rhs).err(),
Some("x.dim() must be equal to ncol")
);
assert_eq!(
VerifyLinSys::from_complex(&coo, &x, &rhs_wrong).err(),
Some("rhs.dim() must be equal to nrow")
);
}
#[test]
fn new_works() {
let mut coo = CooMatrix::new(3, 3, 9, Sym::No).unwrap();
coo.put(0, 0, 1.0).unwrap();
coo.put(0, 1, 3.0).unwrap();
coo.put(0, 2, -2.0).unwrap();
coo.put(1, 0, 3.0).unwrap();
coo.put(1, 1, 5.0).unwrap();
coo.put(1, 2, 6.0).unwrap();
coo.put(2, 0, 2.0).unwrap();
coo.put(2, 1, 4.0).unwrap();
coo.put(2, 2, 3.0).unwrap();
let x = Vector::from(&[-15.0, 8.0, 2.0]);
let rhs = Vector::from(&[5.0, 7.0, 8.0]);
let verify = VerifyLinSys::from(&coo, &x, &rhs).unwrap();
assert_eq!(verify.max_abs_a, 6.0);
assert_eq!(verify.max_abs_ax, 8.0);
assert_eq!(verify.max_abs_diff, 0.0);
assert_eq!(verify.relative_error, 0.0);
}
#[test]
fn new_rectangular_matrix_works() {
let (coo, _, _, _) = Samples::rectangular_3x4();
let x = Vector::from(&[1.0, 3.0, 8.0, 5.0]);
let rhs = Vector::from(&[0.0, 0.0, 0.0]);
let a_times_x = &[4.0, 8.0, 12.0];
let verify = VerifyLinSys::from(&coo, &x, &rhs).unwrap();
assert_eq!(verify.max_abs_a, 15.0);
assert_eq!(verify.max_abs_ax, 12.0);
assert_eq!(verify.max_abs_diff, 12.0);
approx_eq(verify.relative_error, 12.0 / (15.0 + 1.0), 1e-15);
let verify = VerifyLinSys::from(&coo, &x, &Vector::from(a_times_x)).unwrap();
assert_eq!(verify.max_abs_a, 15.0);
assert_eq!(verify.max_abs_ax, 12.0);
assert_eq!(verify.max_abs_diff, 0.0);
approx_eq(verify.relative_error, 0.0, 1e-15);
}
#[test]
fn new_complex_matrix_works() {
let (coo, _, _, _) = Samples::complex_rectangular_4x3();
let x = ComplexVector::from(&[cpx!(1.0, 2.0), cpx!(2.0, -1.0), cpx!(0.0, 1.0)]);
let rhs = ComplexVector::from(&[cpx!(-6.0, 14.0), cpx!(-1.0, 2.0), cpx!(14.0, 6.0), cpx!(1.0, 2.0)]);
let verify = VerifyLinSys::from_complex(&coo, &x, &rhs).unwrap();
approx_eq(verify.max_abs_a, 7.0710678118654755, 1e-15);
approx_eq(verify.max_abs_ax, 15.231546211727817, 1e-15);
approx_eq(verify.max_abs_diff, 0.0, 1e-15);
approx_eq(verify.relative_error, 0.0, 1e-15);
let rhs = ComplexVector::from(&[cpx!(-6.0, 14.0), cpx!(-1.0, 2.0), cpx!(14.0, 6.0), cpx!(1.0, 0.0)]);
let verify = VerifyLinSys::from_complex(&coo, &x, &rhs).unwrap();
approx_eq(verify.max_abs_a, 7.0710678118654755, 1e-15);
approx_eq(verify.max_abs_ax, 15.231546211727817, 1e-15);
approx_eq(verify.max_abs_diff, 2.0, 1e-15);
approx_eq(verify.relative_error, 2.0 / (7.0710678118654755 + 1.0), 1e-15);
}
}