use pounce_algorithm::kkt::aug_system_solver::{
AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver,
};
use pounce_common::types::{Index, Number};
use pounce_linalg::compound_vector::CompoundVector;
use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
use pounce_linalg::triplet::{GenTMatrix, GenTMatrixSpace, SymTMatrix, SymTMatrixSpace};
use pounce_linalg::{LowRankUpdateSymMatrix, Matrix, MultiVectorMatrix, Vector};
use pounce_linsol::ESymSolverStatus;
use std::rc::Rc;
pub struct AugRestoSystemSolver {
inner: Box<dyn AugSystemSolver>,
initialized: bool,
n_orig: Index,
m_eq: Index,
m_ineq: Index,
nz_jc_orig: usize,
nz_jd_orig: usize,
h_orig: Option<SymTMatrix>,
j_c_orig: Option<GenTMatrix>,
j_d_orig: Option<GenTMatrix>,
space_m_eq: Option<Rc<DenseVectorSpace>>,
space_m_ineq: Option<Rc<DenseVectorSpace>>,
}
impl std::fmt::Debug for AugRestoSystemSolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AugRestoSystemSolver")
.field("initialized", &self.initialized)
.field("n_orig", &self.n_orig)
.field("m_eq", &self.m_eq)
.field("m_ineq", &self.m_ineq)
.finish_non_exhaustive()
}
}
impl AugRestoSystemSolver {
pub fn new(inner: Box<dyn AugSystemSolver>) -> Self {
Self {
inner,
initialized: false,
n_orig: 0,
m_eq: 0,
m_ineq: 0,
nz_jc_orig: 0,
nz_jd_orig: 0,
h_orig: None,
j_c_orig: None,
j_d_orig: None,
space_m_eq: None,
space_m_ineq: None,
}
}
fn build_structure(&mut self, w: &SymTMatrix, j_c: &GenTMatrix, j_d: &GenTMatrix) {
let m_eq = j_c.n_rows();
let m_ineq = j_d.n_rows();
let n_total = j_c.n_cols();
let n_orig = n_total - 2 * m_eq - 2 * m_ineq;
let h_space = SymTMatrixSpace::new(n_orig, w.irows().to_vec(), w.jcols().to_vec());
self.h_orig = Some(SymTMatrix::new(h_space));
let nz_jc_orig = (j_c.nonzeros() as usize).saturating_sub(2 * m_eq as usize);
let jc_space = GenTMatrixSpace::new(
m_eq,
n_orig,
j_c.irows()[..nz_jc_orig].to_vec(),
j_c.jcols()[..nz_jc_orig].to_vec(),
);
self.j_c_orig = Some(GenTMatrix::new(jc_space));
let nz_jd_orig = (j_d.nonzeros() as usize).saturating_sub(2 * m_ineq as usize);
let jd_space = GenTMatrixSpace::new(
m_ineq,
n_orig,
j_d.irows()[..nz_jd_orig].to_vec(),
j_d.jcols()[..nz_jd_orig].to_vec(),
);
self.j_d_orig = Some(GenTMatrix::new(jd_space));
self.space_m_eq = Some(DenseVectorSpace::new(m_eq));
self.space_m_ineq = Some(DenseVectorSpace::new(m_ineq));
self.n_orig = n_orig;
self.m_eq = m_eq;
self.m_ineq = m_ineq;
self.nz_jc_orig = nz_jc_orig;
self.nz_jd_orig = nz_jd_orig;
self.initialized = true;
}
fn refill_values(&mut self, w: &SymTMatrix, j_c: &GenTMatrix, j_d: &GenTMatrix) {
let h_dst = self.h_orig.as_mut().unwrap().values_mut();
h_dst.copy_from_slice(w.values());
let jc_dst = self.j_c_orig.as_mut().unwrap().values_mut();
jc_dst.copy_from_slice(&j_c.values()[..self.nz_jc_orig]);
let jd_dst = self.j_d_orig.as_mut().unwrap().values_mut();
jd_dst.copy_from_slice(&j_d.values()[..self.nz_jd_orig]);
}
}
impl AugSystemSolver for AugRestoSystemSolver {
fn provides_inertia(&self) -> bool {
self.inner.provides_inertia()
}
fn number_of_neg_evals(&self) -> Index {
self.inner.number_of_neg_evals()
}
fn increase_quality(&mut self) -> bool {
self.inner.increase_quality()
}
fn last_solve_status(&self) -> ESymSolverStatus {
self.inner.last_solve_status()
}
fn solve(
&mut self,
coeffs: &AugSysCoeffs<'_>,
rhs: &AugSysRhs<'_>,
sol: &mut AugSysSol<'_>,
check_neg_evals: bool,
num_neg_evals: Index,
) -> ESymSolverStatus {
let w_dyn = coeffs
.w
.expect("AugRestoSystemSolver: W must be present (resto Hessian)");
let j_c = coeffs
.j_c
.as_any()
.downcast_ref::<GenTMatrix>()
.expect("AugRestoSystemSolver: J_c must be a GenTMatrix");
let j_d = coeffs
.j_d
.as_any()
.downcast_ref::<GenTMatrix>()
.expect("AugRestoSystemSolver: J_d must be a GenTMatrix");
let m_eq = j_c.n_rows();
let m_ineq = j_d.n_rows();
let n_orig = j_c.n_cols() - 2 * m_eq - 2 * m_ineq;
let dx_compound = coeffs
.d_x
.expect("AugRestoSystemSolver: D_x must be present (5-block compound)")
.as_any()
.downcast_ref::<CompoundVector>()
.expect("AugRestoSystemSolver: D_x must be a CompoundVector");
debug_assert_eq!(dx_compound.n_comps(), 5);
let w_owned;
let w = match w_dyn.as_any().downcast_ref::<SymTMatrix>() {
Some(w) => w,
None => {
w_owned = materialize_orig_block(w_dyn, n_orig);
&w_owned
}
};
if !self.initialized {
self.build_structure(w, j_c, j_d);
}
self.refill_values(w, j_c, j_d);
let m_eq = self.m_eq as usize;
let m_ineq = self.m_ineq as usize;
let dbg = std::env::var("POUNCE_RESTO_DBG").is_ok();
if dbg {
tracing::debug!(target: "pounce::restoration",
"[resto-aug] n_orig={} m_eq={} m_ineq={} W.nz={} J_c.nz={} J_d.nz={} delta_x={:.3e} delta_c={:.3e} delta_d={:.3e}",
self.n_orig, self.m_eq, self.m_ineq,
w.nonzeros(), j_c.nonzeros(), j_d.nonzeros(),
coeffs.delta_x, coeffs.delta_c, coeffs.delta_d,
);
}
let sigma_orig_dyn = dx_compound.comp(0); let sigma_n_c = dense_values(dx_compound.comp(1));
let sigma_p_c = dense_values(dx_compound.comp(2));
let sigma_n_d = dense_values(dx_compound.comp(3));
let sigma_p_d = dense_values(dx_compound.comp(4));
let dx = coeffs.delta_x;
let sig_tilde_n_c_inv: Vec<Option<Number>> = sigma_n_c
.iter()
.map(|&s| sigma_tilde_inv_elem(Some(s), dx))
.collect();
let sig_tilde_p_c_inv: Vec<Option<Number>> = sigma_p_c
.iter()
.map(|&s| sigma_tilde_inv_elem(Some(s), dx))
.collect();
let sig_tilde_n_d_inv: Vec<Option<Number>> = sigma_n_d
.iter()
.map(|&s| sigma_tilde_inv_elem(Some(s), dx))
.collect();
let sig_tilde_p_d_inv: Vec<Option<Number>> = sigma_p_d
.iter()
.map(|&s| sigma_tilde_inv_elem(Some(s), dx))
.collect();
let d_c_vals: Option<Vec<Number>> = coeffs.d_c.map(dense_values);
let mut d_c_r = vec![0.0; m_eq];
for i in 0..m_eq {
let n_term = sig_tilde_n_c_inv[i].unwrap_or(0.0);
let p_term = sig_tilde_p_c_inv[i].unwrap_or(0.0);
let d_term = d_c_vals.as_ref().map(|v| v[i]).unwrap_or(0.0);
d_c_r[i] = n_term + p_term + d_term;
}
let mut d_c_r_dense = self.space_m_eq.as_ref().unwrap().make_new_dense();
d_c_r_dense.set_values(&d_c_r);
let d_d_vals: Option<Vec<Number>> = coeffs.d_d.map(dense_values);
let mut d_d_r = vec![0.0; m_ineq];
for i in 0..m_ineq {
let n_term = sig_tilde_n_d_inv[i].unwrap_or(0.0);
let p_term = sig_tilde_p_d_inv[i].unwrap_or(0.0);
let d_term = d_d_vals.as_ref().map(|v| v[i]).unwrap_or(0.0);
d_d_r[i] = n_term + p_term + d_term;
}
let mut d_d_r_dense = self.space_m_ineq.as_ref().unwrap().make_new_dense();
d_d_r_dense.set_values(&d_d_r);
let rhs_x_compound = rhs
.rhs_x
.as_any()
.downcast_ref::<CompoundVector>()
.expect("AugRestoSystemSolver: rhs_x must be a CompoundVector");
debug_assert_eq!(rhs_x_compound.n_comps(), 5);
let rhs_x_r_dyn = rhs_x_compound.comp(0);
let rhs_n_c = dense_values(rhs_x_compound.comp(1));
let rhs_p_c = dense_values(rhs_x_compound.comp(2));
let rhs_n_d = dense_values(rhs_x_compound.comp(3));
let rhs_p_d = dense_values(rhs_x_compound.comp(4));
let rhs_c_vals = dense_values(rhs.rhs_c);
let rhs_d_vals = dense_values(rhs.rhs_d);
let mut rhs_c_r = vec![0.0; m_eq];
for i in 0..m_eq {
rhs_c_r[i] = rhs_cr_elem(
rhs_c_vals[i],
sig_tilde_n_c_inv[i],
rhs_n_c[i],
sig_tilde_p_c_inv[i],
rhs_p_c[i],
);
}
let mut rhs_c_r_dense = self.space_m_eq.as_ref().unwrap().make_new_dense();
rhs_c_r_dense.set_values(&rhs_c_r);
let mut rhs_d_r = vec![0.0; m_ineq];
for i in 0..m_ineq {
let n_contrib = sig_tilde_n_d_inv[i].map(|s| s * rhs_n_d[i]).unwrap_or(0.0);
let p_contrib = sig_tilde_p_d_inv[i].map(|s| s * rhs_p_d[i]).unwrap_or(0.0);
rhs_d_r[i] = rhs_d_vals[i] - n_contrib + p_contrib;
}
let mut rhs_d_r_dense = self.space_m_ineq.as_ref().unwrap().make_new_dense();
rhs_d_r_dense.set_values(&rhs_d_r);
let mut sol_y_c_dense = self.space_m_eq.as_ref().unwrap().make_new_dense();
let mut sol_y_d_dense = self.space_m_ineq.as_ref().unwrap().make_new_dense();
let sol_x_compound = sol
.sol_x
.as_any_mut()
.downcast_mut::<CompoundVector>()
.expect("AugRestoSystemSolver: sol_x must be a CompoundVector");
debug_assert_eq!(sol_x_compound.n_comps(), 5);
let status = {
let sol_x_r = sol_x_compound.comp_mut(0);
let inner_coeffs = AugSysCoeffs {
w: Some(self.h_orig.as_ref().unwrap()),
w_factor: coeffs.w_factor,
d_x: Some(sigma_orig_dyn),
delta_x: coeffs.delta_x,
d_s: coeffs.d_s,
delta_s: coeffs.delta_s,
j_c: self.j_c_orig.as_ref().unwrap(),
d_c: Some(&d_c_r_dense),
delta_c: coeffs.delta_c,
j_d: self.j_d_orig.as_ref().unwrap(),
d_d: Some(&d_d_r_dense),
delta_d: coeffs.delta_d,
};
let inner_rhs = AugSysRhs {
rhs_x: rhs_x_r_dyn,
rhs_s: rhs.rhs_s,
rhs_c: &rhs_c_r_dense,
rhs_d: &rhs_d_r_dense,
};
let mut inner_sol = AugSysSol {
sol_x: sol_x_r,
sol_s: sol.sol_s,
sol_c: &mut sol_y_c_dense,
sol_d: &mut sol_y_d_dense,
};
self.inner.solve(
&inner_coeffs,
&inner_rhs,
&mut inner_sol,
check_neg_evals,
num_neg_evals,
)
};
if status != ESymSolverStatus::Success {
return status;
}
let sol_y_c_vals = sol_y_c_dense.expanded_values();
let sol_y_d_vals = sol_y_d_dense.expanded_values();
if dbg {
let sigma_orig_vals = dense_values(sigma_orig_dyn);
let rhs_x_orig_vals = dense_values(rhs_x_r_dyn);
let inf_norm = |v: &[f64]| v.iter().fold(0.0_f64, |a, &x| a.max(x.abs()));
let sol_x_r = sol_x_compound.comp(0);
let sol_x_orig_vals = dense_values(sol_x_r);
tracing::debug!(target: "pounce::restoration",
"[resto-aug] ||sigma_orig||={:.3e} ||sigma_n_c||={:.3e} ||sigma_p_c||={:.3e} ||sigma_n_d||={:.3e} ||sigma_p_d||={:.3e}",
inf_norm(&sigma_orig_vals),
inf_norm(&sigma_n_c), inf_norm(&sigma_p_c), inf_norm(&sigma_n_d), inf_norm(&sigma_p_d),
);
tracing::debug!(target: "pounce::restoration",
"[resto-aug] ||rhs_x_orig||={:.3e} ||rhs_n_c||={:.3e} ||rhs_p_c||={:.3e} ||rhs_n_d||={:.3e} ||rhs_p_d||={:.3e} ||rhs_c||={:.3e} ||rhs_d||={:.3e}",
inf_norm(&rhs_x_orig_vals), inf_norm(&rhs_n_c), inf_norm(&rhs_p_c),
inf_norm(&rhs_n_d), inf_norm(&rhs_p_d), inf_norm(&rhs_c_vals), inf_norm(&rhs_d_vals),
);
tracing::debug!(target: "pounce::restoration",
"[resto-aug] ||rhs_cR||={:.3e} ||rhs_dR||={:.3e} ||D_cR||={:.3e} ||D_dR||={:.3e} ||sol_x_orig||={:.3e} ||sol_y_c||={:.3e} ||sol_y_d||={:.3e}",
inf_norm(&rhs_c_r), inf_norm(&rhs_d_r),
inf_norm(&d_c_r), inf_norm(&d_d_r),
inf_norm(&sol_x_orig_vals),
inf_norm(&sol_y_c_vals), inf_norm(&sol_y_d_vals),
);
}
downcast_dense_mut(sol.sol_c).set_values(&sol_y_c_vals);
downcast_dense_mut(sol.sol_d).set_values(&sol_y_d_vals);
let mut sol_n_c_vals = vec![0.0; m_eq];
let mut sol_p_c_vals = vec![0.0; m_eq];
for i in 0..m_eq {
sol_n_c_vals[i] =
expand_sol_n_c_elem(rhs_n_c[i], sol_y_c_vals[i], sig_tilde_n_c_inv[i]);
sol_p_c_vals[i] =
expand_sol_p_c_elem(rhs_p_c[i], sol_y_c_vals[i], sig_tilde_p_c_inv[i]);
}
let mut sol_n_d_vals = vec![0.0; m_ineq];
let mut sol_p_d_vals = vec![0.0; m_ineq];
for i in 0..m_ineq {
sol_n_d_vals[i] =
expand_sol_n_c_elem(rhs_n_d[i], sol_y_d_vals[i], sig_tilde_n_d_inv[i]);
sol_p_d_vals[i] =
expand_sol_p_c_elem(rhs_p_d[i], sol_y_d_vals[i], sig_tilde_p_d_inv[i]);
}
downcast_dense_mut(sol_x_compound.comp_mut(1)).set_values(&sol_n_c_vals);
downcast_dense_mut(sol_x_compound.comp_mut(2)).set_values(&sol_p_c_vals);
downcast_dense_mut(sol_x_compound.comp_mut(3)).set_values(&sol_n_d_vals);
downcast_dense_mut(sol_x_compound.comp_mut(4)).set_values(&sol_p_d_vals);
ESymSolverStatus::Success
}
}
fn dense_values(v: &dyn Vector) -> Vec<Number> {
v.as_any()
.downcast_ref::<DenseVector>()
.expect("AugRestoSystemSolver: expected DenseVector argument")
.expanded_values()
}
fn downcast_dense_mut(v: &mut dyn Vector) -> &mut DenseVector {
v.as_any_mut()
.downcast_mut::<DenseVector>()
.expect("AugRestoSystemSolver: expected DenseVector argument")
}
fn materialize_orig_block(w: &dyn Matrix, n_orig: Index) -> SymTMatrix {
let n = n_orig as usize;
let lr = w
.as_any()
.downcast_ref::<LowRankUpdateSymMatrix>()
.expect("AugRestoSystemSolver: resto W must be a SymTMatrix or LowRankUpdateSymMatrix");
assert!(
lr.p_lowrank().is_none() && !lr.reduced_diag(),
"AugRestoSystemSolver: resto W has a p_lowrank/reduced_diag low-rank form \
that the orig-block densification does not cover (pounce#102)"
);
let mut irows = Vec::with_capacity(n * (n + 1) / 2);
let mut jcols = Vec::with_capacity(n * (n + 1) / 2);
for i in 1..=n_orig {
for j in 1..=i {
irows.push(i);
jcols.push(j);
}
}
let space = SymTMatrixSpace::new(n_orig, irows, jcols);
let mut sym = SymTMatrix::new(space);
let diag = lr
.get_diag()
.map(|d| orig_rows(d.as_ref(), n))
.unwrap_or_else(|| vec![0.0; n]);
let v_cols = multi_vector_orig_cols(lr.get_v(), n);
let u_cols = multi_vector_orig_cols(lr.get_u(), n);
let vals = sym.values_mut();
for ii in 0..n {
for jj in 0..=ii {
let idx = (ii + 1) * ii / 2 + jj;
let mut acc = if ii == jj { diag[ii] } else { 0.0 };
for col in &v_cols {
acc += col[ii] * col[jj];
}
for col in &u_cols {
acc -= col[ii] * col[jj];
}
vals[idx] = acc;
}
}
sym
}
fn orig_rows(v: &dyn Vector, n: usize) -> Vec<Number> {
if let Some(c) = v.as_any().downcast_ref::<CompoundVector>() {
let orig = c
.comp(0)
.as_any()
.downcast_ref::<DenseVector>()
.expect("AugRestoSystemSolver: resto W orig block must be a DenseVector");
orig.expanded_values()[..n].to_vec()
} else if let Some(d) = v.as_any().downcast_ref::<DenseVector>() {
d.expanded_values()[..n].to_vec()
} else {
panic!("AugRestoSystemSolver: resto W component must be Dense or Compound");
}
}
fn multi_vector_orig_cols(m: Option<&Rc<MultiVectorMatrix>>, n: usize) -> Vec<Vec<Number>> {
match m {
None => Vec::new(),
Some(mv) => (0..mv.space().n_cols())
.map(|k| orig_rows(mv.get_vector(k).as_ref(), n))
.collect(),
}
}
pub fn sigma_tilde_inv_elem(sigma: Option<f64>, delta_x: f64) -> Option<f64> {
match (sigma, delta_x) {
(Some(s), 0.0) => Some(1.0 / s),
(Some(s), d) => Some(1.0 / (s + d)),
(None, 0.0) => None,
(None, d) => Some(1.0 / d),
}
}
pub fn neg_omega_plus_d_elem(
sigma_tilde_n_inv: Option<f64>,
sigma_tilde_p_inv: Option<f64>,
d_c: Option<f64>,
) -> Option<f64> {
if sigma_tilde_n_inv.is_none() && sigma_tilde_p_inv.is_none() && d_c.is_none() {
return None;
}
let n_term = sigma_tilde_n_inv.unwrap_or(0.0);
let p_term = sigma_tilde_p_inv.unwrap_or(0.0);
let d_term = d_c.unwrap_or(0.0);
Some(-n_term - p_term + d_term)
}
pub fn rhs_cr_elem(
rhs_c: f64,
sigma_tilde_n_inv: Option<f64>,
rhs_n_c: f64,
sigma_tilde_p_inv: Option<f64>,
rhs_p_c: f64,
) -> f64 {
let n_contrib = sigma_tilde_n_inv.map(|s| s * rhs_n_c).unwrap_or(0.0);
let p_contrib = sigma_tilde_p_inv.map(|s| s * rhs_p_c).unwrap_or(0.0);
rhs_c - n_contrib + p_contrib
}
pub fn expand_sol_n_c_elem(rhs_n_c: f64, sol_y_c: f64, sigma_tilde_n_inv: Option<f64>) -> f64 {
sigma_tilde_n_inv
.map(|s| s * (rhs_n_c - sol_y_c))
.unwrap_or(0.0)
}
pub fn expand_sol_p_c_elem(rhs_p_c: f64, sol_y_c: f64, sigma_tilde_p_inv: Option<f64>) -> f64 {
sigma_tilde_p_inv
.map(|s| s * (rhs_p_c + sol_y_c))
.unwrap_or(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sigma_tilde_inv_combines_sigma_and_delta() {
assert_eq!(sigma_tilde_inv_elem(Some(0.25), 0.75), Some(1.0));
}
#[test]
fn sigma_tilde_inv_pure_sigma_path() {
assert_eq!(sigma_tilde_inv_elem(Some(0.5), 0.0), Some(2.0));
}
#[test]
fn sigma_tilde_inv_pure_delta_path() {
assert_eq!(sigma_tilde_inv_elem(None, 0.5), Some(2.0));
}
#[test]
fn sigma_tilde_inv_skips_when_both_absent() {
assert_eq!(sigma_tilde_inv_elem(None, 0.0), None);
}
#[test]
fn neg_omega_returns_none_when_all_absent() {
assert_eq!(neg_omega_plus_d_elem(None, None, None), None);
}
#[test]
fn neg_omega_sums_negated_inverses() {
let r = neg_omega_plus_d_elem(Some(2.0), Some(3.0), Some(0.5));
assert_eq!(r, Some(-2.0 - 3.0 + 0.5));
}
#[test]
fn neg_omega_propagates_d_alone() {
assert_eq!(neg_omega_plus_d_elem(None, None, Some(0.7)), Some(0.7));
}
#[test]
fn rhs_cr_combines_three_terms() {
let r = rhs_cr_elem(1.0, Some(0.5), 2.0, Some(0.25), 4.0);
assert_eq!(r, 1.0);
}
#[test]
fn rhs_cr_drops_terms_when_sigma_absent() {
let r = rhs_cr_elem(2.0, None, 3.0, Some(0.5), 6.0);
assert_eq!(r, 2.0 + 0.5 * 6.0);
let r = rhs_cr_elem(2.0, None, 3.0, None, 6.0);
assert_eq!(r, 2.0);
}
#[test]
fn expand_sol_n_c_zero_when_sigma_absent() {
assert_eq!(expand_sol_n_c_elem(1.0, 2.0, None), 0.0);
}
#[test]
fn expand_sol_n_c_signs() {
assert_eq!(expand_sol_n_c_elem(5.0, 1.0, Some(0.5)), 2.0);
assert_eq!(expand_sol_n_c_elem(1.0, 5.0, Some(0.5)), -2.0);
}
#[test]
fn expand_sol_p_c_signs() {
assert_eq!(expand_sol_p_c_elem(5.0, 1.0, Some(0.5)), 3.0);
assert_eq!(expand_sol_p_c_elem(1.0, 5.0, Some(0.5)), 3.0);
}
#[test]
fn expand_sol_p_c_zero_when_sigma_absent() {
assert_eq!(expand_sol_p_c_elem(1.0, 2.0, None), 0.0);
}
}