use std::slice;
use num_traits::{Float, Zero};
use crate::error::{Result, Error};
use crate::matrix::{matrix_mul, matrix_transpose_mul, FloatComplex};
const DEBUG_CGSOLVE: bool = false;
pub fn matrix_cgsolve<T>(a: &[T], n: usize, b: &[T], x: &mut [T], _opts: Option<()>) -> Result<()>
where
T: FloatComplex,
{
if n == 0 {
return Err(Error::Range("matrix_cgsolve(), system dimension cannot be zero".to_owned()));
}
let max_iterations = 4 * n; let tol = T::Real::from(1e-6);
let mut x0 = vec![T::zero(); n];
let mut x1 = vec![T::zero(); n];
let mut d0 = vec![T::zero(); n];
let mut d1 = vec![T::zero(); n];
let mut r0 = vec![T::zero(); n];
let mut r1 = vec![T::zero(); n];
let mut q = vec![T::zero(); n];
let mut ax1 = vec![T::zero(); n];
for j in 0..n {
d0[j] = b[j];
}
r0.copy_from_slice(&d0);
let mut delta_init = T::zero();
let mut delta0 = T::zero();
matrix_transpose_mul(b, n, 1, slice::from_mut(&mut delta_init));
matrix_transpose_mul(&r0, n, 1, slice::from_mut(&mut delta0));
x.copy_from_slice(&x0);
let mut i = 0; let mut res_opt = T::Real::zero();
while i < max_iterations && delta0.re() > tol * tol * delta_init.re() {
if DEBUG_CGSOLVE {
println!("*********** {} / {} (max) **************", i, max_iterations);
println!(" comparing {:e} > {:e}", delta0.re(), tol * tol * delta_init.re());
}
matrix_mul(a, n, n, &d0, n, 1, &mut q, n, 1)?;
let gamma: T = d0.iter().zip(q.iter()).map(|(&d, &q)| d.conj() * q).sum::<T>();
let alpha = delta0 / gamma;
if DEBUG_CGSOLVE {
println!(" alpha = {:e}", alpha.re());
println!(" delta0 = {:e}", delta0.re());
}
for j in 0..n {
x1[j] = x0[j] + alpha * d0[j];
}
if DEBUG_CGSOLVE {
println!(" x:");
}
if (i + 1) % 50 == 0 {
matrix_mul(a, n, n, &x1, n, 1, &mut ax1, n, 1)?;
for j in 0..n {
r1[j] = b[j] - ax1[j];
}
} else {
for j in 0..n {
r1[j] = r0[j] - alpha * q[j];
}
}
let mut delta1 = T::zero();
matrix_transpose_mul(&r1, n, 1, slice::from_mut(&mut delta1));
let beta = delta1 / delta0;
for j in 0..n {
d1[j] = r1[j] + beta * d0[j];
}
let res = Float::sqrt(delta1.abs() / delta_init.abs());
if i == 0 || res < res_opt {
res_opt = res;
x.copy_from_slice(&x1);
}
if DEBUG_CGSOLVE {
println!(" res = {:e}", res);
}
x0.copy_from_slice(&x1);
d0.copy_from_slice(&d1);
r0.copy_from_slice(&r1);
delta0 = delta1;
i += 1;
}
Ok(())
}