use crate::backsolver::SensBacksolver;
use crate::schur_data::{IndexSchurData, SchurData};
use pounce_common::types::{Index, Number};
use std::collections::HashMap;
pub trait PCalculator {
fn data_a(&self) -> &dyn SchurData;
fn compute_p(&mut self) -> bool {
false
}
fn schur_matrix(&mut self, _b: &dyn SchurData, _dense_schur: &mut [Number]) -> bool {
false
}
}
pub struct IndexPCalculator<B: SensBacksolver> {
backsolver: B,
data_a: IndexSchurData,
n_full: usize,
p_cols: HashMap<Index, Vec<Number>>,
}
impl<B: SensBacksolver> IndexPCalculator<B> {
pub fn new(backsolver: B, data_a: IndexSchurData) -> Self {
let n_full = backsolver.dim();
Self {
backsolver,
data_a,
n_full,
p_cols: HashMap::new(),
}
}
pub fn n_full(&self) -> usize {
self.n_full
}
pub fn p_columns(&self) -> &HashMap<Index, Vec<Number>> {
&self.p_cols
}
pub fn backsolver(&self) -> &B {
&self.backsolver
}
}
impl<B: SensBacksolver> PCalculator for IndexPCalculator<B> {
fn data_a(&self) -> &dyn SchurData {
&self.data_a
}
fn compute_p(&mut self) -> bool {
let cols = self.data_a.col_indices().to_vec();
let signs = self.data_a.signs().to_vec();
for (i, &col) in cols.iter().enumerate() {
if self.p_cols.contains_key(&col) {
continue;
}
let mut rhs = vec![0.0; self.n_full];
let c_us = col as usize;
if c_us >= self.n_full {
return false;
}
rhs[c_us] = signs[i] as Number;
let mut p_col = vec![0.0; self.n_full];
if !self.backsolver.solve(&rhs, &mut p_col) {
return false;
}
self.p_cols.insert(col, p_col);
}
true
}
fn schur_matrix(&mut self, b: &dyn SchurData, dense_schur: &mut [Number]) -> bool {
let n_b = b.nrows() as usize;
let n_a = self.data_a.nrows() as usize;
if dense_schur.len() != n_b * n_a {
return false;
}
if !self.compute_p() {
return false;
}
let a_cols = self.data_a.col_indices().to_vec();
for (j, &a_col) in a_cols.iter().enumerate() {
let p_col = match self.p_cols.get(&a_col) {
Some(v) => v,
None => return false,
};
for i in 0..n_b {
let (b_idx_vec, _facs) = match b.multiplying_row(i as Index) {
Ok(t) => t,
Err(_) => return false,
};
let b_col = b_idx_vec[0] as usize;
if b_col >= p_col.len() {
return false;
}
dense_schur[j * n_b + i] = -p_col[b_col];
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schur_data::IndexSchurData;
struct StubPCalculator {
a: IndexSchurData,
}
impl PCalculator for StubPCalculator {
fn data_a(&self) -> &dyn SchurData {
&self.a
}
}
#[test]
fn trait_default_compute_p_returns_false() {
let a = IndexSchurData::from_parts(vec![0, 2], vec![1, -1]).unwrap();
let mut pc = StubPCalculator { a };
assert!(
!pc.compute_p(),
"default compute_p must return false until Phase B"
);
}
#[test]
fn trait_default_schur_matrix_returns_false() {
let a = IndexSchurData::from_parts(vec![0, 2], vec![1, -1]).unwrap();
let b = IndexSchurData::from_parts(vec![1], vec![1]).unwrap();
let mut pc = StubPCalculator { a };
let mut out = vec![0.0; 1 * 2];
assert!(!pc.schur_matrix(&b, &mut out));
}
#[test]
fn data_a_round_trips_to_concrete_schur_data() {
let a = IndexSchurData::from_parts(vec![3, 5], vec![1, 1]).unwrap();
let pc = StubPCalculator { a };
assert_eq!(pc.data_a().nrows(), 2);
}
use crate::backsolver::DenseLuBacksolver;
#[test]
fn compute_p_solves_each_a_column_against_K() {
#[rustfmt::skip]
let k = vec![
2.0, -1.0, 0.0,
-1.0, 2.0, -1.0,
0.0, -1.0, 2.0,
];
let backsolver = DenseLuBacksolver::from_dense(3, &k).expect("factor");
let a = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
let mut pc = IndexPCalculator::new(backsolver, a);
assert!(pc.compute_p());
let p0 = pc.p_columns().get(&0).expect("col 0 cached");
assert!((p0[0] - 0.75).abs() < 1e-12);
assert!((p0[1] - 0.50).abs() < 1e-12);
assert!((p0[2] - 0.25).abs() < 1e-12);
let p2 = pc.p_columns().get(&2).expect("col 2 cached");
assert!((p2[0] - 0.25).abs() < 1e-12);
assert!((p2[1] - 0.50).abs() < 1e-12);
assert!((p2[2] - 0.75).abs() < 1e-12);
}
#[test]
fn compute_p_uses_sign_from_a_data() {
#[rustfmt::skip]
let k = vec![
2.0, -1.0, 0.0,
-1.0, 2.0, -1.0,
0.0, -1.0, 2.0,
];
let backsolver = DenseLuBacksolver::from_dense(3, &k).unwrap();
let a = IndexSchurData::from_parts(vec![1], vec![-1]).unwrap();
let mut pc = IndexPCalculator::new(backsolver, a);
assert!(pc.compute_p());
let p1 = pc.p_columns().get(&1).expect("col 1 cached");
assert!((p1[0] - (-0.5)).abs() < 1e-12);
assert!((p1[1] - (-1.0)).abs() < 1e-12);
assert!((p1[2] - (-0.5)).abs() < 1e-12);
}
#[test]
fn schur_matrix_matches_closed_form_minus_b_kinv_a() {
#[rustfmt::skip]
let k = vec![
2.0, -1.0, 0.0,
-1.0, 2.0, -1.0,
0.0, -1.0, 2.0,
];
let backsolver = DenseLuBacksolver::from_dense(3, &k).unwrap();
let a = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
let b = IndexSchurData::from_parts(vec![1], vec![1]).unwrap();
let mut pc = IndexPCalculator::new(backsolver, a);
let mut s = vec![0.0; 1 * 2];
assert!(pc.schur_matrix(&b, &mut s));
assert!((s[0] - (-0.5)).abs() < 1e-12, "S[0,0] = {}", s[0]);
assert!((s[1] - (-0.5)).abs() < 1e-12, "S[0,1] = {}", s[1]);
}
#[test]
fn schur_matrix_reproduces_independent_computation() {
#[rustfmt::skip]
let k = vec![
4.0, 1.0, 0.0, 0.0,
1.0, 4.0, 1.0, 0.0,
0.0, 1.0, 4.0, 1.0,
0.0, 0.0, 1.0, 4.0,
];
let backsolver = DenseLuBacksolver::from_dense(4, &k).unwrap();
let a_data = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).unwrap();
let mut pc = IndexPCalculator::new(backsolver, a_data);
assert!(pc.compute_p());
let bs2 = DenseLuBacksolver::from_dense(4, &k).unwrap();
let mut kinv_e1 = vec![0.0; 4];
bs2.solve(&[0.0, 1.0, 0.0, 0.0], &mut kinv_e1);
let mut kinv_minus_e3 = vec![0.0; 4];
bs2.solve(&[0.0, 0.0, 0.0, -1.0], &mut kinv_minus_e3);
let mut kinv_a = vec![0.0; 4 * 2];
for r in 0..4 {
kinv_a[r * 2 + 0] = kinv_e1[r];
kinv_a[r * 2 + 1] = kinv_minus_e3[r];
}
let b_data = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
let mut s_actual = vec![0.0; 2 * 2];
assert!(pc.schur_matrix(&b_data, &mut s_actual));
let mut s_expected = vec![0.0; 2 * 2];
let b_idx = [0usize, 2];
for (i, &row) in b_idx.iter().enumerate() {
for j in 0..2 {
s_expected[j * 2 + i] = -kinv_a[row * 2 + j];
}
}
for k in 0..4 {
assert!(
(s_actual[k] - s_expected[k]).abs() < 1e-10,
"S[{}] actual={}, expected={}",
k,
s_actual[k],
s_expected[k],
);
}
}
}