use crate::convergence::is_equality_constraint;
use crate::linear_solver::{KktMatrix, LinearSolver, SolverError, SparseSymmetricMatrix, SymmetricMatrix};
pub struct KktSystem {
pub dim: usize,
pub n: usize,
pub m: usize,
pub matrix: KktMatrix,
pub rhs: Vec<f64>,
pub delta_c_diag: Vec<f64>,
pub scale_factors: Option<Vec<f64>>,
}
#[allow(clippy::too_many_arguments)]
pub fn assemble_kkt(
n: usize,
m: usize,
hess_rows: &[usize],
hess_cols: &[usize],
hess_vals: &[f64],
jac_rows: &[usize],
jac_cols: &[usize],
jac_vals: &[f64],
sigma: &[f64],
grad_f: &[f64],
g: &[f64],
g_l: &[f64],
g_u: &[f64],
y: &[f64],
_z_l: &[f64],
_z_u: &[f64],
x: &[f64],
x_l: &[f64],
x_u: &[f64],
mu: f64,
use_sparse: bool,
_v_l: &[f64],
_v_u: &[f64],
) -> KktSystem {
let dim = n + m;
let capacity = hess_rows.len() + jac_rows.len() + n + m;
let mut matrix = if use_sparse {
KktMatrix::zeros_sparse(dim, capacity)
} else {
KktMatrix::zeros_dense(dim)
};
let mut rhs = vec![0.0; dim];
for (idx, (&row, &col)) in hess_rows.iter().zip(hess_cols.iter()).enumerate() {
let v = hess_vals[idx];
if v.is_nan() || v.is_infinite() {
log::warn!("NaN/Inf in Hessian at ({}, {}): {}", row, col, v);
}
matrix.add(row, col, v);
}
#[allow(clippy::needless_range_loop)]
for i in 0..n {
matrix.add(i, i, sigma[i]);
}
for (idx, (&row, &col)) in jac_rows.iter().zip(jac_cols.iter()).enumerate() {
matrix.add(n + row, col, jac_vals[idx]);
}
for i in 0..n {
let mut rd = -grad_f[i];
if x_l[i].is_finite() {
rd += mu / (x[i] - x_l[i]);
}
if x_u[i].is_finite() {
rd -= mu / (x_u[i] - x[i]);
}
rhs[i] = rd;
}
for (idx, (&row, &col)) in jac_rows.iter().zip(jac_cols.iter()).enumerate() {
rhs[col] -= jac_vals[idx] * y[row];
}
let mut has_sigma_s = vec![false; m]; for i in 0..m {
if is_equality_constraint(g_l[i], g_u[i]) {
rhs[n + i] = -(g[i] - g_l[i]);
continue;
}
let mut sigma_s = 0.0;
let mut rhs_correction = y[i]; let mut any_feasible = false;
let mut rhs_infeasible = 0.0;
if g_l[i].is_finite() {
let slack = g[i] - g_l[i];
if slack >= -1e-8 {
let safe_slack = slack.max(mu.max(1e-10));
let z_sl = if y[i] < -1e-20 {
-y[i]
} else {
mu / safe_slack
};
sigma_s += z_sl / safe_slack;
rhs_correction += mu / safe_slack;
any_feasible = true;
} else {
rhs_infeasible += -(g[i] - g_l[i]);
}
}
if g_u[i].is_finite() {
let slack = g_u[i] - g[i];
if slack >= -1e-8 {
let safe_slack = slack.max(mu.max(1e-10));
let z_su = if y[i] > 1e-20 {
y[i]
} else {
mu / safe_slack
};
sigma_s += z_su / safe_slack;
rhs_correction -= mu / safe_slack;
any_feasible = true;
} else {
rhs_infeasible += -(g[i] - g_u[i]);
}
}
if any_feasible && sigma_s > 1e-20 {
let sigma_s_inv = (1.0 / sigma_s).min(1e20);
matrix.add(n + i, n + i, -sigma_s_inv);
has_sigma_s[i] = true;
rhs[n + i] = sigma_s_inv * rhs_correction + rhs_infeasible;
} else {
rhs[n + i] = rhs_infeasible;
}
}
let delta_c_base = 1e-8 * mu.max(0.0).powf(0.25);
let mut delta_c_diag = vec![0.0; m];
for i in 0..m {
if !has_sigma_s[i] {
matrix.add(n + i, n + i, -delta_c_base);
delta_c_diag[i] = delta_c_base;
}
}
if rhs.iter().any(|v| v.is_nan() || v.is_infinite()) {
log::warn!("NaN/Inf in KKT RHS!");
for (i, v) in rhs.iter().enumerate() {
if v.is_nan() || v.is_infinite() {
log::warn!(" rhs[{}] = {}", i, v);
}
}
}
KktSystem {
dim,
n,
m,
matrix,
rhs,
delta_c_diag,
scale_factors: None,
}
}
pub fn compute_sigma(
x: &[f64],
x_l: &[f64],
x_u: &[f64],
z_l: &[f64],
z_u: &[f64],
) -> Vec<f64> {
let n = x.len();
let mut sigma = vec![0.0; n];
for i in 0..n {
if x_l[i].is_finite() {
let slack = (x[i] - x_l[i]).max(1e-20);
sigma[i] += z_l[i] / slack;
}
if x_u[i].is_finite() {
let slack = (x_u[i] - x[i]).max(1e-20);
sigma[i] += z_u[i] / slack;
}
}
sigma
}
pub fn ruiz_equilibrate(matrix: &mut KktMatrix, rhs: &mut [f64]) -> Vec<f64> {
let dim = matrix.n();
let mut cumulative = vec![1.0; dim];
{
let norms = matrix.row_abs_max();
for k in 0..dim {
let norm_k = norms[k];
if norm_k > 1e-30 {
let s = 1.0 / norm_k.sqrt();
matrix.scale_row_col(k, s);
rhs[k] *= s;
cumulative[k] *= s;
}
}
}
for _ in 0..3 {
let norms = matrix.row_abs_sum();
for k in 0..dim {
let norm_k = norms[k];
if norm_k > 1e-30 {
let s = 1.0 / norm_k.sqrt();
matrix.scale_row_col(k, s);
rhs[k] *= s;
cumulative[k] *= s;
}
}
}
cumulative
}
pub struct InertiaCorrectionParams {
pub delta_w_init: f64,
pub delta_c_base: f64,
pub delta_w_inc_fact_first: f64,
pub delta_w_inc_fact: f64,
pub delta_w_dec_fact: f64,
pub delta_w_max: f64,
pub delta_w_min: f64,
pub max_attempts: usize,
pub delta_w_last: f64,
pub use_scaling: bool,
pub degeneracy_count: usize,
pub structurally_degenerate: bool,
}
impl Default for InertiaCorrectionParams {
fn default() -> Self {
Self {
delta_w_init: 1e-4,
delta_c_base: 1e-8,
delta_w_inc_fact_first: 100.0,
delta_w_inc_fact: 8.0,
delta_w_dec_fact: 1.0 / 3.0,
delta_w_max: 1e20,
delta_w_min: 1e-20,
max_attempts: 30,
delta_w_last: 0.0,
use_scaling: false,
degeneracy_count: 0,
structurally_degenerate: false,
}
}
}
fn check_factorization_backward_error(
kkt: &KktSystem,
solver: &mut dyn LinearSolver,
) -> bool {
check_factorization_backward_error_with_matrix(&kkt.matrix, &kkt.rhs, solver)
}
fn check_factorization_backward_error_with_matrix(
matrix: &KktMatrix,
rhs: &[f64],
solver: &mut dyn LinearSolver,
) -> bool {
let dim = rhs.len();
let mut solution = vec![0.0; dim];
if solver.solve(rhs, &mut solution).is_err() {
return false;
}
if solution.iter().any(|v| v.is_nan() || v.is_infinite()) {
return false;
}
let mut residual = vec![0.0; dim];
matrix.matvec(&solution, &mut residual);
let x_norm: f64 = solution.iter().map(|v| v.abs()).fold(0.0f64, f64::max).max(1.0);
let mut max_berr: f64 = 0.0;
for i in 0..dim {
let abs_res = (rhs[i] - residual[i]).abs();
let denom = x_norm + rhs[i].abs().max(1e-30);
max_berr = max_berr.max(abs_res / denom);
}
max_berr <= 1e-4
}
pub fn factor_with_inertia_correction(
kkt: &mut KktSystem,
solver: &mut dyn LinearSolver,
params: &mut InertiaCorrectionParams,
mu: f64,
) -> Result<(f64, f64), crate::linear_solver::SolverError> {
let n = kkt.n;
let delta_c_active = params.delta_c_base * mu.max(0.0).powf(0.25);
let m = kkt.m;
if params.use_scaling {
let scale = ruiz_equilibrate(&mut kkt.matrix, &mut kkt.rhs);
kkt.scale_factors = Some(scale);
}
if !params.structurally_degenerate {
let inertia = solver.factor(&kkt.matrix)?;
if let Some(inertia) = inertia {
let inertia_ok = inertia.positive == n && inertia.negative == m && inertia.zero == 0;
if inertia_ok {
if check_factorization_backward_error(kkt, solver) {
params.delta_w_last = 0.0;
params.degeneracy_count = 0;
return Ok((0.0, 0.0));
}
if !params.use_scaling && kkt.scale_factors.is_none() {
params.use_scaling = true;
let scale = ruiz_equilibrate(&mut kkt.matrix, &mut kkt.rhs);
kkt.scale_factors = Some(scale);
let inertia2 = solver.factor(&kkt.matrix)?;
if let Some(inertia2) = inertia2 {
let inertia_ok2 = inertia2.positive == n && inertia2.negative == m && inertia2.zero == 0;
if inertia_ok2 && check_factorization_backward_error(kkt, solver) {
params.delta_w_last = 0.0;
params.degeneracy_count = 0;
return Ok((0.0, 0.0));
}
}
}
}
}
if m == 0 {
if let Some(min_d) = solver.min_diagonal() {
if min_d < 0.0 {
let delta_w_direct = -min_d + 1e-8;
let mut perturbed = kkt.matrix.clone();
perturbed.add_diagonal_range(0, n, delta_w_direct);
let inertia = solver.factor(&perturbed)?;
if let Some(inertia) = inertia {
if inertia.positive == n && inertia.negative == 0 && inertia.zero == 0 {
kkt.matrix = perturbed;
params.delta_w_last = delta_w_direct;
params.degeneracy_count += 1;
if params.degeneracy_count >= 3 { params.structurally_degenerate = true; }
return Ok((delta_w_direct, 0.0));
}
}
}
}
}
if solver.increase_quality() {
let inertia = solver.factor(&kkt.matrix)?;
if let Some(inertia) = inertia {
let inertia_ok = inertia.positive == n && inertia.negative == m && inertia.zero == 0;
if inertia_ok && check_factorization_backward_error(kkt, solver) {
params.delta_w_last = 0.0;
return Ok((0.0, 0.0));
}
}
}
if m > 0 {
let last_inertia = solver.factor(&kkt.matrix)?;
if let Some(inertia) = last_inertia {
let total = inertia.positive + inertia.negative + inertia.zero;
let deficit = m as isize - inertia.negative as isize;
if deficit > 0
&& deficit <= 5.max((m / 100) as isize)
&& inertia.zero == 0
&& (total as isize - (n + m) as isize).unsigned_abs() <= 2
{
let mut delta_c = delta_c_active;
for _ in 0..4 {
let mut perturbed = kkt.matrix.clone();
perturbed.add_diagonal_range(n, n + m, -delta_c);
let inertia = solver.factor(&perturbed)?;
if let Some(inertia) = inertia {
let ok = inertia.positive == n && inertia.negative == m
&& inertia.zero == 0;
if ok {
log::debug!(
"Selective delta_c-only correction succeeded: delta_c={:.2e}",
delta_c
);
kkt.matrix = perturbed;
params.delta_w_last = 0.0;
return Ok((0.0, delta_c));
}
}
delta_c *= 4.0;
}
}
}
}
}
let mut delta_w = if params.delta_w_last == 0.0 {
params.delta_w_init
} else {
(params.delta_w_last * params.delta_w_dec_fact).max(params.delta_w_min)
};
let mut best_delta_w = delta_w;
for attempt in 0..params.max_attempts {
if delta_w > params.delta_w_max {
log::debug!(
"Inertia correction: delta_w={:.2e} exceeds delta_w_max={:.2e}, giving up",
delta_w, params.delta_w_max
);
break;
}
let delta_c = delta_c_active;
let mut perturbed = kkt.matrix.clone();
perturbed.add_diagonal_range(0, n, delta_w);
if m > 0 {
perturbed.add_diagonal_range(n, n + m, -delta_c);
}
let inertia = solver.factor(&perturbed)?;
if let Some(inertia) = inertia {
let exact_ok = inertia.positive == n && inertia.negative == m && inertia.zero == 0;
if exact_ok {
if check_factorization_backward_error_with_matrix(&perturbed, &kkt.rhs, solver) {
kkt.matrix = perturbed;
params.delta_w_last = delta_w;
params.degeneracy_count += 1;
if params.degeneracy_count >= 3 { params.structurally_degenerate = true; }
return Ok((delta_w, delta_c));
}
log::debug!(
"Inertia correct at delta_w={:.2e} but backward error too large, increasing",
delta_w
);
}
}
best_delta_w = delta_w;
let inc = if params.delta_w_last == 0.0
|| 1e5 * params.delta_w_last < delta_w
{
params.delta_w_inc_fact_first
} else {
params.delta_w_inc_fact
};
delta_w *= inc;
log::debug!(
"Inertia correction attempt {}: delta_w = {:.2e} (x{}), delta_c = {:.2e}, inertia = {:?}",
attempt + 1,
delta_w,
inc,
delta_c,
inertia
);
}
let delta_c = delta_c_active;
log::warn!(
"Inertia correction failed after {} attempts (delta_w={:.2e}, delta_c={:.2e}), proceeding with approximate factorization",
params.max_attempts, best_delta_w, delta_c
);
let mut perturbed = kkt.matrix.clone();
perturbed.add_diagonal_range(0, n, best_delta_w);
if m > 0 {
perturbed.add_diagonal_range(n, n + m, -delta_c);
}
solver.factor(&perturbed)?;
kkt.matrix = perturbed;
params.delta_w_last = best_delta_w;
params.degeneracy_count += 1;
if params.degeneracy_count >= 3 { params.structurally_degenerate = true; }
Ok((best_delta_w, delta_c))
}
fn matvec_original(
kkt: &KktSystem,
x: &[f64],
y: &mut [f64],
delta_w: f64,
delta_c_ic: f64,
) {
kkt.matrix.matvec(x, y);
if delta_w > 0.0 {
for j in 0..kkt.n {
y[j] -= delta_w * x[j];
}
}
for i in 0..kkt.m {
let total_dc = kkt.delta_c_diag[i] + delta_c_ic;
if total_dc > 0.0 {
y[kkt.n + i] += total_dc * x[kkt.n + i];
}
}
}
pub fn solve_for_direction(
kkt: &KktSystem,
solver: &mut dyn LinearSolver,
delta_w: f64,
delta_c_ic: f64,
) -> Result<(Vec<f64>, Vec<f64>), crate::linear_solver::SolverError> {
let dim = kkt.dim;
let _ = (delta_w, delta_c_ic);
let use_ic_refinement = false;
if kkt.rhs.iter().any(|v| v.is_nan() || v.is_infinite()) {
return Err(crate::linear_solver::SolverError::NumericalFailure(
"KKT RHS contains NaN/Inf".to_string(),
));
}
let mut solution = vec![0.0; dim];
solver.solve(&kkt.rhs, &mut solution)?;
if solution.iter().any(|v| v.is_nan() || v.is_infinite()) {
return Err(crate::linear_solver::SolverError::NumericalFailure(
"KKT solution contains NaN/Inf".to_string(),
));
}
let max_refinements = 10;
let mut residual = vec![0.0; dim];
let mut prev_res_norm = f64::MAX;
for _ref_iter in 0..max_refinements {
if use_ic_refinement {
matvec_original(kkt, &solution, &mut residual, delta_w, delta_c_ic);
} else {
kkt.matrix.matvec(&solution, &mut residual);
}
let mut res_norm: f64 = 0.0;
for i in 0..dim {
residual[i] = kkt.rhs[i] - residual[i];
res_norm = res_norm.max(residual[i].abs());
}
if res_norm < 1e-12 {
break;
}
let stagnation_factor = if use_ic_refinement { 0.9 } else { 1.0 - 1e-6 };
if res_norm > stagnation_factor * prev_res_norm {
break;
}
prev_res_norm = res_norm;
let mut correction = vec![0.0; dim];
if solver.solve(&residual, &mut correction).is_err() {
break;
}
for i in 0..dim {
solution[i] += correction[i];
}
}
{
if use_ic_refinement {
matvec_original(kkt, &solution, &mut residual, delta_w, delta_c_ic);
} else {
kkt.matrix.matvec(&solution, &mut residual);
}
let nrm_rhs: f64 = kkt.rhs.iter().map(|v| v.abs()).fold(0.0f64, f64::max);
let nrm_res: f64 = solution.iter().map(|v| v.abs()).fold(0.0f64, f64::max);
let nrm_resid: f64 = (0..dim).map(|i| (kkt.rhs[i] - residual[i]).abs()).fold(0.0f64, f64::max);
let max_cond = 1e6;
let residual_ratio = if nrm_rhs + nrm_res == 0.0 {
nrm_resid
} else {
nrm_resid / (nrm_res.min(max_cond * nrm_rhs) + nrm_rhs)
};
if residual_ratio > 1e-5 {
log::debug!("KKT residual ratio {:.2e} > 1e-5 — pretend singular", residual_ratio);
return Err(SolverError::PretendSingular);
}
if residual_ratio > 1e-10 {
log::debug!("KKT residual ratio {:.2e} (above target 1e-10, proceeding)", residual_ratio);
}
let magnitude_ratio = nrm_res / nrm_rhs.max(1.0);
if magnitude_ratio > 1e10 {
log::debug!(
"KKT ||sol||={:.2e} vs ||rhs||={:.2e} (ratio {:.2e}) — pretend singular (rank-def guard)",
nrm_res, nrm_rhs, magnitude_ratio,
);
return Err(SolverError::PretendSingular);
}
}
if let Some(ref scale) = kkt.scale_factors {
for i in 0..dim {
solution[i] *= scale[i];
}
}
let dx = solution[..kkt.n].to_vec();
let dy = solution[kkt.n..].to_vec();
Ok((dx, dy))
}
pub fn recover_dz(
x: &[f64],
x_l: &[f64],
x_u: &[f64],
z_l: &[f64],
z_u: &[f64],
dx: &[f64],
mu: f64,
) -> (Vec<f64>, Vec<f64>) {
let n = x.len();
let mut dz_l = vec![0.0; n];
let mut dz_u = vec![0.0; n];
for i in 0..n {
if x_l[i].is_finite() {
let s_l = (x[i] - x_l[i]).max(1e-20);
dz_l[i] = (mu - z_l[i] * s_l) / s_l - (z_l[i] / s_l) * dx[i];
}
if x_u[i].is_finite() {
let s_u = (x_u[i] - x[i]).max(1e-20);
dz_u[i] = (mu - z_u[i] * s_u) / s_u + (z_u[i] / s_u) * dx[i];
}
}
(dz_l, dz_u)
}
pub fn affine_predictor_rhs(
rhs: &[f64],
x: &[f64],
x_l: &[f64],
x_u: &[f64],
mu: f64,
) -> Vec<f64> {
let n = x.len();
let mut rhs_aff = rhs.to_vec();
for i in 0..n {
if x_l[i].is_finite() {
let s_l = (x[i] - x_l[i]).max(1e-20);
rhs_aff[i] -= mu / s_l;
}
if x_u[i].is_finite() {
let s_u = (x_u[i] - x[i]).max(1e-20);
rhs_aff[i] += mu / s_u;
}
}
rhs_aff
}
pub fn rebuild_rhs_with_mu(
rhs: &[f64],
x: &[f64],
x_l: &[f64],
x_u: &[f64],
mu_old: f64,
mu_new: f64,
) -> Vec<f64> {
let n = x.len();
let mut rhs_new = rhs.to_vec();
let delta_mu = mu_new - mu_old;
for i in 0..n {
if x_l[i].is_finite() {
let s_l = (x[i] - x_l[i]).max(1e-20);
rhs_new[i] += delta_mu / s_l;
}
if x_u[i].is_finite() {
let s_u = (x_u[i] - x[i]).max(1e-20);
rhs_new[i] -= delta_mu / s_u;
}
}
rhs_new
}
pub fn mehrotra_corrector_rhs(
rhs: &[f64],
x: &[f64],
x_l: &[f64],
x_u: &[f64],
dx_aff: &[f64],
dz_l_aff: &[f64],
dz_u_aff: &[f64],
mu_old: f64,
mu_new: f64,
) -> Vec<f64> {
let n = x.len();
let mut rhs_new = rebuild_rhs_with_mu(rhs, x, x_l, x_u, mu_old, mu_new);
for i in 0..n {
if x_l[i].is_finite() {
let s_l = (x[i] - x_l[i]).max(1e-20);
rhs_new[i] -= dx_aff[i] * dz_l_aff[i] / s_l;
}
if x_u[i].is_finite() {
let s_u = (x_u[i] - x[i]).max(1e-20);
rhs_new[i] -= dx_aff[i] * dz_u_aff[i] / s_u;
}
}
rhs_new
}
pub fn recover_dz_mehrotra(
x: &[f64],
x_l: &[f64],
x_u: &[f64],
z_l: &[f64],
z_u: &[f64],
dx: &[f64],
dx_aff: &[f64],
dz_l_aff: &[f64],
dz_u_aff: &[f64],
mu: f64,
) -> (Vec<f64>, Vec<f64>) {
let n = x.len();
let mut dz_l = vec![0.0; n];
let mut dz_u = vec![0.0; n];
for i in 0..n {
if x_l[i].is_finite() {
let s_l = (x[i] - x_l[i]).max(1e-20);
dz_l[i] = (mu - z_l[i] * s_l - dx_aff[i] * dz_l_aff[i]) / s_l
- (z_l[i] / s_l) * dx[i];
}
if x_u[i].is_finite() {
let s_u = (x_u[i] - x[i]).max(1e-20);
dz_u[i] = (mu - z_u[i] * s_u + dx_aff[i] * dz_u_aff[i]) / s_u
+ (z_u[i] / s_u) * dx[i];
}
}
(dz_l, dz_u)
}
pub fn solve_with_custom_rhs(
n: usize,
dim: usize,
solver: &mut dyn LinearSolver,
rhs: &[f64],
) -> Result<(Vec<f64>, Vec<f64>), SolverError> {
solve_with_custom_rhs_impl(n, dim, solver, rhs, None)
}
pub fn solve_with_custom_rhs_refined(
matrix: &KktMatrix,
n: usize,
dim: usize,
solver: &mut dyn LinearSolver,
rhs: &[f64],
) -> Result<(Vec<f64>, Vec<f64>), SolverError> {
solve_with_custom_rhs_impl(n, dim, solver, rhs, Some(matrix))
}
fn solve_with_custom_rhs_impl(
n: usize,
dim: usize,
solver: &mut dyn LinearSolver,
rhs: &[f64],
refine_against: Option<&KktMatrix>,
) -> Result<(Vec<f64>, Vec<f64>), SolverError> {
if rhs.iter().any(|v| v.is_nan() || v.is_infinite()) {
return Err(SolverError::NumericalFailure(
"Custom RHS contains NaN/Inf".to_string(),
));
}
let mut solution = vec![0.0; dim];
solver.solve(rhs, &mut solution)?;
if solution.iter().any(|v| v.is_nan() || v.is_infinite()) {
return Err(SolverError::NumericalFailure(
"Custom solve solution contains NaN/Inf".to_string(),
));
}
if let Some(matrix) = refine_against {
let max_refinements = 5;
let mut residual = vec![0.0; dim];
let mut prev_res_norm = f64::MAX;
for _ in 0..max_refinements {
matrix.matvec(&solution, &mut residual);
let mut res_norm: f64 = 0.0;
for i in 0..dim {
residual[i] = rhs[i] - residual[i];
res_norm = res_norm.max(residual[i].abs());
}
if res_norm < 1e-10 {
break;
}
if res_norm > (1.0 - 1e-6) * prev_res_norm {
break;
}
prev_res_norm = res_norm;
let mut correction = vec![0.0; dim];
if solver.solve(&residual, &mut correction).is_err() {
break;
}
for i in 0..dim {
solution[i] += correction[i];
}
}
}
Ok((solution[..n].to_vec(), solution[n..].to_vec()))
}
pub struct CondensedKktSystem {
pub matrix: SymmetricMatrix,
pub rhs: Vec<f64>,
pub n: usize,
pub m: usize,
pub d_c: Vec<f64>,
pub rhs_primal: Vec<f64>,
pub rhs_constraint: Vec<f64>,
pub jac_rows: Vec<usize>,
pub jac_cols: Vec<usize>,
pub jac_vals: Vec<f64>,
}
#[allow(clippy::too_many_arguments)]
pub fn assemble_condensed_kkt(
n: usize,
m: usize,
hess_rows: &[usize],
hess_cols: &[usize],
hess_vals: &[f64],
jac_rows: &[usize],
jac_cols: &[usize],
jac_vals: &[f64],
sigma: &[f64],
grad_f: &[f64],
g: &[f64],
g_l: &[f64],
g_u: &[f64],
y: &[f64],
_z_l: &[f64],
_z_u: &[f64],
x: &[f64],
x_l: &[f64],
x_u: &[f64],
mu: f64,
_v_l: &[f64],
_v_u: &[f64],
) -> CondensedKktSystem {
let mut matrix = SymmetricMatrix::zeros(n);
for (idx, (&row, &col)) in hess_rows.iter().zip(hess_cols.iter()).enumerate() {
matrix.add(row, col, hess_vals[idx]);
}
for i in 0..n {
matrix.add(i, i, sigma[i]);
}
let mut rhs_primal = vec![0.0; n];
for i in 0..n {
let mut rd = -grad_f[i];
if x_l[i].is_finite() {
rd += mu / (x[i] - x_l[i]);
}
if x_u[i].is_finite() {
rd -= mu / (x_u[i] - x[i]);
}
rhs_primal[i] = rd;
}
for (idx, (&row, &col)) in jac_rows.iter().zip(jac_cols.iter()).enumerate() {
rhs_primal[col] -= jac_vals[idx] * y[row];
}
let mut d_c = vec![0.0; m];
let mut rhs_constraint = vec![0.0; m];
for i in 0..m {
if is_equality_constraint(g_l[i], g_u[i]) {
rhs_constraint[i] = -(g[i] - g_l[i]);
continue;
}
let mut sigma_s = 0.0;
let mut rhs_correction = y[i];
let mut any_feasible = false;
let mut rhs_infeasible = 0.0;
if g_l[i].is_finite() {
let slack = g[i] - g_l[i];
if slack >= -1e-8 {
let safe_slack = slack.max(mu.max(1e-10));
let z_sl = if y[i] < -1e-20 { -y[i] } else { mu / safe_slack };
sigma_s += z_sl / safe_slack;
rhs_correction += mu / safe_slack;
any_feasible = true;
} else {
rhs_infeasible += -(g[i] - g_l[i]);
}
}
if g_u[i].is_finite() {
let slack = g_u[i] - g[i];
if slack >= -1e-8 {
let safe_slack = slack.max(mu.max(1e-10));
let z_su = if y[i] > 1e-20 { y[i] } else { mu / safe_slack };
sigma_s += z_su / safe_slack;
rhs_correction -= mu / safe_slack;
any_feasible = true;
} else {
rhs_infeasible += -(g[i] - g_u[i]);
}
}
if any_feasible && sigma_s > 1e-20 {
let sigma_s_inv = (1.0 / sigma_s).min(1e20);
d_c[i] = -sigma_s_inv;
rhs_constraint[i] = sigma_s_inv * rhs_correction + rhs_infeasible;
} else {
rhs_constraint[i] = rhs_infeasible;
}
}
let mut j_dense = vec![0.0; m * n];
for (idx, (&row, &col)) in jac_rows.iter().zip(jac_cols.iter()).enumerate() {
j_dense[row * n + col] += jac_vals[idx];
}
for i in 0..m {
let d_c_eff = if d_c[i].abs() < 1e-20 {
-1e-16 } else {
d_c[i]
};
let inv_neg_dc = 1.0 / (-d_c_eff);
for p in 0..n {
let jp = j_dense[i * n + p];
if jp == 0.0 {
continue;
}
for q in 0..=p {
let jq = j_dense[i * n + q];
if jq != 0.0 {
matrix.add(p, q, inv_neg_dc * jp * jq);
}
}
}
}
let mut rhs = rhs_primal.clone();
for i in 0..m {
let d_c_eff = if d_c[i].abs() < 1e-20 {
-1e-16
} else {
d_c[i]
};
let inv_neg_dc = 1.0 / (-d_c_eff);
let scaled_rp = inv_neg_dc * rhs_constraint[i];
for p in 0..n {
let jp = j_dense[i * n + p];
if jp != 0.0 {
rhs[p] += jp * scaled_rp;
}
}
}
CondensedKktSystem {
matrix,
rhs,
n,
m,
d_c,
rhs_primal,
rhs_constraint,
jac_rows: jac_rows.to_vec(),
jac_cols: jac_cols.to_vec(),
jac_vals: jac_vals.to_vec(),
}
}
pub fn solve_condensed(
condensed: &CondensedKktSystem,
solver: &mut dyn LinearSolver,
) -> Result<(Vec<f64>, Vec<f64>), SolverError> {
let n = condensed.n;
let m = condensed.m;
let mut dx = vec![0.0; n];
solver.solve(&condensed.rhs, &mut dx)?;
let max_refinements = 5;
let mut residual = vec![0.0; n];
let mut prev_res_norm = f64::MAX;
for _ in 0..max_refinements {
condensed.matrix.matvec(&dx, &mut residual);
let mut res_norm: f64 = 0.0;
for i in 0..n {
residual[i] = condensed.rhs[i] - residual[i];
res_norm = res_norm.max(residual[i].abs());
}
if res_norm < 1e-10 {
break;
}
if res_norm > (1.0 - 1e-6) * prev_res_norm {
break;
}
prev_res_norm = res_norm;
let mut correction = vec![0.0; n];
if solver.solve(&residual, &mut correction).is_err() {
break;
}
for i in 0..n {
dx[i] += correction[i];
}
}
let mut jdx = vec![0.0; m];
for (idx, (&row, &col)) in condensed
.jac_rows
.iter()
.zip(condensed.jac_cols.iter())
.enumerate()
{
jdx[row] += condensed.jac_vals[idx] * dx[col];
}
let mut dy = vec![0.0; m];
for i in 0..m {
let d_c_eff = if condensed.d_c[i].abs() < 1e-20 {
-1e-16 } else {
condensed.d_c[i]
};
dy[i] = (jdx[i] - condensed.rhs_constraint[i]) / (-d_c_eff);
}
let nrm_rhs: f64 = condensed
.rhs_primal
.iter()
.chain(condensed.rhs_constraint.iter())
.map(|v| v.abs())
.fold(0.0_f64, f64::max);
let nrm_sol: f64 = dx
.iter()
.chain(dy.iter())
.map(|v| v.abs())
.fold(0.0_f64, f64::max);
let magnitude_ratio = nrm_sol / nrm_rhs.max(1.0);
if magnitude_ratio > 1e10 {
log::debug!(
"Condensed ||sol||={:.2e} vs ||rhs||={:.2e} (ratio {:.2e}) — rank-def, falling back to full KKT",
nrm_sol, nrm_rhs, magnitude_ratio,
);
return Err(SolverError::NumericalFailure(format!(
"Condensed solution magnitude {:.2e} exceeds {:.2e}×RHS (rank-deficient)",
nrm_sol, nrm_rhs.max(1.0) * 1e10,
)));
}
Ok((dx, dy))
}
pub fn solve_condensed_soc(
condensed: &CondensedKktSystem,
solver: &mut dyn LinearSolver,
c_soc: &[f64],
) -> Result<Vec<f64>, SolverError> {
let n = condensed.n;
let m = condensed.m;
let mut scaled = vec![0.0; m];
for i in 0..m {
let d_c_eff = if condensed.d_c[i].abs() < 1e-20 {
-1e-16 } else {
condensed.d_c[i]
};
scaled[i] = (-c_soc[i]) / (-d_c_eff);
}
let mut rhs = condensed.rhs_primal.clone();
for (idx, (&row, &col)) in condensed
.jac_rows
.iter()
.zip(condensed.jac_cols.iter())
.enumerate()
{
rhs[col] += condensed.jac_vals[idx] * scaled[row];
}
let mut dx = vec![0.0; n];
solver.solve(&rhs, &mut dx)?;
Ok(dx)
}
pub struct SparseCondensedKktSystem {
pub matrix: SparseSymmetricMatrix,
pub rhs: Vec<f64>,
pub n: usize,
pub m: usize,
pub d_c: Vec<f64>,
pub rhs_primal: Vec<f64>,
pub rhs_constraint: Vec<f64>,
pub jac_rows: Vec<usize>,
pub jac_cols: Vec<usize>,
pub jac_vals: Vec<f64>,
}
#[allow(clippy::too_many_arguments)]
pub fn assemble_sparse_condensed_kkt(
n: usize,
m: usize,
hess_rows: &[usize],
hess_cols: &[usize],
hess_vals: &[f64],
jac_rows: &[usize],
jac_cols: &[usize],
jac_vals: &[f64],
sigma: &[f64],
grad_f: &[f64],
g: &[f64],
g_l: &[f64],
g_u: &[f64],
y: &[f64],
_z_l: &[f64],
_z_u: &[f64],
x: &[f64],
x_l: &[f64],
x_u: &[f64],
mu: f64,
_v_l: &[f64],
_v_u: &[f64],
) -> SparseCondensedKktSystem {
let mut row_start = vec![0usize; m + 1];
for &r in jac_rows {
row_start[r + 1] += 1;
}
for i in 0..m {
row_start[i + 1] += row_start[i];
}
let jac_nnz = jac_rows.len();
let mut jac_order = vec![0usize; jac_nnz];
let mut row_count = vec![0usize; m];
for k in 0..jac_nnz {
let r = jac_rows[k];
jac_order[row_start[r] + row_count[r]] = k;
row_count[r] += 1;
}
let mut schur_nnz = 0;
for i in 0..m {
let k = row_start[i + 1] - row_start[i];
schur_nnz += k * (k + 1) / 2;
}
let total_nnz = hess_rows.len() + n + schur_nnz;
let mut matrix = SparseSymmetricMatrix::with_capacity(n, total_nnz);
for (idx, (&row, &col)) in hess_rows.iter().zip(hess_cols.iter()).enumerate() {
matrix.add(row, col, hess_vals[idx]);
}
for i in 0..n {
matrix.add(i, i, sigma[i]);
}
let mut rhs_primal = vec![0.0; n];
for i in 0..n {
let mut rd = -grad_f[i];
if x_l[i].is_finite() {
rd += mu / (x[i] - x_l[i]);
}
if x_u[i].is_finite() {
rd -= mu / (x_u[i] - x[i]);
}
rhs_primal[i] = rd;
}
for (idx, (&row, &col)) in jac_rows.iter().zip(jac_cols.iter()).enumerate() {
rhs_primal[col] -= jac_vals[idx] * y[row];
}
let mut d_c = vec![0.0; m];
let mut rhs_constraint = vec![0.0; m];
for i in 0..m {
if is_equality_constraint(g_l[i], g_u[i]) {
let delta_c = mu.max(1e-8);
d_c[i] = -delta_c;
rhs_constraint[i] = -(g[i] - g_l[i]);
continue;
}
let mut sigma_s = 0.0;
let mut rhs_correction = y[i];
let mut any_feasible = false;
let mut rhs_infeasible = 0.0;
if g_l[i].is_finite() {
let slack = g[i] - g_l[i];
if slack >= -1e-8 {
let safe_slack = slack.max(mu.max(1e-10));
let z_sl = if y[i] < -1e-20 { -y[i] } else { mu / safe_slack };
sigma_s += z_sl / safe_slack;
rhs_correction += mu / safe_slack;
any_feasible = true;
} else {
rhs_infeasible += -(g[i] - g_l[i]);
}
}
if g_u[i].is_finite() {
let slack = g_u[i] - g[i];
if slack >= -1e-8 {
let safe_slack = slack.max(mu.max(1e-10));
let z_su = if y[i] > 1e-20 { y[i] } else { mu / safe_slack };
sigma_s += z_su / safe_slack;
rhs_correction -= mu / safe_slack;
any_feasible = true;
} else {
rhs_infeasible += -(g[i] - g_u[i]);
}
}
if any_feasible && sigma_s > 1e-20 {
let sigma_s_inv = (1.0 / sigma_s).min(1e20);
d_c[i] = -sigma_s_inv;
rhs_constraint[i] = sigma_s_inv * rhs_correction + rhs_infeasible;
} else {
rhs_constraint[i] = rhs_infeasible;
}
}
for i in 0..m {
let d_c_eff = if d_c[i].abs() < 1e-20 { -1e-16 } else { d_c[i] };
let inv_neg_dc = 1.0 / (-d_c_eff);
let start = row_start[i];
let end = row_start[i + 1];
for a in start..end {
let ka = jac_order[a];
let ca = jac_cols[ka];
let va = jac_vals[ka];
for b in a..end {
let kb = jac_order[b];
let cb = jac_cols[kb];
let vb = jac_vals[kb];
let (p, q) = if ca <= cb { (ca, cb) } else { (cb, ca) };
let val = if a == b {
inv_neg_dc * va * vb
} else {
inv_neg_dc * va * vb };
matrix.add(p, q, val);
}
}
}
let mut rhs = rhs_primal.clone();
for i in 0..m {
let d_c_eff = if d_c[i].abs() < 1e-20 { -1e-16 } else { d_c[i] };
let inv_neg_dc = 1.0 / (-d_c_eff);
let scaled_rp = inv_neg_dc * rhs_constraint[i];
let start = row_start[i];
let end = row_start[i + 1];
for a in start..end {
let ka = jac_order[a];
rhs[jac_cols[ka]] += jac_vals[ka] * scaled_rp;
}
}
SparseCondensedKktSystem {
matrix,
rhs,
n,
m,
d_c,
rhs_primal,
rhs_constraint,
jac_rows: jac_rows.to_vec(),
jac_cols: jac_cols.to_vec(),
jac_vals: jac_vals.to_vec(),
}
}
pub fn solve_sparse_condensed(
condensed: &SparseCondensedKktSystem,
solver: &mut dyn LinearSolver,
) -> Result<(Vec<f64>, Vec<f64>), SolverError> {
let n = condensed.n;
let m = condensed.m;
let mut dx = vec![0.0; n];
solver.solve(&condensed.rhs, &mut dx)?;
let mut jdx = vec![0.0; m];
for (idx, (&row, &col)) in condensed.jac_rows.iter().zip(condensed.jac_cols.iter()).enumerate() {
jdx[row] += condensed.jac_vals[idx] * dx[col];
}
let mut dy = vec![0.0; m];
for i in 0..m {
let d_c_eff = if condensed.d_c[i].abs() < 1e-20 { -1e-16 } else { condensed.d_c[i] };
dy[i] = (jdx[i] - condensed.rhs_constraint[i]) / (-d_c_eff);
}
Ok((dx, dy))
}
pub fn solve_sparse_condensed_soc(
condensed: &SparseCondensedKktSystem,
solver: &mut dyn LinearSolver,
c_soc: &[f64],
) -> Result<Vec<f64>, SolverError> {
let n = condensed.n;
let m = condensed.m;
let mut scaled = vec![0.0; m];
for i in 0..m {
let d_c_eff = if condensed.d_c[i].abs() < 1e-20 { -1e-16 } else { condensed.d_c[i] };
scaled[i] = (-c_soc[i]) / (-d_c_eff);
}
let mut rhs = condensed.rhs_primal.clone();
for (idx, (&row, &col)) in condensed.jac_rows.iter().zip(condensed.jac_cols.iter()).enumerate() {
rhs[col] += condensed.jac_vals[idx] * scaled[row];
}
let mut dx = vec![0.0; n];
solver.solve(&rhs, &mut dx)?;
Ok(dx)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linear_solver::dense::DenseLdl;
#[test]
fn test_compute_sigma_no_bounds() {
let x = vec![1.0, 2.0];
let x_l = vec![f64::NEG_INFINITY, f64::NEG_INFINITY];
let x_u = vec![f64::INFINITY, f64::INFINITY];
let z_l = vec![0.0, 0.0];
let z_u = vec![0.0, 0.0];
let sigma = compute_sigma(&x, &x_l, &x_u, &z_l, &z_u);
assert!((sigma[0]).abs() < 1e-15);
assert!((sigma[1]).abs() < 1e-15);
}
#[test]
fn test_compute_sigma_lower_bound_only() {
let x = vec![1.5];
let x_l = vec![1.0];
let x_u = vec![f64::INFINITY];
let z_l = vec![2.0];
let z_u = vec![0.0];
let sigma = compute_sigma(&x, &x_l, &x_u, &z_l, &z_u);
assert!((sigma[0] - 4.0).abs() < 1e-12);
}
#[test]
fn test_compute_sigma_both_bounds() {
let x = vec![1.5];
let x_l = vec![1.0];
let x_u = vec![2.0];
let z_l = vec![2.0];
let z_u = vec![3.0];
let sigma = compute_sigma(&x, &x_l, &x_u, &z_l, &z_u);
assert!((sigma[0] - 10.0).abs() < 1e-12);
}
#[test]
fn test_compute_sigma_at_bound_clamped() {
let x = vec![1.0]; let x_l = vec![1.0];
let x_u = vec![f64::INFINITY];
let z_l = vec![1.0];
let z_u = vec![0.0];
let sigma = compute_sigma(&x, &x_l, &x_u, &z_l, &z_u);
assert!(sigma[0] > 1e19);
}
#[test]
fn test_assemble_kkt_unconstrained() {
let n = 2;
let m = 0;
let hess_rows = vec![0, 1];
let hess_cols = vec![0, 1];
let hess_vals = vec![2.0, 3.0];
let sigma = vec![1.0, 2.0];
let grad_f = vec![0.5, 0.5];
let x = vec![1.0, 2.0];
let x_l = vec![f64::NEG_INFINITY; 2];
let x_u = vec![f64::INFINITY; 2];
let z_l = vec![0.0; 2];
let z_u = vec![0.0; 2];
let kkt = assemble_kkt(
n, m, &hess_rows, &hess_cols, &hess_vals,
&[], &[], &[], &sigma, &grad_f,
&[], &[], &[], &[], &z_l, &z_u,
&x, &x_l, &x_u, 0.1, false, &[], &[],
);
assert_eq!(kkt.dim, 2);
assert!((kkt.matrix.get(0, 0) - 3.0).abs() < 1e-12);
assert!((kkt.matrix.get(1, 1) - 5.0).abs() < 1e-12);
}
#[test]
fn test_assemble_kkt_equality_constraint() {
let n = 2;
let m = 1;
let hess_rows = vec![0, 1];
let hess_cols = vec![0, 1];
let hess_vals = vec![2.0, 2.0];
let jac_rows = vec![0, 0];
let jac_cols = vec![0, 1];
let jac_vals = vec![1.0, 1.0];
let sigma = vec![0.0; 2];
let grad_f = vec![1.0, 1.0];
let g = vec![0.7]; let g_l = vec![1.0];
let g_u = vec![1.0];
let y = vec![0.5];
let x = vec![0.3, 0.4];
let x_l = vec![f64::NEG_INFINITY; 2];
let x_u = vec![f64::INFINITY; 2];
let z_l = vec![0.0; 2];
let z_u = vec![0.0; 2];
let v_l = vec![0.0; m];
let v_u = vec![0.0; m];
let kkt = assemble_kkt(
n, m, &hess_rows, &hess_cols, &hess_vals,
&jac_rows, &jac_cols, &jac_vals, &sigma, &grad_f,
&g, &g_l, &g_u, &y, &z_l, &z_u,
&x, &x_l, &x_u, 0.1, false, &v_l, &v_u,
);
assert_eq!(kkt.dim, 3);
assert!((kkt.matrix.get(2, 0) - 1.0).abs() < 1e-12);
assert!((kkt.matrix.get(2, 1) - 1.0).abs() < 1e-12);
let expected_dc = 1e-8 * 0.1f64.powf(0.25);
assert!((kkt.matrix.get(2, 2) - (-expected_dc)).abs() < 1e-15);
assert!((kkt.delta_c_diag[0] - expected_dc).abs() < 1e-15);
assert!((kkt.rhs[2] - 0.3).abs() < 1e-12);
}
#[test]
fn test_assemble_kkt_rhs_sign_convention() {
let n = 1;
let m = 1;
let hess_rows = vec![0];
let hess_cols = vec![0];
let hess_vals = vec![1.0];
let jac_rows = vec![0];
let jac_cols = vec![0];
let jac_vals = vec![2.0]; let sigma = vec![0.0];
let grad_f = vec![3.0];
let g = vec![1.0];
let g_l = vec![1.0];
let g_u = vec![1.0];
let y = vec![1.0];
let x = vec![1.0];
let x_l = vec![f64::NEG_INFINITY];
let x_u = vec![f64::INFINITY];
let z_l = vec![0.0];
let z_u = vec![0.0];
let v_l = vec![0.0; m];
let v_u = vec![0.0; m];
let kkt = assemble_kkt(
n, m, &hess_rows, &hess_cols, &hess_vals,
&jac_rows, &jac_cols, &jac_vals, &sigma, &grad_f,
&g, &g_l, &g_u, &y, &z_l, &z_u,
&x, &x_l, &x_u, 0.1, false, &v_l, &v_u,
);
assert!((kkt.rhs[0] - (-5.0)).abs() < 1e-12,
"RHS sign convention: expected -5.0, got {}", kkt.rhs[0]);
}
#[test]
fn test_assemble_kkt_inequality_constraint() {
let n = 1;
let m = 1;
let hess_rows = vec![0];
let hess_cols = vec![0];
let hess_vals = vec![1.0];
let jac_rows = vec![0];
let jac_cols = vec![0];
let jac_vals = vec![1.0];
let sigma = vec![0.0];
let grad_f = vec![0.0];
let g = vec![2.0]; let g_l = vec![1.0];
let g_u = vec![f64::INFINITY];
let y = vec![0.0];
let x = vec![2.0];
let x_l = vec![f64::NEG_INFINITY];
let x_u = vec![f64::INFINITY];
let z_l = vec![0.0];
let z_u = vec![0.0];
let mu = 0.1;
let v_l = vec![0.0; m];
let v_u = vec![0.0; m];
let kkt = assemble_kkt(
n, m, &hess_rows, &hess_cols, &hess_vals,
&jac_rows, &jac_cols, &jac_vals, &sigma, &grad_f,
&g, &g_l, &g_u, &y, &z_l, &z_u,
&x, &x_l, &x_u, mu, false, &v_l, &v_u,
);
assert!(kkt.matrix.get(1, 1) < 0.0,
"Inequality (2,2) block should be negative, got {}", kkt.matrix.get(1, 1));
}
#[test]
fn test_factor_with_inertia_correction_good() {
let n = 2;
let m = 1;
let mut matrix = SymmetricMatrix::zeros(3);
matrix.set(0, 0, 2.0);
matrix.set(1, 1, 2.0);
matrix.set(2, 0, 1.0);
matrix.set(2, 1, 1.0);
let mut kkt = KktSystem {
dim: 3, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 2.0, 3.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
let mut params = InertiaCorrectionParams::default();
let (delta_w, delta_c) = factor_with_inertia_correction(&mut kkt, &mut solver, &mut params, 1e-4).unwrap();
assert!((delta_w).abs() < 1e-15, "Good inertia should need no delta_w");
assert!((delta_c).abs() < 1e-15, "Good inertia should need no delta_c");
}
#[test]
fn test_factor_with_inertia_correction_needs_perturbation() {
let n = 2;
let m = 1;
let mut matrix = SymmetricMatrix::zeros(3);
matrix.set(0, 0, 2.0);
matrix.set(1, 1, 2.0);
matrix.set(2, 2, 1.0); matrix.set(2, 0, 0.1);
matrix.set(2, 1, 0.1);
let mut kkt = KktSystem {
dim: 3, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 2.0, 3.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
let mut params = InertiaCorrectionParams::default();
let (delta_w, _delta_c) = factor_with_inertia_correction(&mut kkt, &mut solver, &mut params, 1e-4).unwrap();
assert!(delta_w > 0.0, "Wrong inertia should require delta_w > 0");
}
#[test]
fn test_factor_with_inertia_correction_warm_start() {
let n = 2;
let m = 1;
let mut matrix = SymmetricMatrix::zeros(3);
matrix.set(0, 0, 2.0);
matrix.set(1, 1, 2.0);
matrix.set(2, 2, 1.0);
matrix.set(2, 0, 0.1);
matrix.set(2, 1, 0.1);
let mut kkt = KktSystem {
dim: 3, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 2.0, 3.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
let mut params = InertiaCorrectionParams::default();
params.delta_w_last = 1.0;
let (delta_w, _) = factor_with_inertia_correction(&mut kkt, &mut solver, &mut params, 1e-4).unwrap();
let warm_start = 1.0 * params.delta_w_dec_fact;
assert!(
delta_w >= warm_start - 1e-12,
"Warm-start should begin from delta_w_last * dec_fact ({}); got {}",
warm_start, delta_w
);
}
#[test]
fn test_factor_with_inertia_correction_growth_sequence() {
let n = 2;
let m = 1;
let mut matrix = SymmetricMatrix::zeros(3);
matrix.set(0, 0, -1.0); matrix.set(1, 1, -1.0);
matrix.set(2, 0, 1.0);
matrix.set(2, 1, 1.0);
let mut kkt = KktSystem {
dim: 3, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 2.0, 3.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
let mut params = InertiaCorrectionParams::default();
let (delta_w, _dc) = factor_with_inertia_correction(&mut kkt, &mut solver, &mut params, 1e-4).unwrap();
assert!(delta_w >= params.delta_w_init,
"delta_w should be at least delta_w_init, got {}", delta_w);
assert!((params.delta_w_last - delta_w).abs() < 1e-15,
"delta_w_last should equal the successful delta_w");
assert_eq!(params.degeneracy_count, 1);
}
#[test]
fn test_inertia_first_inc_factor_used_when_cold() {
let n = 2;
let m = 2;
let mut matrix = SymmetricMatrix::zeros(4);
for i in 0..4 { matrix.set(i, i, 1.0); }
matrix.set(2, 0, 0.5);
matrix.set(3, 1, 0.5);
let mut kkt = KktSystem {
dim: 4, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 2.0, 3.0, 4.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
let mut params = InertiaCorrectionParams::default();
let result = factor_with_inertia_correction(&mut kkt, &mut solver, &mut params, 1.0);
assert!(result.is_ok());
let (delta_w, _) = result.unwrap();
if delta_w > params.delta_w_init * 1.5 {
assert!(
delta_w >= params.delta_w_init * params.delta_w_inc_fact_first - 1e-12,
"Cold escalation should jump by 100x, not 8x; got delta_w={:.3e}",
delta_w
);
}
}
#[test]
fn test_inertia_dec_fact_warm_shrinks_by_one_third() {
let n = 2;
let m = 1;
let mut matrix = SymmetricMatrix::zeros(3);
matrix.set(0, 0, -1.0);
matrix.set(1, 1, -1.0);
matrix.set(2, 2, 1.0);
matrix.set(2, 0, 0.1);
matrix.set(2, 1, 0.1);
let mut kkt2 = KktSystem {
dim: 3, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 2.0, 3.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
let mut params = InertiaCorrectionParams::default();
params.delta_w_last = 0.9;
let (delta_w, _) =
factor_with_inertia_correction(&mut kkt2, &mut solver, &mut params, 1e-4).unwrap();
let expected_initial = 0.9 * (1.0 / 3.0);
assert!(
delta_w >= expected_initial - 1e-12,
"Warm-shrink should start at delta_w_last * 1/3 = {:.3e}; got {:.3e}",
expected_initial, delta_w
);
assert!(
delta_w > 0.9 / 8.0,
"delta_w should not match the deprecated /growth shrink"
);
}
#[test]
fn test_inertia_max_perturbation_cap() {
let n = 2;
let m = 2;
let mut matrix = SymmetricMatrix::zeros(4);
for i in 0..4 { matrix.set(i, i, 1.0); }
matrix.set(2, 0, 0.1);
matrix.set(3, 1, 0.1);
let mut kkt = KktSystem {
dim: 4, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 2.0, 3.0, 4.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
let mut params = InertiaCorrectionParams {
delta_w_max: 1e-2,
max_attempts: 100,
..Default::default()
};
let result = factor_with_inertia_correction(&mut kkt, &mut solver, &mut params, 1e-4);
assert!(result.is_ok(), "delta_w_max give-up must return Ok");
}
#[test]
fn test_factor_with_inertia_correction_max_attempts_cap() {
let n = 2;
let m = 2; let mut matrix = SymmetricMatrix::zeros(4);
for i in 0..4 { matrix.set(i, i, 1.0); } matrix.set(2, 0, 0.1);
matrix.set(3, 1, 0.1);
let mut kkt = KktSystem {
dim: 4, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 2.0, 3.0, 4.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
let mut params = InertiaCorrectionParams {
max_attempts: 3,
..Default::default()
};
let result = factor_with_inertia_correction(&mut kkt, &mut solver, &mut params, 1e-4);
assert!(result.is_ok(), "max-attempts path must return Ok, not error: {:?}", result.err());
let (delta_w, delta_c) = result.unwrap();
assert!(delta_w >= 0.0, "delta_w must be non-negative, got {}", delta_w);
assert!(delta_c >= 0.0, "delta_c must be non-negative, got {}", delta_c);
}
#[test]
fn test_factor_with_inertia_correction_degeneracy_count_sets_structural_flag() {
let n = 2;
let m = 1;
let mut params = InertiaCorrectionParams::default();
for _ in 0..3 {
let mut matrix = SymmetricMatrix::zeros(3);
matrix.set(0, 0, -1.0);
matrix.set(1, 1, -1.0);
matrix.set(2, 0, 1.0);
matrix.set(2, 1, 1.0);
let mut kkt = KktSystem {
dim: 3, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 2.0, 3.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
factor_with_inertia_correction(&mut kkt, &mut solver, &mut params, 1e-4).unwrap();
}
assert!(params.structurally_degenerate,
"after 3 perturbations, structurally_degenerate must latch on");
assert!(params.degeneracy_count >= 3);
}
#[test]
fn test_factor_with_inertia_correction_unconstrained_uses_min_diagonal_path() {
let n = 2;
let m = 0;
let mut matrix = SymmetricMatrix::zeros(2);
matrix.set(0, 0, -2.0); matrix.set(1, 1, 3.0);
let mut kkt = KktSystem {
dim: 2, n, m,
matrix: KktMatrix::Dense(matrix),
rhs: vec![1.0, 1.0],
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
let mut params = InertiaCorrectionParams::default();
let (delta_w, delta_c) = factor_with_inertia_correction(&mut kkt, &mut solver, &mut params, 1e-4).unwrap();
assert!(delta_w > 0.0, "indefinite unconstrained needs delta_w > 0, got {}", delta_w);
assert_eq!(delta_c, 0.0, "unconstrained (m=0) must have delta_c = 0");
}
#[test]
fn test_solve_for_direction_simple() {
let n = 2;
let m = 1;
let mut matrix = SymmetricMatrix::zeros(3);
matrix.set(0, 0, 2.0);
matrix.set(1, 1, 2.0);
matrix.set(2, 0, 1.0);
matrix.set(2, 1, 1.0);
let rhs = vec![1.0, 2.0, 0.5];
let kkt = KktSystem {
dim: 3, n, m,
matrix: KktMatrix::Dense(matrix.clone()),
rhs: rhs.clone(),
delta_c_diag: vec![0.0; m],
scale_factors: None,
};
let mut solver = DenseLdl::new();
solver.factor(&KktMatrix::Dense(matrix.clone())).unwrap();
let (dx, dy) = solve_for_direction(&kkt, &mut solver, 0.0, 0.0).unwrap();
assert_eq!(dx.len(), 2);
assert_eq!(dy.len(), 1);
let mut sol = vec![0.0; 3];
sol[..2].copy_from_slice(&dx);
sol[2] = dy[0];
let mut ax = vec![0.0; 3];
matrix.matvec(&sol, &mut ax);
for i in 0..3 {
assert!((ax[i] - rhs[i]).abs() < 1e-8,
"KKT*solution mismatch at {}: {} vs {}", i, ax[i], rhs[i]);
}
}
#[test]
fn test_recover_dz_no_bounds() {
let x = vec![1.0, 2.0];
let x_l = vec![f64::NEG_INFINITY; 2];
let x_u = vec![f64::INFINITY; 2];
let z_l = vec![0.0; 2];
let z_u = vec![0.0; 2];
let dx = vec![0.1, 0.2];
let (dz_l, dz_u) = recover_dz(&x, &x_l, &x_u, &z_l, &z_u, &dx, 0.1);
for i in 0..2 {
assert!((dz_l[i]).abs() < 1e-15);
assert!((dz_u[i]).abs() < 1e-15);
}
}
#[test]
fn test_recover_dz_lower_bound() {
let x = vec![1.5];
let x_l = vec![1.0];
let x_u = vec![f64::INFINITY];
let z_l = vec![2.0];
let z_u = vec![0.0];
let dx = vec![0.1];
let mu = 0.1;
let (dz_l, _) = recover_dz(&x, &x_l, &x_u, &z_l, &z_u, &dx, mu);
assert!((dz_l[0] - (-2.2)).abs() < 1e-12);
}
#[test]
fn test_condensed_kkt_matches_full() {
let n = 2;
let m = 3;
let hess_rows = vec![0, 1];
let hess_cols = vec![0, 1];
let hess_vals = vec![2.0, 2.0];
let jac_rows = vec![0, 0, 1, 2];
let jac_cols = vec![0, 1, 0, 1];
let jac_vals = vec![1.0, 1.0, 1.0, 1.0];
let x = vec![0.6, 0.7];
let x_l = vec![f64::NEG_INFINITY; 2];
let x_u = vec![f64::INFINITY; 2];
let z_l = vec![0.0; 2];
let z_u = vec![0.0; 2];
let sigma = compute_sigma(&x, &x_l, &x_u, &z_l, &z_u);
let grad_f = vec![1.2, 1.4];
let g = vec![1.3, 0.6, 0.7];
let g_l = vec![1.0, 0.2, 0.3];
let g_u = vec![f64::INFINITY; 3];
let y = vec![0.1, 0.05, 0.05];
let mu = 0.01;
let v_l = vec![0.0; m];
let v_u = vec![0.0; m];
let mut full_kkt = assemble_kkt(
n, m, &hess_rows, &hess_cols, &hess_vals,
&jac_rows, &jac_cols, &jac_vals, &sigma, &grad_f,
&g, &g_l, &g_u, &y, &z_l, &z_u, &x, &x_l, &x_u, mu, false,
&v_l, &v_u,
);
let mut full_solver = DenseLdl::new();
let mut params = InertiaCorrectionParams::default();
let (dw, dc) = factor_with_inertia_correction(&mut full_kkt, &mut full_solver, &mut params, 1e-4).unwrap();
let (dx_full, dy_full) = solve_for_direction(&full_kkt, &mut full_solver, dw, dc).unwrap();
let condensed = assemble_condensed_kkt(
n, m, &hess_rows, &hess_cols, &hess_vals,
&jac_rows, &jac_cols, &jac_vals, &sigma, &grad_f,
&g, &g_l, &g_u, &y, &z_l, &z_u, &x, &x_l, &x_u, mu,
&v_l, &v_u,
);
let mut cond_solver = DenseLdl::new();
cond_solver.factor(&KktMatrix::Dense(condensed.matrix.clone())).unwrap();
let (dx_cond, dy_cond) = solve_condensed(&condensed, &mut cond_solver).unwrap();
for i in 0..n {
assert!(
(dx_full[i] - dx_cond[i]).abs() < 1e-6,
"dx mismatch at {}: full={}, condensed={}", i, dx_full[i], dx_cond[i]
);
}
for i in 0..m {
assert!(
(dy_full[i] - dy_cond[i]).abs() < 1e-6,
"dy mismatch at {}: full={}, condensed={}", i, dy_full[i], dy_cond[i]
);
}
}
}