use crate::factor::LinearSolver;
use crate::kkt::a_times_x;
use crate::options::QpOptions;
use crate::problem::{HessianInertia, QpProblem};
use crate::schur::SchurState;
use crate::working_set::{BoundStatus, ConsStatus, WorkingSet};
use pounce_feral::FeralSolverInterface;
use pounce_linalg::triplet::{GenTMatrix, GenTMatrixSpace, SymTMatrix, SymTMatrixSpace};
use std::rc::Rc;
fn linsol() -> LinearSolver {
LinearSolver::new(Box::new(FeralSolverInterface::new()))
}
fn tiny_qp() -> (
SymTMatrix,
GenTMatrix,
[f64; 2],
[f64; 1],
[f64; 1],
[f64; 2],
[f64; 2],
) {
let h_space = SymTMatrixSpace::new(2, vec![1, 2], vec![1, 2]);
let mut h = SymTMatrix::new(Rc::clone(&h_space));
h.set_values(&[1.0, 1.0]);
let a_space = GenTMatrixSpace::new(1, 2, vec![1, 1], vec![1, 2]);
let mut a = GenTMatrix::new(Rc::clone(&a_space));
a.set_values(&[1.0, 1.0]);
let g = [0.0, 0.0];
let bl = [1.0];
let bu = [1.0];
let xl = [pounce_common::types::NLP_LOWER_BOUND_INF; 2];
let xu = [pounce_common::types::NLP_UPPER_BOUND_INF; 2];
(h, a, g, bl, bu, xl, xu)
}
#[test]
fn schur_state_constructs_with_expected_dimensions() {
let s = SchurState::new(3, 2);
assert_eq!(s.n, 3);
assert_eq!(s.m, 2);
assert_eq!(s.m_total, 5);
assert_eq!(s.dim, 8);
}
#[test]
fn schur_reset_factors_k_max_and_solves_with_no_updates() {
let (h, a, g, bl, bu, xl, xu) = tiny_qp();
let qp = QpProblem {
n: 2,
m: 1,
h: &h,
g: &g,
a: &a,
bl: &bl,
bu: &bu,
xl: &xl,
xu: &xu,
hessian_inertia: HessianInertia::Psd,
};
let mut working = WorkingSet::cold(2, 1);
working.constraints[0] = ConsStatus::Equality;
let mut state = SchurState::new(2, 1);
let mut ls = linsol();
state
.reset(&mut ls, &qp, &working, 1, &QpOptions::default())
.unwrap();
assert_eq!(state.n_schur_updates(), 0);
let mut rhs = vec![0.0, 0.0, 1.0, 0.0, 0.0];
state.solve(&mut ls, &mut rhs).unwrap();
assert!((rhs[0] - 0.5).abs() < 1e-10, "x[0] = {}", rhs[0]);
assert!((rhs[1] - 0.5).abs() < 1e-10, "x[1] = {}", rhs[1]);
assert!((rhs[2] + 0.5).abs() < 1e-10, "λ_eq = {}", rhs[2]);
assert!(rhs[3].abs() < 1e-10, "λ_b1 = {}", rhs[3]);
assert!(rhs[4].abs() < 1e-10, "λ_b2 = {}", rhs[4]);
}
#[test]
fn schur_apply_change_to_activate_bound_matches_fresh_factor() {
let (h, a, g, bl, bu, xl, xu) = tiny_qp();
let qp = QpProblem {
n: 2,
m: 1,
h: &h,
g: &g,
a: &a,
bl: &bl,
bu: &bu,
xl: &xl,
xu: &xu,
hessian_inertia: HessianInertia::Psd,
};
let mut working = WorkingSet::cold(2, 1);
working.constraints[0] = ConsStatus::Equality;
let mut state = SchurState::new(2, 1);
let mut ls = linsol();
state
.reset(&mut ls, &qp, &working, 1, &QpOptions::default())
.unwrap();
state.apply_change(&mut ls, &qp, 1, true).unwrap();
assert_eq!(state.n_schur_updates(), 1);
let mut rhs_schur = vec![0.0, 0.0, 1.0, 0.0, 0.0];
state.solve(&mut ls, &mut rhs_schur).unwrap();
let mut working_ref = working.clone();
working_ref.bounds[0] = BoundStatus::AtLower;
let mut state_ref = SchurState::new(2, 1);
let mut ls_ref = linsol();
state_ref
.reset(&mut ls_ref, &qp, &working_ref, 2, &QpOptions::default())
.unwrap();
let mut rhs_ref = vec![0.0, 0.0, 1.0, 0.0, 0.0];
state_ref.solve(&mut ls_ref, &mut rhs_ref).unwrap();
for (i, (&a, &b)) in rhs_schur.iter().zip(rhs_ref.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-9,
"rhs[{i}]: schur={a}, fresh={b} (diff {})",
(a - b).abs(),
);
}
}
#[test]
fn schur_two_flips_match_fresh_factor() {
let (h, a, g, bl, bu, xl, xu) = tiny_qp();
let qp = QpProblem {
n: 2,
m: 1,
h: &h,
g: &g,
a: &a,
bl: &bl,
bu: &bu,
xl: &xl,
xu: &xu,
hessian_inertia: HessianInertia::Psd,
};
let mut working = WorkingSet::cold(2, 1);
working.constraints[0] = ConsStatus::Equality;
let mut state = SchurState::new(2, 1);
let mut ls = linsol();
state
.reset(&mut ls, &qp, &working, 1, &QpOptions::default())
.unwrap();
state.apply_change(&mut ls, &qp, 1, true).unwrap(); state.apply_change(&mut ls, &qp, 2, true).unwrap(); assert_eq!(state.n_schur_updates(), 2);
let mut rhs = vec![0.0, 0.0, 1.0, 0.0, 0.0];
let err = state.solve(&mut ls, &mut rhs);
let _ = err;
}
#[test]
fn schur_reset_after_apply_change_clears_state() {
let (h, a, g, bl, bu, xl, xu) = tiny_qp();
let qp = QpProblem {
n: 2,
m: 1,
h: &h,
g: &g,
a: &a,
bl: &bl,
bu: &bu,
xl: &xl,
xu: &xu,
hessian_inertia: HessianInertia::Psd,
};
let mut working = WorkingSet::cold(2, 1);
working.constraints[0] = ConsStatus::Equality;
let mut state = SchurState::new(2, 1);
let mut ls = linsol();
state
.reset(&mut ls, &qp, &working, 1, &QpOptions::default())
.unwrap();
state.apply_change(&mut ls, &qp, 1, true).unwrap();
assert_eq!(state.n_schur_updates(), 1);
let mut working2 = working.clone();
working2.bounds[0] = BoundStatus::AtLower;
state
.reset(&mut ls, &qp, &working2, 2, &QpOptions::default())
.unwrap();
assert_eq!(state.n_schur_updates(), 0);
}
#[test]
fn schur_dot_helper_is_used_correctly() {
let (h, a, g, bl, bu, xl, xu) = tiny_qp();
let qp = QpProblem {
n: 2,
m: 1,
h: &h,
g: &g,
a: &a,
bl: &bl,
bu: &bu,
xl: &xl,
xu: &xu,
hessian_inertia: HessianInertia::Psd,
};
let mut working = WorkingSet::cold(2, 1);
working.constraints[0] = ConsStatus::Equality;
let mut state = SchurState::new(2, 1);
let mut ls = linsol();
state
.reset(&mut ls, &qp, &working, 1, &QpOptions::default())
.unwrap();
let mut rhs = vec![0.0, 0.0, 1.0, 0.0, 0.0];
state.solve(&mut ls, &mut rhs).unwrap();
let x = &rhs[..2];
let ax = a_times_x(qp.a, x, 1);
assert!((ax[0] - 1.0).abs() < 1e-10, "Ax = {}, want 1.0", ax[0]);
}