use crate::error::QpError;
use crate::factor::LinearSolver;
use crate::kkt::KktTriplet;
use crate::options::QpOptions;
use crate::problem::QpProblem;
use crate::working_set::WorkingSet;
use pounce_common::{Index, Number};
pub struct SchurState {
pub n: usize,
pub m: usize,
pub m_total: usize,
pub dim: usize,
base_active: Vec<bool>,
u_cols: Vec<Vec<Number>>,
v_cols: Vec<Vec<Number>>,
kinv_u_cols: Vec<Vec<Number>>,
s_matrix: Vec<Number>,
s_dim: usize,
}
struct UpdateVectors {
u: Vec<Number>,
v: Vec<Number>,
}
impl SchurState {
pub fn new(n: usize, m: usize) -> Self {
let m_total = m + n;
let dim = n + m_total;
Self {
n,
m,
m_total,
dim,
base_active: vec![false; m_total],
u_cols: Vec::new(),
v_cols: Vec::new(),
kinv_u_cols: Vec::new(),
s_matrix: Vec::new(),
s_dim: 0,
}
}
pub fn slot_active(working: &WorkingSet, slot: usize) -> bool {
let m = working.m();
if slot < m {
working.constraints[slot].is_active()
} else {
working.bounds[slot - m].is_active()
}
}
fn slot_is_general(&self, slot: usize) -> bool {
slot < self.m
}
fn build_k_max_triplet(&self, qp: &QpProblem, slot_active: &[bool]) -> KktTriplet {
let n_i = self.n as Index;
let mut irn = Vec::new();
let mut jcn = Vec::new();
let mut vals = Vec::new();
let h_irows = qp.h.irows();
let h_jcols = qp.h.jcols();
let h_vals = qp.h.values();
for k in 0..h_irows.len() {
let i = h_irows[k];
let j = h_jcols[k];
let (lo, hi) = if i >= j { (j, i) } else { (i, j) };
irn.push(hi);
jcn.push(lo);
vals.push(h_vals[k]);
}
let a_irows = qp.a.irows();
let a_jcols = qp.a.jcols();
let a_vals = qp.a.values();
for slot in 0..self.m_total {
let saddle_row = n_i + (slot as Index) + 1; if slot_active[slot] {
if self.slot_is_general(slot) {
let slot_1based = (slot + 1) as Index;
for k in 0..a_irows.len() {
if a_irows[k] == slot_1based {
irn.push(saddle_row);
jcn.push(a_jcols[k]);
vals.push(a_vals[k]);
}
}
} else {
let var = slot - self.m;
irn.push(saddle_row);
jcn.push((var + 1) as Index);
vals.push(1.0);
}
} else {
irn.push(saddle_row);
jcn.push(saddle_row);
vals.push(1.0);
}
}
KktTriplet {
dim: self.dim,
irn,
jcn,
vals,
}
}
fn build_update_vectors(
&self,
qp: &QpProblem,
slot: usize,
going_active: bool,
) -> UpdateVectors {
let p = self.n + slot; let mut u = vec![0.0; self.dim];
u[p] = 1.0;
let mut v = vec![0.0; self.dim];
if self.slot_is_general(slot) {
let row_1based = (slot + 1) as Index;
let sign = if going_active { 1.0 } else { -1.0 };
for k in 0..qp.a.irows().len() {
if qp.a.irows()[k] == row_1based {
let col_0 = (qp.a.jcols()[k] - 1) as usize;
v[col_0] += sign * qp.a.values()[k];
}
}
v[p] = if going_active { -0.5 } else { 0.5 };
} else {
let var = slot - self.m;
let sign = if going_active { 1.0 } else { -1.0 };
v[var] = sign;
v[p] = if going_active { -0.5 } else { 0.5 };
}
UpdateVectors { u, v }
}
pub fn reset(
&mut self,
linsol: &mut LinearSolver,
qp: &QpProblem,
working: &WorkingSet,
expected_neg: i32,
opts: &QpOptions,
) -> Result<(), QpError> {
for slot in 0..self.m_total {
self.base_active[slot] = Self::slot_active(working, slot);
}
let mut kkt = self.build_k_max_triplet(qp, &self.base_active);
let mut rhs = vec![0.0; self.dim];
let mut current = 0.0;
let mut last_err: Option<QpError>;
match linsol.factorize_and_solve(&kkt, &mut rhs, Some(expected_neg)) {
Ok(()) => {
self.u_cols.clear();
self.v_cols.clear();
self.kinv_u_cols.clear();
self.s_matrix.clear();
self.s_dim = 0;
return Ok(());
}
Err(QpError::LinearSolverFailure(ref msg))
if msg.contains("inertia") || msg.contains("singular") =>
{
last_err = Some(QpError::LinearSolverFailure(msg.clone()));
}
Err(e) => return Err(e),
}
let mut next = opts.inertia_shift_initial;
for _ in 0..opts.inertia_max_shifts {
kkt.add_h_diagonal_shift(self.n, next - current);
current = next;
let mut rhs_local = vec![0.0; self.dim];
match linsol.factorize_and_solve(&kkt, &mut rhs_local, Some(expected_neg)) {
Ok(()) => {
self.u_cols.clear();
self.v_cols.clear();
self.kinv_u_cols.clear();
self.s_matrix.clear();
self.s_dim = 0;
return Ok(());
}
Err(QpError::LinearSolverFailure(ref msg))
if msg.contains("inertia") || msg.contains("singular") =>
{
last_err = Some(QpError::LinearSolverFailure(msg.clone()));
next *= opts.inertia_shift_factor;
}
Err(e) => return Err(e),
}
}
Err(last_err.unwrap_or_else(|| {
QpError::LinearSolverFailure(format!(
"Schur reset: inertia control exhausted {} shifts (final δ = {:.3e})",
opts.inertia_max_shifts, current
))
}))
}
pub fn apply_change(
&mut self,
linsol: &mut LinearSolver,
qp: &QpProblem,
slot: usize,
going_active: bool,
) -> Result<(), QpError> {
let UpdateVectors { u, v } = self.build_update_vectors(qp, slot, going_active);
let mut kinv_u = u.clone();
linsol.resolve(&mut kinv_u)?;
let mut kinv_v = v.clone();
linsol.resolve(&mut kinv_v)?;
let old_dim = self.s_dim;
let new_dim = old_dim + 2;
let mut new_s = vec![0.0; new_dim * new_dim];
for i in 0..old_dim {
for j in 0..old_dim {
new_s[i * new_dim + j] = self.s_matrix[i * old_dim + j];
}
}
let v_new_rows: [&[Number]; 2] = [&v, &u];
let u_new_cols: [&[Number]; 2] = [&kinv_u, &kinv_v];
for i in 0..old_dim {
new_s[i * new_dim + old_dim] = dot(&self.v_cols[i], &kinv_u);
new_s[i * new_dim + old_dim + 1] = dot(&self.v_cols[i], &kinv_v);
}
for ii in 0..2 {
for j in 0..old_dim {
new_s[(old_dim + ii) * new_dim + j] = dot(v_new_rows[ii], &self.kinv_u_cols[j]);
}
}
for ii in 0..2 {
for jj in 0..2 {
let entry = dot(v_new_rows[ii], u_new_cols[jj]);
let identity = if ii == jj { 1.0 } else { 0.0 };
new_s[(old_dim + ii) * new_dim + old_dim + jj] = entry + identity;
}
}
self.u_cols.push(u);
self.u_cols.push(v.clone());
self.v_cols.push(v);
self.v_cols.push(self.u_cols[self.u_cols.len() - 2].clone()); self.kinv_u_cols.push(kinv_u);
self.kinv_u_cols.push(kinv_v);
self.s_matrix = new_s;
self.s_dim = new_dim;
Ok(())
}
pub fn solve(&self, linsol: &mut LinearSolver, rhs: &mut [Number]) -> Result<(), QpError> {
if rhs.len() != self.dim {
return Err(QpError::DimensionMismatch(format!(
"Schur solve RHS length {} but K_max dim is {}",
rhs.len(),
self.dim
)));
}
linsol.resolve(rhs)?;
if self.s_dim == 0 {
return Ok(());
}
let z = rhs.to_vec();
let mut w = vec![0.0; self.s_dim];
for j in 0..self.s_dim {
w[j] = dot(&self.v_cols[j], &z);
}
let y = small_dense_lu_solve(&self.s_matrix, self.s_dim, &w)?;
for j in 0..self.s_dim {
let y_j = y[j];
let kinv_uj = &self.kinv_u_cols[j];
for i in 0..self.dim {
rhs[i] -= y_j * kinv_uj[i];
}
}
Ok(())
}
pub fn needs_reset(&self, opts: &QpOptions) -> bool {
self.s_dim >= opts.max_schur_updates_before_refactor as usize
}
pub fn n_schur_updates(&self) -> u32 {
(self.s_dim / 2) as u32
}
}
fn dot(a: &[Number], b: &[Number]) -> Number {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
fn small_dense_lu_solve(
s_in: &[Number],
dim: usize,
b_in: &[Number],
) -> Result<Vec<Number>, QpError> {
let mut a = s_in.to_vec();
let mut b = b_in.to_vec();
for k in 0..dim {
let mut piv = k;
let mut piv_mag = a[k * dim + k].abs();
for i in (k + 1)..dim {
let v = a[i * dim + k].abs();
if v > piv_mag {
piv_mag = v;
piv = i;
}
}
if piv_mag == 0.0 {
return Err(QpError::LinearSolverFailure(format!(
"Schur block is singular at column {k}"
)));
}
if piv != k {
for j in 0..dim {
a.swap(k * dim + j, piv * dim + j);
}
b.swap(k, piv);
}
let pivot = a[k * dim + k];
for i in (k + 1)..dim {
let m = a[i * dim + k] / pivot;
for j in k..dim {
a[i * dim + j] -= m * a[k * dim + j];
}
b[i] -= m * b[k];
}
}
let mut x = vec![0.0; dim];
for k in (0..dim).rev() {
let mut s = b[k];
for j in (k + 1)..dim {
s -= a[k * dim + j] * x[j];
}
x[k] = s / a[k * dim + k];
}
Ok(x)
}