use crate::kkt::aug_system_solver::{AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver};
use pounce_common::tagged::Tag;
use pounce_common::timing::TimingStatistics;
use pounce_common::types::{Index, Number};
use pounce_linalg::dense_gen_matrix::{DenseGenMatrix, DenseGenMatrixSpace};
use pounce_linalg::dense_sym_matrix::DenseSymMatrixSpace;
use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
use pounce_linalg::diag_matrix::DiagMatrix;
use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrix;
use pounce_linalg::multi_vector_matrix::{MultiVectorMatrix, MultiVectorMatrixSpace};
use pounce_linalg::{Matrix, SymMatrix, Vector};
use pounce_linsol::ESymSolverStatus;
use std::rc::Rc;
pub struct LowRankAugSystemSolver {
inner: Box<dyn AugSystemSolver>,
first_call: bool,
num_neg_evals: Index,
cache: AugSysCache,
factor: Factorization,
}
#[derive(Debug, Clone)]
pub struct AugSysCache {
pub w_tag: Tag,
pub w_factor: Number,
pub d_x_tag: Tag,
pub delta_x: Number,
pub d_s_tag: Tag,
pub delta_s: Number,
pub j_c_tag: Tag,
pub d_c_tag: Tag,
pub delta_c: Number,
pub j_d_tag: Tag,
pub d_d_tag: Tag,
pub delta_d: Number,
}
impl Default for AugSysCache {
fn default() -> Self {
Self {
w_tag: Tag::NONE,
w_factor: 0.0,
d_x_tag: Tag::NONE,
delta_x: 0.0,
d_s_tag: Tag::NONE,
delta_s: 0.0,
j_c_tag: Tag::NONE,
d_c_tag: Tag::NONE,
delta_c: 0.0,
j_d_tag: Tag::NONE,
d_d_tag: Tag::NONE,
delta_d: 0.0,
}
}
}
#[derive(Default)]
struct Factorization {
wdiag: Option<Box<DiagMatrix>>,
j1: Option<DenseGenMatrix>,
j2: Option<DenseGenMatrix>,
vtilde1_x: Option<MultiVectorMatrix>,
vtilde1_s: Option<MultiVectorMatrix>,
vtilde1_c: Option<MultiVectorMatrix>,
vtilde1_d: Option<MultiVectorMatrix>,
utilde2_x: Option<MultiVectorMatrix>,
utilde2_s: Option<MultiVectorMatrix>,
utilde2_c: Option<MultiVectorMatrix>,
utilde2_d: Option<MultiVectorMatrix>,
}
impl LowRankAugSystemSolver {
pub fn new(inner: Box<dyn AugSystemSolver>) -> Self {
Self {
inner,
first_call: true,
num_neg_evals: 0,
cache: AugSysCache::default(),
factor: Factorization::default(),
}
}
pub fn augmented_system_requires_change(&self, coeffs: &AugSysCoeffs<'_>) -> bool {
let cache = &self.cache;
let zero_tag: Tag = Tag::NONE;
let w_changed = match coeffs.w {
Some(w) => w.as_tagged().get_tag() != cache.w_tag,
None => cache.w_tag != zero_tag,
};
if w_changed || coeffs.w_factor != cache.w_factor {
return true;
}
let dx_changed = match coeffs.d_x {
Some(d) => d.as_tagged().get_tag() != cache.d_x_tag,
None => cache.d_x_tag != zero_tag,
};
if dx_changed || coeffs.delta_x != cache.delta_x {
return true;
}
let ds_changed = match coeffs.d_s {
Some(d) => d.as_tagged().get_tag() != cache.d_s_tag,
None => cache.d_s_tag != zero_tag,
};
if ds_changed || coeffs.delta_s != cache.delta_s {
return true;
}
if coeffs.j_c.as_tagged().get_tag() != cache.j_c_tag {
return true;
}
let dc_changed = match coeffs.d_c {
Some(d) => d.as_tagged().get_tag() != cache.d_c_tag,
None => cache.d_c_tag != zero_tag,
};
if dc_changed || coeffs.delta_c != cache.delta_c {
return true;
}
if coeffs.j_d.as_tagged().get_tag() != cache.j_d_tag {
return true;
}
let dd_changed = match coeffs.d_d {
Some(d) => d.as_tagged().get_tag() != cache.d_d_tag,
None => cache.d_d_tag != zero_tag,
};
if dd_changed || coeffs.delta_d != cache.delta_d {
return true;
}
false
}
fn store_cache(&mut self, coeffs: &AugSysCoeffs<'_>) {
let zero_tag = Tag::NONE;
self.cache.w_tag = coeffs
.w
.map(|w| w.as_tagged().get_tag())
.unwrap_or(zero_tag);
self.cache.w_factor = coeffs.w_factor;
self.cache.d_x_tag = coeffs
.d_x
.map(|d| d.as_tagged().get_tag())
.unwrap_or(zero_tag);
self.cache.delta_x = coeffs.delta_x;
self.cache.d_s_tag = coeffs
.d_s
.map(|d| d.as_tagged().get_tag())
.unwrap_or(zero_tag);
self.cache.delta_s = coeffs.delta_s;
self.cache.j_c_tag = coeffs.j_c.as_tagged().get_tag();
self.cache.d_c_tag = coeffs
.d_c
.map(|d| d.as_tagged().get_tag())
.unwrap_or(zero_tag);
self.cache.delta_c = coeffs.delta_c;
self.cache.j_d_tag = coeffs.j_d.as_tagged().get_tag();
self.cache.d_d_tag = coeffs
.d_d
.map(|d| d.as_tagged().get_tag())
.unwrap_or(zero_tag);
self.cache.delta_d = coeffs.delta_d;
}
pub fn first_call(&self) -> bool {
self.first_call
}
pub fn cache(&self) -> &AugSysCache {
&self.cache
}
fn update_factorization(
&mut self,
lr_w: &LowRankUpdateSymMatrix,
coeffs: &AugSysCoeffs<'_>,
proto: &AugSysRhs<'_>,
check_neg_evals: bool,
num_neg_evals: Index,
) -> ESymSolverStatus {
let proto_x = downcast_dense(proto.rhs_x);
let proto_s = downcast_dense(proto.rhs_s);
let proto_c = downcast_dense(proto.rhs_c);
let proto_d = downcast_dense(proto.rhs_d);
let space_x = Rc::clone(proto_x.space());
let space_s = Rc::clone(proto_s.space());
let space_c = Rc::clone(proto_c.space());
let space_d = Rc::clone(proto_d.space());
let b0_dense: DenseVector = if coeffs.w_factor == 1.0 {
match lr_w.get_diag() {
Some(d) => clone_dense(downcast_dense(d.as_ref())),
None => zero_x_for(&space_x, lr_w),
}
} else {
zero_x_for(&space_x, lr_w)
};
let wdiag_diag: Rc<dyn Vector> = match (lr_w.p_lowrank(), lr_w.reduced_diag()) {
(Some(p_lm), true) => {
let mut fullx = space_x.make_new_dense();
p_lm.mult_vector(1.0, &b0_dense, 0.0, &mut fullx);
Rc::new(fullx) as Rc<dyn Vector>
}
_ => Rc::new(clone_dense(&b0_dense)) as Rc<dyn Vector>,
};
let mut wdiag = Box::new(DiagMatrix::new(space_x.dim()));
wdiag.set_diag(wdiag_diag);
self.factor.wdiag = Some(wdiag);
if coeffs.w_factor == 1.0 && lr_w.get_v().is_some() {
let v = Rc::clone(lr_w.get_v().unwrap());
let n_v = v.n_cols();
let v_x_space = MultiVectorMatrixSpace::new(n_v, Rc::clone(&space_x));
let mut v_x = v_x_space.make_new_multi_vector();
for k in 0..n_v {
let vk = Rc::clone(v.get_vector(k));
let rhs_x_k: Rc<dyn Vector> = match lr_w.p_lowrank() {
Some(p_lm) => {
let mut fullx = space_x.make_new_dense();
p_lm.mult_vector(1.0, vk.as_ref(), 0.0, &mut fullx);
Rc::new(fullx) as Rc<dyn Vector>
}
None => vk,
};
v_x.set_vector(k, rhs_x_k);
}
let (vt_x, vt_s, vt_c, vt_d) = self.multi_solve_block(
&v_x,
coeffs,
&space_x,
&space_s,
&space_c,
&space_d,
check_neg_evals,
num_neg_evals,
);
let vt_x = match vt_x {
Ok(x) => x,
Err(status) => return status,
};
let m1_space = DenseSymMatrixSpace::new(n_v);
let mut m1 = m1_space.make_new_dense_sym();
m1.fill_identity(1.0);
m1.high_rank_update_transpose(1.0, &vt_x, &v_x, 1.0);
let j1_space = DenseGenMatrixSpace::new(n_v, n_v);
let mut j1 = j1_space.make_new_dense_gen();
if !j1.compute_cholesky_factor(&m1) {
self.num_neg_evals += 1;
return ESymSolverStatus::WrongInertia;
}
self.factor.vtilde1_x = Some(vt_x);
self.factor.vtilde1_s = Some(vt_s);
self.factor.vtilde1_c = Some(vt_c);
self.factor.vtilde1_d = Some(vt_d);
self.factor.j1 = Some(j1);
} else {
self.factor.vtilde1_x = None;
self.factor.vtilde1_s = None;
self.factor.vtilde1_c = None;
self.factor.vtilde1_d = None;
self.factor.j1 = None;
}
if coeffs.w_factor == 1.0 && lr_w.get_u().is_some() {
let u = Rc::clone(lr_w.get_u().unwrap());
let n_u = u.n_cols();
let u_x_space = MultiVectorMatrixSpace::new(n_u, Rc::clone(&space_x));
let mut u_x = u_x_space.make_new_multi_vector();
for k in 0..n_u {
let uk = Rc::clone(u.get_vector(k));
let rhs_x_k: Rc<dyn Vector> = match lr_w.p_lowrank() {
Some(p_lm) => {
let mut fullx = space_x.make_new_dense();
p_lm.mult_vector(1.0, uk.as_ref(), 0.0, &mut fullx);
Rc::new(fullx) as Rc<dyn Vector>
}
None => uk,
};
u_x.set_vector(k, rhs_x_k);
}
let (mut ut_x, mut ut_s, mut ut_c, mut ut_d) = match self.multi_solve_block(
&u_x,
coeffs,
&space_x,
&space_s,
&space_c,
&space_d,
check_neg_evals,
num_neg_evals,
) {
(Ok(x), s, c, d) => (x, s, c, d),
(Err(status), _, _, _) => return status,
};
if self.factor.vtilde1_x.is_some() {
let vt1_x = self.factor.vtilde1_x.as_ref().unwrap();
let vt1_s = self.factor.vtilde1_s.as_ref().unwrap();
let vt1_c = self.factor.vtilde1_c.as_ref().unwrap();
let vt1_d = self.factor.vtilde1_d.as_ref().unwrap();
let n_v = vt1_x.n_cols();
let c_space = DenseGenMatrixSpace::new(n_v, n_u);
let mut c_mat = c_space.make_new_dense_gen();
{
let cv = c_mat.values_mut();
for j in 0..n_u as usize {
let uj = u_x.get_vector(j as Index).as_ref();
for i in 0..n_v as usize {
let vi = vt1_x.get_vector(i as Index).as_ref();
cv[i + j * n_v as usize] = vi.dot(uj);
}
}
}
self.factor
.j1
.as_ref()
.unwrap()
.cholesky_solve_matrix(&mut c_mat);
ut_x.add_right_mult_matrix(-1.0, vt1_x, &c_mat, 1.0);
ut_s.add_right_mult_matrix(-1.0, vt1_s, &c_mat, 1.0);
ut_c.add_right_mult_matrix(-1.0, vt1_c, &c_mat, 1.0);
ut_d.add_right_mult_matrix(-1.0, vt1_d, &c_mat, 1.0);
}
let m2_space = DenseSymMatrixSpace::new(n_u);
let mut m2 = m2_space.make_new_dense_sym();
m2.fill_identity(1.0);
m2.high_rank_update_transpose(-1.0, &ut_x, &u_x, 1.0);
let j2_space = DenseGenMatrixSpace::new(n_u, n_u);
let mut j2 = j2_space.make_new_dense_gen();
if !j2.compute_cholesky_factor(&m2) {
self.num_neg_evals += 1;
return ESymSolverStatus::WrongInertia;
}
self.factor.utilde2_x = Some(ut_x);
self.factor.utilde2_s = Some(ut_s);
self.factor.utilde2_c = Some(ut_c);
self.factor.utilde2_d = Some(ut_d);
self.factor.j2 = Some(j2);
} else {
self.factor.utilde2_x = None;
self.factor.utilde2_s = None;
self.factor.utilde2_c = None;
self.factor.utilde2_d = None;
self.factor.j2 = None;
}
ESymSolverStatus::Success
}
#[allow(clippy::too_many_arguments)]
fn multi_solve_block(
&mut self,
v_x: &MultiVectorMatrix,
coeffs: &AugSysCoeffs<'_>,
space_x: &Rc<DenseVectorSpace>,
space_s: &Rc<DenseVectorSpace>,
space_c: &Rc<DenseVectorSpace>,
space_d: &Rc<DenseVectorSpace>,
check_neg_evals: bool,
num_neg_evals: Index,
) -> (
Result<MultiVectorMatrix, ESymSolverStatus>,
MultiVectorMatrix,
MultiVectorMatrix,
MultiVectorMatrix,
) {
let n_cols = v_x.n_cols();
let mut out_x =
MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_x)).make_new_multi_vector();
let mut out_s =
MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_s)).make_new_multi_vector();
let mut out_c =
MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_c)).make_new_multi_vector();
let mut out_d =
MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_d)).make_new_multi_vector();
out_x.fill_with_new_vectors();
out_s.fill_with_new_vectors();
out_c.fill_with_new_vectors();
out_d.fill_with_new_vectors();
let mut rhs_s = space_s.make_new_dense();
rhs_s.set(0.0);
let mut rhs_c = space_c.make_new_dense();
rhs_c.set(0.0);
let mut rhs_d = space_d.make_new_dense();
rhs_d.set(0.0);
for k in 0..n_cols {
let rhs_x_dyn: &dyn Vector = v_x.get_vector(k).as_ref();
let inner_rhs = AugSysRhs {
rhs_x: rhs_x_dyn,
rhs_s: rhs_s.as_dyn_vector(),
rhs_c: rhs_c.as_dyn_vector(),
rhs_d: rhs_d.as_dyn_vector(),
};
let mut sol_x = space_x.make_new_dense();
let mut sol_s = space_s.make_new_dense();
let mut sol_c = space_c.make_new_dense();
let mut sol_d = space_d.make_new_dense();
sol_x.set(0.0);
sol_s.set(0.0);
sol_c.set(0.0);
sol_d.set(0.0);
let inner_coeffs = inner_coeffs(&self.factor, coeffs);
let status = {
let mut sol = AugSysSol {
sol_x: &mut sol_x,
sol_s: &mut sol_s,
sol_c: &mut sol_c,
sol_d: &mut sol_d,
};
self.inner.solve(
&inner_coeffs,
&inner_rhs,
&mut sol,
check_neg_evals,
num_neg_evals,
)
};
if self.inner.provides_inertia() {
self.num_neg_evals = self.inner.number_of_neg_evals();
}
if status != ESymSolverStatus::Success {
return (Err(status), out_s, out_c, out_d);
}
out_x.set_vector(k, Rc::new(sol_x) as Rc<dyn Vector>);
out_s.set_vector(k, Rc::new(sol_s) as Rc<dyn Vector>);
out_c.set_vector(k, Rc::new(sol_c) as Rc<dyn Vector>);
out_d.set_vector(k, Rc::new(sol_d) as Rc<dyn Vector>);
}
(Ok(out_x), out_s, out_c, out_d)
}
}
fn inner_coeffs<'b>(factor: &'b Factorization, coeffs: &AugSysCoeffs<'b>) -> AugSysCoeffs<'b> {
let wdiag: &DiagMatrix = factor.wdiag.as_ref().expect("Wdiag unset").as_ref();
AugSysCoeffs {
w: Some(wdiag as &dyn SymMatrix),
w_factor: 1.0,
d_x: coeffs.d_x,
delta_x: coeffs.delta_x,
d_s: coeffs.d_s,
delta_s: coeffs.delta_s,
j_c: coeffs.j_c,
d_c: coeffs.d_c,
delta_c: coeffs.delta_c,
j_d: coeffs.j_d,
d_d: coeffs.d_d,
delta_d: coeffs.delta_d,
}
}
fn downcast_dense(v: &dyn Vector) -> &DenseVector {
v.as_any()
.downcast_ref::<DenseVector>()
.expect("LowRankAugSystemSolver currently requires DenseVector RHS/solutions")
}
fn clone_dense(src: &DenseVector) -> DenseVector {
let mut out = src.space().make_new_dense();
out.set_values(&src.expanded_values());
out
}
fn zero_x_for(space_x: &Rc<DenseVectorSpace>, lr_w: &LowRankUpdateSymMatrix) -> DenseVector {
let _ = lr_w;
let mut z = space_x.make_new_dense();
z.set(0.0);
z
}
impl AugSystemSolver for LowRankAugSystemSolver {
fn provides_inertia(&self) -> bool {
self.inner.provides_inertia()
}
fn number_of_neg_evals(&self) -> Index {
if self.inner.provides_inertia() {
self.inner.number_of_neg_evals()
} else {
self.num_neg_evals
}
}
fn increase_quality(&mut self) -> bool {
self.inner.increase_quality()
}
fn last_solve_status(&self) -> ESymSolverStatus {
self.inner.last_solve_status()
}
fn set_timing_stats(&mut self, timing: Rc<TimingStatistics>) {
self.inner.set_timing_stats(timing);
}
fn solve(
&mut self,
coeffs: &AugSysCoeffs<'_>,
rhs: &AugSysRhs<'_>,
sol: &mut AugSysSol<'_>,
check_neg_evals: bool,
num_neg_evals: Index,
) -> ESymSolverStatus {
let mut check_neg_evals = check_neg_evals;
if !self.inner.provides_inertia() {
check_neg_evals = false;
}
let lr_w_opt = coeffs
.w
.and_then(|w| w.as_any().downcast_ref::<LowRankUpdateSymMatrix>());
let Some(lr_w) = lr_w_opt else {
let status = self
.inner
.solve(coeffs, rhs, sol, check_neg_evals, num_neg_evals);
if self.inner.provides_inertia() {
self.num_neg_evals = self.inner.number_of_neg_evals();
}
return status;
};
let needs_rebuild = self.first_call || self.augmented_system_requires_change(coeffs);
if needs_rebuild {
let status =
self.update_factorization(lr_w, coeffs, rhs, check_neg_evals, num_neg_evals);
if status != ESymSolverStatus::Success {
return status;
}
self.store_cache(coeffs);
self.first_call = false;
}
let ic = inner_coeffs(&self.factor, coeffs);
let status = self
.inner
.solve(&ic, rhs, sol, check_neg_evals, num_neg_evals);
if self.inner.provides_inertia() {
self.num_neg_evals = self.inner.number_of_neg_evals();
}
if status != ESymSolverStatus::Success {
return status;
}
if self.factor.utilde2_x.is_some() {
self.apply_smw( 1.0, true, rhs, sol);
}
if self.factor.vtilde1_x.is_some() {
self.apply_smw( -1.0, false, rhs, sol);
}
ESymSolverStatus::Success
}
}
impl LowRankAugSystemSolver {
fn apply_smw(&self, sign: Number, use_u: bool, rhs: &AugSysRhs<'_>, sol: &mut AugSysSol<'_>) {
let (mvx, mvs, mvc, mvd, j) = if use_u {
(
self.factor.utilde2_x.as_ref().unwrap(),
self.factor.utilde2_s.as_ref().unwrap(),
self.factor.utilde2_c.as_ref().unwrap(),
self.factor.utilde2_d.as_ref().unwrap(),
self.factor.j2.as_ref().unwrap(),
)
} else {
(
self.factor.vtilde1_x.as_ref().unwrap(),
self.factor.vtilde1_s.as_ref().unwrap(),
self.factor.vtilde1_c.as_ref().unwrap(),
self.factor.vtilde1_d.as_ref().unwrap(),
self.factor.j1.as_ref().unwrap(),
)
};
let n = mvx.n_cols();
let mut b_vec: Vec<Number> = Vec::with_capacity(n as usize);
for k in 0..n {
let dot = mvx.get_vector(k).dot(rhs.rhs_x)
+ mvs.get_vector(k).dot(rhs.rhs_s)
+ mvc.get_vector(k).dot(rhs.rhs_c)
+ mvd.get_vector(k).dot(rhs.rhs_d);
b_vec.push(dot);
}
let space_b = DenseVectorSpace::new(n);
let mut b = space_b.make_new_dense();
b.set_values(&b_vec);
j.cholesky_solve_vector(&mut b);
mvx.mult_vector(sign, &b, 1.0, sol.sol_x);
mvs.mult_vector(sign, &b, 1.0, sol.sol_s);
mvc.mult_vector(sign, &b, 1.0, sol.sol_c);
mvd.mult_vector(sign, &b, 1.0, sol.sol_d);
}
}
#[cfg(test)]
mod tests {
use super::*;
use pounce_linalg::dense_vector::DenseVectorSpace;
use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrixSpace;
use std::cell::Cell;
struct DiagInner {
calls: Cell<usize>,
}
impl AugSystemSolver for DiagInner {
fn provides_inertia(&self) -> bool {
true
}
fn number_of_neg_evals(&self) -> Index {
0
}
fn increase_quality(&mut self) -> bool {
true
}
fn last_solve_status(&self) -> ESymSolverStatus {
ESymSolverStatus::Success
}
fn solve(
&mut self,
coeffs: &AugSysCoeffs<'_>,
rhs: &AugSysRhs<'_>,
sol: &mut AugSysSol<'_>,
_check_neg_evals: bool,
_num_neg_evals: Index,
) -> ESymSolverStatus {
self.calls.set(self.calls.get() + 1);
let wdiag = coeffs
.w
.expect("DiagInner requires W")
.as_any()
.downcast_ref::<DiagMatrix>()
.expect("DiagInner requires W to be a DiagMatrix");
let diag_rc = wdiag.get_diag().expect("Wdiag has no diag set").clone();
let diag = downcast_dense(diag_rc.as_ref()).expanded_values();
let rhs_x = downcast_dense(rhs.rhs_x).expanded_values();
let dx_vals: Option<Vec<Number>> =
coeffs.d_x.map(|d| downcast_dense(d).expanded_values());
let mut out = vec![0.0; rhs_x.len()];
for i in 0..rhs_x.len() {
let dx_i = match &dx_vals {
Some(v) => v[i],
None => 0.0,
};
let denom = diag[i] + dx_i + coeffs.delta_x;
out[i] = rhs_x[i] / denom;
}
let sol_x_dv = sol
.sol_x
.as_any_mut()
.downcast_mut::<DenseVector>()
.unwrap();
sol_x_dv.set_values(&out);
ESymSolverStatus::Success
}
}
fn dvec(space: &Rc<DenseVectorSpace>, vals: &[Number]) -> DenseVector {
let mut v = space.make_new_dense();
v.set_values(vals);
v
}
fn dvec_rc(space: &Rc<DenseVectorSpace>, vals: &[Number]) -> Rc<DenseVector> {
Rc::new(dvec(space, vals))
}
#[test]
fn smw_recovers_low_rank_inverse() {
let space_x = DenseVectorSpace::new(1);
let space_zero = DenseVectorSpace::new(0);
let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
let mut lr = lr_space.make_new_low_rank();
let b0_rc: Rc<dyn Vector> = dvec_rc(&space_x, &[2.0]);
lr.set_diag(b0_rc);
let v_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
let mut v_mvm = v_space.make_new_multi_vector();
v_mvm.set_vector(0, dvec_rc(&space_x, &[3.0]) as Rc<dyn Vector>);
lr.set_v(Rc::new(v_mvm));
let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
calls: Cell::new(0),
}));
let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_c = j_c_space.make_new_dense_gen();
let j_d = j_d_space.make_new_dense_gen();
let coeffs = AugSysCoeffs {
w: Some(lr_rc.as_ref() as &dyn SymMatrix),
w_factor: 1.0,
d_x: None,
delta_x: 0.0,
d_s: None,
delta_s: 0.0,
j_c: &j_c as &dyn Matrix,
d_c: None,
delta_c: 0.0,
j_d: &j_d as &dyn Matrix,
d_d: None,
delta_d: 0.0,
};
let rhs_x = dvec(&space_x, &[5.0]);
let rhs_s = dvec(&space_zero, &[]);
let rhs_c = dvec(&space_zero, &[]);
let rhs_d = dvec(&space_zero, &[]);
let rhs = AugSysRhs {
rhs_x: &rhs_x,
rhs_s: &rhs_s,
rhs_c: &rhs_c,
rhs_d: &rhs_d,
};
let mut sol_x = dvec(&space_x, &[0.0]);
let mut sol_s = dvec(&space_zero, &[]);
let mut sol_c = dvec(&space_zero, &[]);
let mut sol_d = dvec(&space_zero, &[]);
let mut sol = AugSysSol {
sol_x: &mut sol_x,
sol_s: &mut sol_s,
sol_c: &mut sol_c,
sol_d: &mut sol_d,
};
let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
assert_eq!(status, ESymSolverStatus::Success);
let got = sol_x.expanded_values()[0];
let want = 5.0 / 11.0;
assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
}
#[test]
fn smw_with_u_only_applies_positive_correction() {
let space_x = DenseVectorSpace::new(1);
let space_zero = DenseVectorSpace::new(0);
let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
let mut lr = lr_space.make_new_low_rank();
lr.set_diag(dvec_rc(&space_x, &[5.0]));
let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
let mut u_mvm = u_space.make_new_multi_vector();
u_mvm.set_vector(0, dvec_rc(&space_x, &[1.5]) as Rc<dyn Vector>);
lr.set_u(Rc::new(u_mvm));
let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
calls: Cell::new(0),
}));
let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_c = j_c_space.make_new_dense_gen();
let j_d = j_d_space.make_new_dense_gen();
let coeffs = AugSysCoeffs {
w: Some(lr_rc.as_ref() as &dyn SymMatrix),
w_factor: 1.0,
d_x: None,
delta_x: 0.0,
d_s: None,
delta_s: 0.0,
j_c: &j_c as &dyn Matrix,
d_c: None,
delta_c: 0.0,
j_d: &j_d as &dyn Matrix,
d_d: None,
delta_d: 0.0,
};
let rhs_x = dvec(&space_x, &[7.0]);
let rhs_s = dvec(&space_zero, &[]);
let rhs_c = dvec(&space_zero, &[]);
let rhs_d = dvec(&space_zero, &[]);
let rhs = AugSysRhs {
rhs_x: &rhs_x,
rhs_s: &rhs_s,
rhs_c: &rhs_c,
rhs_d: &rhs_d,
};
let mut sol_x = dvec(&space_x, &[0.0]);
let mut sol_s = dvec(&space_zero, &[]);
let mut sol_c = dvec(&space_zero, &[]);
let mut sol_d = dvec(&space_zero, &[]);
let mut sol = AugSysSol {
sol_x: &mut sol_x,
sol_s: &mut sol_s,
sol_c: &mut sol_c,
sol_d: &mut sol_d,
};
let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
assert_eq!(status, ESymSolverStatus::Success);
let got = sol_x.expanded_values()[0];
let want = 7.0 / 2.75;
assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
}
#[test]
fn smw_reports_wrong_inertia_on_indefinite_negative_update() {
let space_x = DenseVectorSpace::new(1);
let space_zero = DenseVectorSpace::new(0);
let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
let mut lr = lr_space.make_new_low_rank();
lr.set_diag(dvec_rc(&space_x, &[2.0]));
let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
let mut u_mvm = u_space.make_new_multi_vector();
u_mvm.set_vector(0, dvec_rc(&space_x, &[2.0]) as Rc<dyn Vector>);
lr.set_u(Rc::new(u_mvm));
let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
calls: Cell::new(0),
}));
let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_c = j_c_space.make_new_dense_gen();
let j_d = j_d_space.make_new_dense_gen();
let coeffs = AugSysCoeffs {
w: Some(lr_rc.as_ref() as &dyn SymMatrix),
w_factor: 1.0,
d_x: None,
delta_x: 0.0,
d_s: None,
delta_s: 0.0,
j_c: &j_c as &dyn Matrix,
d_c: None,
delta_c: 0.0,
j_d: &j_d as &dyn Matrix,
d_d: None,
delta_d: 0.0,
};
let rhs_x = dvec(&space_x, &[1.0]);
let rhs_s = dvec(&space_zero, &[]);
let rhs_c = dvec(&space_zero, &[]);
let rhs_d = dvec(&space_zero, &[]);
let rhs = AugSysRhs {
rhs_x: &rhs_x,
rhs_s: &rhs_s,
rhs_c: &rhs_c,
rhs_d: &rhs_d,
};
let mut sol_x = dvec(&space_x, &[0.0]);
let mut sol_s = dvec(&space_zero, &[]);
let mut sol_c = dvec(&space_zero, &[]);
let mut sol_d = dvec(&space_zero, &[]);
let mut sol = AugSysSol {
sol_x: &mut sol_x,
sol_s: &mut sol_s,
sol_c: &mut sol_c,
sol_d: &mut sol_d,
};
let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
assert_eq!(status, ESymSolverStatus::WrongInertia);
}
#[test]
fn smw_with_v_and_u_combines_corrections() {
let space_x = DenseVectorSpace::new(1);
let space_zero = DenseVectorSpace::new(0);
let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
let mut lr = lr_space.make_new_low_rank();
lr.set_diag(dvec_rc(&space_x, &[10.0]));
let v_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
let mut v_mvm = v_space.make_new_multi_vector();
v_mvm.set_vector(0, dvec_rc(&space_x, &[2.0]) as Rc<dyn Vector>);
lr.set_v(Rc::new(v_mvm));
let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
let mut u_mvm = u_space.make_new_multi_vector();
u_mvm.set_vector(0, dvec_rc(&space_x, &[1.0]) as Rc<dyn Vector>);
lr.set_u(Rc::new(u_mvm));
let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
calls: Cell::new(0),
}));
let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_c = j_c_space.make_new_dense_gen();
let j_d = j_d_space.make_new_dense_gen();
let coeffs = AugSysCoeffs {
w: Some(lr_rc.as_ref() as &dyn SymMatrix),
w_factor: 1.0,
d_x: None,
delta_x: 0.0,
d_s: None,
delta_s: 0.0,
j_c: &j_c as &dyn Matrix,
d_c: None,
delta_c: 0.0,
j_d: &j_d as &dyn Matrix,
d_d: None,
delta_d: 0.0,
};
let rhs_x = dvec(&space_x, &[1.0]);
let rhs_s = dvec(&space_zero, &[]);
let rhs_c = dvec(&space_zero, &[]);
let rhs_d = dvec(&space_zero, &[]);
let rhs = AugSysRhs {
rhs_x: &rhs_x,
rhs_s: &rhs_s,
rhs_c: &rhs_c,
rhs_d: &rhs_d,
};
let mut sol_x = dvec(&space_x, &[0.0]);
let mut sol_s = dvec(&space_zero, &[]);
let mut sol_c = dvec(&space_zero, &[]);
let mut sol_d = dvec(&space_zero, &[]);
let mut sol = AugSysSol {
sol_x: &mut sol_x,
sol_s: &mut sol_s,
sol_c: &mut sol_c,
sol_d: &mut sol_d,
};
let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
assert_eq!(status, ESymSolverStatus::Success);
let got = sol_x.expanded_values()[0];
let want = 1.0 / 13.0;
assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
}
#[test]
fn unchanged_coeffs_skip_rebuild_after_first_call() {
let mut lr_solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
calls: Cell::new(0),
}));
let space_x = DenseVectorSpace::new(1);
let space_zero = DenseVectorSpace::new(0);
let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
let mut lr = lr_space.make_new_low_rank();
lr.set_diag(dvec_rc(&space_x, &[2.0]));
let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
let j_c = j_c_space.make_new_dense_gen();
let j_d = j_d_space.make_new_dense_gen();
let coeffs = AugSysCoeffs {
w: Some(lr_rc.as_ref() as &dyn SymMatrix),
w_factor: 1.0,
d_x: None,
delta_x: 0.001,
d_s: None,
delta_s: 0.0,
j_c: &j_c as &dyn Matrix,
d_c: None,
delta_c: 0.0,
j_d: &j_d as &dyn Matrix,
d_d: None,
delta_d: 0.0,
};
let rhs_x = dvec(&space_x, &[1.0]);
let rhs_zero = dvec(&space_zero, &[]);
let rhs = AugSysRhs {
rhs_x: &rhs_x,
rhs_s: &rhs_zero,
rhs_c: &rhs_zero,
rhs_d: &rhs_zero,
};
let mut sol_x = dvec(&space_x, &[0.0]);
let mut sol_z1 = dvec(&space_zero, &[]);
let mut sol_z2 = dvec(&space_zero, &[]);
let mut sol_z3 = dvec(&space_zero, &[]);
{
let mut sol = AugSysSol {
sol_x: &mut sol_x,
sol_s: &mut sol_z1,
sol_c: &mut sol_z2,
sol_d: &mut sol_z3,
};
lr_solver.solve(&coeffs, &rhs, &mut sol, false, 0);
}
assert!(!lr_solver.augmented_system_requires_change(&coeffs));
}
}