use crate::utils::{LinearSystem, M, S, V};
use ndarray::ArrayBase;
use std::f64::{MIN_POSITIVE, NAN};
use streaming_iterator::*;
#[derive(Clone, Debug)]
pub struct ConjugateGradient {
pub a: M,
pub b: V,
pub x_k: V,
pub solution: V,
pub r_k: V,
pub p_k: V,
pub alpha_k: S,
pub beta_k: S,
pub r_k2: S,
pub r_km2: S,
pub ap_k: V,
pub pap_k: S,
pub pap_km: S,
}
impl ConjugateGradient {
pub fn for_problem(p: &LinearSystem) -> ConjugateGradient {
let x_0 = match &p.x0 {
Some(x) => x.clone(),
None => ArrayBase::zeros(p.a.shape()[0]),
};
let r_k = (&p.a.dot(&x_0) - &p.b).to_shared();
let r_k2 = r_k.dot(&r_k);
let r_km2 = NAN;
let p_k = -r_k.clone();
let ap_k = p.a.dot(&p_k).to_shared();
let pap_k = p_k.dot(&ap_k);
let pap_km = NAN;
ConjugateGradient {
x_k: x_0.clone(),
solution: x_0,
a: p.a.clone(),
b: p.b.clone(),
r_k,
r_k2,
r_km2,
p_k,
ap_k,
pap_k,
pap_km,
alpha_k: NAN,
beta_k: NAN,
}
}
}
fn too_small(v: S) -> bool {
v < 10. * MIN_POSITIVE
}
impl StreamingIterator for ConjugateGradient {
type Item = Self;
fn advance(&mut self) {
self.alpha_k = self.r_k2 / self.pap_k;
if (!too_small(self.r_k2)) && (!too_small(self.pap_k)) {
self.solution = self.x_k.clone();
self.x_k = (self.x_k.clone() + &self.p_k * self.alpha_k).to_shared();
self.r_k = (self.r_k.clone() + &self.ap_k * self.alpha_k).to_shared();
self.r_km2 = self.r_k2;
self.r_k2 = self.r_k.dot(&self.r_k);
self.beta_k = self.r_k2 / self.r_km2;
self.p_k = (-&self.r_k + (self.beta_k * &self.p_k)).to_shared();
self.ap_k = (self.a.dot(&self.p_k)).to_shared();
self.pap_km = self.pap_k;
self.pap_k = self.p_k.dot(&self.ap_k);
} else {
self.r_km2 = self.r_k2;
self.pap_km = self.pap_k;
}
}
fn get(&self) -> Option<&Self::Item> {
if !too_small(self.r_km2) && (!too_small(self.pap_km)) {
Some(self)
} else {
None
}
}
}