use crate::backsolver::SensBacksolver;
use crate::p_calculator::PCalculator;
use crate::schur_driver::SchurDriver;
use pounce_common::types::Number;
pub trait SensStepCalc {
fn compute_step(&self, rhs_u: &[Number], du: &mut [Number], dx_full: &mut [Number]) -> bool;
}
pub struct StdStepCalc<'d, D: SchurDriver + WithBacksolver, P: PCalculator> {
driver: &'d D,
pcalc: &'d P,
}
impl<'d, D: SchurDriver + WithBacksolver, P: PCalculator> StdStepCalc<'d, D, P> {
pub fn new(driver: &'d D, pcalc: &'d P) -> Self {
Self { driver, pcalc }
}
}
pub trait WithBacksolver {
fn k_solve(&self, rhs: &[Number], out: &mut [Number]) -> bool;
}
impl<'d, D, P> SensStepCalc for StdStepCalc<'d, D, P>
where
D: SchurDriver + WithBacksolver,
P: PCalculator,
{
fn compute_step(&self, rhs_u: &[Number], du: &mut [Number], dx_full: &mut [Number]) -> bool {
if !self.driver.schur_solve(rhs_u, du) {
return false;
}
let a = self.pcalc.data_a();
let n_full = dx_full.len();
let mut rhs_full = vec![0.0; n_full];
if let Err(_) = a.trans_multiply(du, &mut rhs_full) {
return false;
}
self.driver.k_solve(&rhs_full, dx_full)
}
}
impl<B> WithBacksolver
for crate::schur_driver::DenseGenSchurDriver<crate::p_calculator::IndexPCalculator<B>, B>
where
B: SensBacksolver,
{
fn k_solve(&self, rhs: &[Number], out: &mut [Number]) -> bool {
self.pcalc().backsolver().solve(rhs, out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backsolver::DenseLuBacksolver;
use crate::p_calculator::IndexPCalculator;
use crate::schur_data::IndexSchurData;
use crate::schur_driver::DenseGenSchurDriver;
#[test]
fn std_step_calc_runs_two_step_pipeline() {
#[rustfmt::skip]
let k = vec![
2.0, -1.0, 0.0,
-1.0, 2.0, -1.0,
0.0, -1.0, 2.0,
];
let backsolver = DenseLuBacksolver::from_dense(3, &k).unwrap();
let a = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
let pc = IndexPCalculator::new(backsolver, a);
let mut driver = DenseGenSchurDriver::<_, DenseLuBacksolver>::new(pc);
let b = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
assert!(driver.schur_build_and_factor(&b));
let step = StdStepCalc::new(&driver, driver.pcalc());
let rhs_u = [1.0, 0.0];
let mut du = [0.0; 2];
let mut dx = [0.0; 3];
assert!(step.compute_step(&rhs_u, &mut du, &mut dx));
assert!((du[0] - (-1.5)).abs() < 1e-10, "du[0] = {}", du[0]);
assert!((du[1] - 0.5).abs() < 1e-10, "du[1] = {}", du[1]);
assert!((dx[0] - (-1.0)).abs() < 1e-10, "dx[0] = {}", dx[0]);
assert!((dx[1] - (-0.5)).abs() < 1e-10, "dx[1] = {}", dx[1]);
assert!((dx[2] - 0.0).abs() < 1e-10, "dx[2] = {}", dx[2]);
}
}