use crate::backsolver::{DenseLuBacksolver, SensBacksolver};
use crate::p_calculator::PCalculator;
use pounce_common::types::Number;
pub trait SchurDriver {
fn schur_build_and_factor(&mut self, b: &dyn crate::SchurData) -> bool;
fn schur_solve(&self, rhs: &[Number], x: &mut [Number]) -> bool;
fn schur_dim(&self) -> Option<usize>;
}
pub struct DenseGenSchurDriver<P: PCalculator, B: SensBacksolver> {
pcalc: P,
schur_lu: Option<DenseLuBacksolver>,
_b: std::marker::PhantomData<B>,
}
impl<P: PCalculator, B: SensBacksolver> DenseGenSchurDriver<P, B> {
pub fn new(pcalc: P) -> Self {
Self {
pcalc,
schur_lu: None,
_b: std::marker::PhantomData,
}
}
pub fn pcalc(&self) -> &P {
&self.pcalc
}
}
impl<P: PCalculator, B: SensBacksolver> SchurDriver for DenseGenSchurDriver<P, B> {
fn schur_build_and_factor(&mut self, b: &dyn crate::SchurData) -> bool {
let n_b = b.nrows() as usize;
let n_a = self.pcalc.data_a().nrows() as usize;
if n_b != n_a {
return false;
}
let mut s_col_major = vec![0.0; n_b * n_a];
if !self.pcalc.schur_matrix(b, &mut s_col_major) {
return false;
}
let mut s_row_major = vec![0.0; n_b * n_a];
for i in 0..n_b {
for j in 0..n_a {
s_row_major[i * n_a + j] = s_col_major[j * n_b + i];
}
}
match DenseLuBacksolver::from_dense(n_b, &s_row_major) {
Ok(lu) => {
self.schur_lu = Some(lu);
true
}
Err(_) => false,
}
}
fn schur_solve(&self, rhs: &[Number], x: &mut [Number]) -> bool {
match self.schur_lu.as_ref() {
Some(lu) => lu.solve(rhs, x),
None => false,
}
}
fn schur_dim(&self) -> Option<usize> {
self.schur_lu.as_ref().map(|lu| lu.dim())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backsolver::DenseLuBacksolver;
use crate::p_calculator::IndexPCalculator;
use crate::schur_data::IndexSchurData;
#[test]
fn dense_gen_schur_driver_factors_and_solves() {
#[rustfmt::skip]
let k = vec![
2.0, -1.0, 0.0,
-1.0, 2.0, -1.0,
0.0, -1.0, 2.0,
];
let backsolver = DenseLuBacksolver::from_dense(3, &k).unwrap();
let a = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
let pc = IndexPCalculator::new(backsolver, a);
let mut driver = DenseGenSchurDriver::<_, DenseLuBacksolver>::new(pc);
let b = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
assert!(driver.schur_build_and_factor(&b));
assert_eq!(driver.schur_dim(), Some(2));
let rhs = [1.0, 0.0];
let mut x = [0.0; 2];
assert!(driver.schur_solve(&rhs, &mut x));
assert!((x[0] - (-1.5)).abs() < 1e-10, "x[0] = {}", x[0]);
assert!((x[1] - 0.5).abs() < 1e-10, "x[1] = {}", x[1]);
let s00 = -0.75;
let s01 = -0.25;
let s10 = -0.25;
let s11 = -0.75;
let recon0 = s00 * x[0] + s01 * x[1];
let recon1 = s10 * x[0] + s11 * x[1];
assert!((recon0 - 1.0).abs() < 1e-10);
assert!((recon1 - 0.0).abs() < 1e-10);
}
#[test]
fn schur_solve_before_factor_fails() {
#[rustfmt::skip]
let k = vec![ 1.0, 0.0, 0.0, 1.0 ];
let backsolver = DenseLuBacksolver::from_dense(2, &k).unwrap();
let a = IndexSchurData::from_parts(vec![0], vec![1]).unwrap();
let pc = IndexPCalculator::new(backsolver, a);
let driver = DenseGenSchurDriver::<_, DenseLuBacksolver>::new(pc);
let rhs = [1.0];
let mut x = [0.0; 1];
assert!(!driver.schur_solve(&rhs, &mut x));
assert_eq!(driver.schur_dim(), None);
}
#[test]
fn schur_build_rejects_b_a_dim_mismatch() {
#[rustfmt::skip]
let k = vec![
2.0, 0.0, 0.0,
0.0, 3.0, 0.0,
0.0, 0.0, 4.0,
];
let backsolver = DenseLuBacksolver::from_dense(3, &k).unwrap();
let a = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
let pc = IndexPCalculator::new(backsolver, a);
let mut driver = DenseGenSchurDriver::<_, DenseLuBacksolver>::new(pc);
let b = IndexSchurData::from_parts(vec![1], vec![1]).unwrap();
assert!(!driver.schur_build_and_factor(&b));
}
}