use crate::error::QpError;
use crate::kkt::KktTriplet;
use pounce_common::{Index, Number};
use pounce_linsol::status::ESymSolverStatus;
use pounce_linsol::{EMatrixFormat, SparseSymLinearSolverInterface};
pub struct LinearSolver {
backend: Box<dyn SparseSymLinearSolverInterface>,
cached_irn: Option<Vec<Index>>,
cached_jcn: Option<Vec<Index>>,
cached_dim: usize,
factored: bool,
}
impl LinearSolver {
pub fn new(backend: Box<dyn SparseSymLinearSolverInterface>) -> Self {
Self {
backend,
cached_irn: None,
cached_jcn: None,
cached_dim: 0,
factored: false,
}
}
pub fn factorize_and_solve(
&mut self,
kkt: &KktTriplet,
rhs: &mut [Number],
expected_neg_evals: Option<i32>,
) -> Result<(), QpError> {
let format = self.backend.matrix_format();
if format != EMatrixFormat::TripletFormat {
return Err(QpError::LinearSolverFailure(format!(
"backend requires {format:?} but pounce-qp emits TripletFormat"
)));
}
if rhs.len() != kkt.dim {
return Err(QpError::DimensionMismatch(format!(
"RHS length {} but KKT dim is {}",
rhs.len(),
kkt.dim
)));
}
self.factored = false;
let dim = kkt.dim as Index;
let nnz = kkt.irn.len() as Index;
let st = self
.backend
.initialize_structure(dim, nnz, &kkt.irn, &kkt.jcn);
if st != ESymSolverStatus::Success {
return Err(QpError::LinearSolverFailure(format!(
"initialize_structure → {st:?}"
)));
}
self.backend.values_array_mut().copy_from_slice(&kkt.vals);
let (check, expected) = match expected_neg_evals {
Some(e) => (true, e),
None => (false, 0),
};
let st = self.backend.multi_solve(
true, &kkt.irn, &kkt.jcn, 1, rhs, check, expected,
);
match st {
ESymSolverStatus::Success => {
self.cached_irn = Some(kkt.irn.clone());
self.cached_jcn = Some(kkt.jcn.clone());
self.cached_dim = kkt.dim;
self.factored = true;
Ok(())
}
ESymSolverStatus::Singular => Err(QpError::LinearSolverFailure(
"KKT matrix is singular (LICQ violation or rank-deficient Jacobian)".into(),
)),
ESymSolverStatus::WrongInertia => Err(QpError::LinearSolverFailure(format!(
"KKT inertia mismatch: expected {} negative eigenvalues, got {}",
expected,
self.backend.number_of_neg_evals()
))),
ESymSolverStatus::CallAgain => Err(QpError::LinearSolverFailure(
"backend requested re-call; not yet supported in pounce-qp".into(),
)),
ESymSolverStatus::FatalError => Err(QpError::LinearSolverFailure(
"backend reported fatal error".into(),
)),
}
}
pub fn resolve(&mut self, rhs: &mut [Number]) -> Result<(), QpError> {
if !self.factored {
return Err(QpError::LinearSolverFailure(
"resolve called before a successful factorize_and_solve".into(),
));
}
if rhs.len() != self.cached_dim {
return Err(QpError::DimensionMismatch(format!(
"resolve RHS length {} but cached factor has dim {}",
rhs.len(),
self.cached_dim
)));
}
let irn = self.cached_irn.as_ref().expect("factored ⇒ cache present");
let jcn = self.cached_jcn.as_ref().expect("factored ⇒ cache present");
let st = self.backend.multi_solve(
false, irn, jcn, 1, rhs, false, 0,
);
match st {
ESymSolverStatus::Success => Ok(()),
other => Err(QpError::LinearSolverFailure(format!(
"resolve backend status: {other:?}"
))),
}
}
pub fn number_of_neg_evals(&self) -> Option<i32> {
if self.backend.provides_inertia() {
Some(self.backend.number_of_neg_evals())
} else {
None
}
}
pub fn has_cached_factor(&self) -> bool {
self.factored
}
}