use crate::sparse::CscMatrix;
use std::time::Instant;
pub const DEFAULT_MEMORY_BUDGET_BYTES: usize = 4 * 1024 * 1024 * 1024;
pub const BYTES_PER_L_ENTRY: usize = 16;
pub fn memory_budget_bytes() -> usize {
DEFAULT_MEMORY_BUDGET_BYTES
}
pub fn max_l_nnz_from_budget() -> usize {
DEFAULT_MEMORY_BUDGET_BYTES / BYTES_PER_L_ENTRY
}
#[derive(Debug, Clone, Copy)]
pub struct KktConfig {
pub dd_ldl: bool,
pub minres_ir: usize,
pub max_l_nnz: usize,
}
impl Default for KktConfig {
fn default() -> Self {
Self {
dd_ldl: false,
minres_ir: MINRES_INEXACT_NEWTON_IR_STEPS,
max_l_nnz: DEFAULT_MEMORY_BUDGET_BYTES / BYTES_PER_L_ENTRY,
}
}
}
#[non_exhaustive]
#[derive(Debug)]
pub enum KktError {
DeadlineExceeded,
SingularOrIndefinite,
WouldExceedMemory,
DidNotConverge,
}
impl std::fmt::Display for KktError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KktError::DeadlineExceeded => write!(f, "KKT solver: deadline exceeded"),
KktError::SingularOrIndefinite => write!(f, "KKT solver: singular or indefinite"),
KktError::WouldExceedMemory => write!(f, "KKT solver: would exceed memory budget"),
KktError::DidNotConverge => write!(f, "KKT solver: did not converge"),
}
}
}
impl std::error::Error for KktError {}
pub trait KktSolver: Send {
fn solve(
&self,
rhs: &[f64],
sol: &mut [f64],
deadline: Option<Instant>,
) -> Result<(), KktError>;
fn refactor(
&mut self,
k: &CscMatrix,
deadline: Option<Instant>,
) -> Result<(), KktError>;
fn dim(&self) -> usize;
}
pub struct DirectLdl {
factor: Option<crate::linalg::ldl::LdlFactorizationAmd>,
n: usize,
max_l_nnz: Option<usize>,
par: faer::Par,
}
impl DirectLdl {
pub fn new(n: usize) -> Self {
Self { factor: None, n, max_l_nnz: None, par: faer::Par::Seq }
}
pub fn with_budget(n: usize, max_l_nnz: usize) -> Self {
Self { factor: None, n, max_l_nnz: Some(max_l_nnz), par: faer::Par::Seq }
}
pub fn with_par(mut self, par: faer::Par) -> Self {
self.par = par;
self
}
pub fn from_matrix(k: &CscMatrix, deadline: Option<Instant>) -> Result<Self, KktError> {
let mut s = Self::new(k.nrows);
s.refactor(k, deadline)?;
Ok(s)
}
pub fn l_nnz(&self) -> usize {
self.factor.as_ref().map_or(0, |f| f.nnz_l())
}
}
impl KktSolver for DirectLdl {
fn solve(
&self,
rhs: &[f64],
sol: &mut [f64],
_deadline: Option<Instant>,
) -> Result<(), KktError> {
let factor = self.factor.as_ref().ok_or(KktError::SingularOrIndefinite)?;
factor.solve(rhs, sol);
Ok(())
}
fn refactor(
&mut self,
k: &CscMatrix,
deadline: Option<Instant>,
) -> Result<(), KktError> {
if k.nrows != self.n || k.ncols != self.n {
return Err(KktError::SingularOrIndefinite);
}
match crate::linalg::ldl::factorize_quasidefinite_with_amd_budget_par(
k, deadline, self.max_l_nnz, self.par,
) {
Ok(f) => {
self.factor = Some(f);
Ok(())
}
Err(crate::linalg::ldl::LdlError::DeadlineExceeded) => {
self.factor = None;
Err(KktError::DeadlineExceeded)
}
Err(crate::linalg::ldl::LdlError::SingularOrIndefinite) => {
self.factor = None;
Err(KktError::SingularOrIndefinite)
}
Err(crate::linalg::ldl::LdlError::WouldExceedBudget { .. }) => {
self.factor = None;
Err(KktError::WouldExceedMemory)
}
}
}
fn dim(&self) -> usize {
self.n
}
}
pub struct PreconditionedMinres {
k: CscMatrix,
m_inv_diag: Vec<f64>,
kind: PreconditionerKind,
max_iter: usize,
tol: f64,
ir_steps: usize,
}
const MIN_DIAG: f64 = 1e-12;
#[derive(Debug, Clone, Copy)]
pub enum PreconditionerKind {
Jacobi,
BlockDiag { n_top: usize },
}
pub(crate) const IPM_OUTER_VS_INNER_RATIO: f64 = 0.1;
pub(crate) const IPM_INEXACT_ETA_FLOOR: f64 = 1e-13;
pub fn inexact_eta_for_eps(eps: f64) -> f64 {
(eps * IPM_OUTER_VS_INNER_RATIO).max(IPM_INEXACT_ETA_FLOOR)
}
pub(crate) const MINRES_INEXACT_NEWTON_ETA: f64 = 1e-7;
pub(crate) const MINRES_INEXACT_NEWTON_IR_STEPS: usize = 0;
const MINRES_DEFAULT_TOL: f64 = 1e-9;
const MINRES_MAX_ITER_MULTIPLIER: usize = 2;
impl PreconditionedMinres {
pub fn set_inexact_tol(&mut self, tol: f64) {
self.tol = tol;
}
pub fn new(k: CscMatrix) -> Self {
let kind = PreconditionerKind::Jacobi;
let m_inv_diag = compute_inv_diag(&k, kind);
let n = k.nrows;
Self { k, m_inv_diag, kind, max_iter: MINRES_MAX_ITER_MULTIPLIER * n, tol: MINRES_DEFAULT_TOL, ir_steps: 0 }
}
pub fn with_block_diag_inexact(k: CscMatrix, n_top: usize, eta: f64, ir: usize) -> Self {
let kind = PreconditionerKind::BlockDiag { n_top };
let m_inv_diag = compute_inv_diag(&k, kind);
let n = k.nrows;
Self {
k,
m_inv_diag,
kind,
max_iter: 2 * n,
tol: eta,
ir_steps: ir,
}
}
pub fn new_inexact(k: CscMatrix, eta: f64, ir: usize) -> Self {
let kind = PreconditionerKind::Jacobi;
let m_inv_diag = compute_inv_diag(&k, kind);
let n = k.nrows;
Self {
k,
m_inv_diag,
kind,
max_iter: 2 * n,
tol: eta,
ir_steps: ir,
}
}
}
fn compute_inv_diag(k: &CscMatrix, kind: PreconditionerKind) -> Vec<f64> {
match kind {
PreconditionerKind::Jacobi => compute_jacobi_inv_diag(k),
PreconditionerKind::BlockDiag { n_top } => compute_block_diag_inv(k, n_top),
}
}
fn compute_jacobi_inv_diag(k: &CscMatrix) -> Vec<f64> {
let n = k.nrows;
let mut diag_abs = vec![MIN_DIAG; n];
for j in 0..n {
for k_idx in k.col_ptr[j]..k.col_ptr[j + 1] {
if k.row_ind[k_idx] == j {
let v = k.values[k_idx].abs();
diag_abs[j] = v.max(MIN_DIAG);
break;
}
}
}
diag_abs.iter().map(|&d| 1.0 / d).collect()
}
fn compute_block_diag_inv(k: &CscMatrix, n_top: usize) -> Vec<f64> {
let n_total = k.nrows;
debug_assert!(n_top <= n_total);
let m_bot = n_total - n_top;
let mut top_diag = vec![MIN_DIAG; n_top];
for j in 0..n_top {
for k_idx in k.col_ptr[j]..k.col_ptr[j + 1] {
if k.row_ind[k_idx] == j {
top_diag[j] = k.values[k_idx].abs().max(MIN_DIAG);
break;
}
}
}
let mut bot_diag = vec![MIN_DIAG; m_bot];
for i in 0..m_bot {
let col = n_top + i;
let mut accum = 0.0_f64;
for k_idx in k.col_ptr[col]..k.col_ptr[col + 1] {
let r = k.row_ind[k_idx];
let val = k.values[k_idx];
if r < n_top {
accum += (val * val) / top_diag[r];
} else if r == col {
accum += val.abs();
}
}
bot_diag[i] = accum.max(MIN_DIAG);
}
let mut m_inv = Vec::with_capacity(n_total);
m_inv.extend(top_diag.iter().map(|&d| 1.0 / d));
m_inv.extend(bot_diag.iter().map(|&d| 1.0 / d));
m_inv
}
impl KktSolver for PreconditionedMinres {
fn solve(
&self,
rhs: &[f64],
sol: &mut [f64],
deadline: Option<Instant>,
) -> Result<(), KktError> {
for s in sol.iter_mut() { *s = 0.0; }
let k = &self.k;
let m_inv = &self.m_inv_diag;
let n = k.nrows;
let do_minres = |sol: &mut [f64], rhs: &[f64]| {
crate::linalg::minres::pminres(
|v, y| crate::linalg::minres::matvec_sym_upper(k, v, y),
|r, z| {
for i in 0..r.len() {
z[i] = r[i] * m_inv[i];
}
},
rhs,
sol,
self.tol,
self.max_iter,
|| deadline.is_some_and(|d| Instant::now() >= d),
)
};
let stats = do_minres(sol, rhs);
if self.ir_steps > 0 {
let mut residual = vec![0.0_f64; n];
let mut delta = vec![0.0_f64; n];
for _ in 0..self.ir_steps {
if deadline.is_some_and(|d| Instant::now() >= d) {
break;
}
crate::linalg::minres::matvec_sym_upper(k, sol, &mut residual);
let mut r_norm_sq = 0.0_f64;
for i in 0..n {
residual[i] = rhs[i] - residual[i];
r_norm_sq += residual[i] * residual[i];
}
let r_norm = r_norm_sq.sqrt();
let rhs_norm = rhs.iter().fold(0.0_f64, |a, &v| a + v * v).sqrt();
if rhs_norm > 0.0 && r_norm <= 1e-14 * rhs_norm {
break;
}
for d in delta.iter_mut() { *d = 0.0; }
do_minres(&mut delta, &residual);
for i in 0..n {
sol[i] += delta[i];
}
}
}
if stats.converged {
Ok(())
} else if deadline.is_some_and(|d| Instant::now() >= d) {
Err(KktError::DeadlineExceeded)
} else {
Err(KktError::DidNotConverge)
}
}
fn refactor(
&mut self,
k: &CscMatrix,
_deadline: Option<Instant>,
) -> Result<(), KktError> {
if k.nrows != self.dim() || k.ncols != self.dim() {
return Err(KktError::SingularOrIndefinite);
}
self.m_inv_diag = compute_inv_diag(k, self.kind);
self.k = k.clone();
Ok(())
}
fn dim(&self) -> usize {
self.k.nrows
}
}
pub enum KktFactor {
Direct(crate::linalg::ldl::LdlFactorizationAmd),
DirectDd(crate::linalg::ldl_dd::LdlFactorizationDdAmd),
Iterative(PreconditionedMinres),
}
impl KktFactor {
pub fn set_iterative_tol(&mut self, tol: f64) {
if let KktFactor::Iterative(minres) = self {
minres.set_inexact_tol(tol);
}
}
pub fn solve(&self, rhs: &[f64], sol: &mut [f64]) {
self.solve_with_deadline(rhs, sol, None);
}
pub fn solve_with_deadline(
&self,
rhs: &[f64],
sol: &mut [f64],
deadline: Option<Instant>,
) {
match self {
KktFactor::Direct(ldl) => ldl.solve(rhs, sol),
KktFactor::DirectDd(ldl_dd) => ldl_dd.solve(rhs, sol),
KktFactor::Iterative(minres) => {
let _ = minres.solve(rhs, sol, deadline);
}
}
}
pub fn is_iterative(&self) -> bool {
matches!(self, KktFactor::Iterative(_))
}
}
pub fn factorize_kkt_with_cached_perm_par(
k: &CscMatrix,
perm: &[usize],
deadline: Option<Instant>,
cfg: &KktConfig,
n_top: Option<usize>,
par: faer::Par,
) -> Result<KktFactor, KktError> {
let eta = MINRES_INEXACT_NEWTON_ETA;
let ir = cfg.minres_ir;
if cfg.dd_ldl {
match crate::linalg::ldl::factorize_quasidefinite_with_cached_perm_budget_par(
k, perm, deadline, Some(cfg.max_l_nnz), par,
) {
Ok(_) => {
match crate::linalg::ldl_dd::factorize_quasidefinite_with_cached_perm_dd(
k, perm, deadline,
) {
Ok(f) => {
return Ok(KktFactor::DirectDd(f));
}
Err(crate::linalg::ldl::LdlError::DeadlineExceeded) => {
return Err(KktError::DeadlineExceeded);
}
Err(crate::linalg::ldl::LdlError::SingularOrIndefinite) => {
return Err(KktError::SingularOrIndefinite);
}
Err(crate::linalg::ldl::LdlError::WouldExceedBudget { .. }) => {
}
}
}
Err(crate::linalg::ldl::LdlError::WouldExceedBudget { .. }) => {
}
Err(crate::linalg::ldl::LdlError::DeadlineExceeded) => {
return Err(KktError::DeadlineExceeded);
}
Err(crate::linalg::ldl::LdlError::SingularOrIndefinite) => {
return Err(KktError::SingularOrIndefinite);
}
}
let minres = match n_top {
Some(n) if n <= k.nrows => PreconditionedMinres::with_block_diag_inexact(k.clone(), n, eta, ir),
_ => PreconditionedMinres::new_inexact(k.clone(), eta, ir),
};
return Ok(KktFactor::Iterative(minres));
}
match crate::linalg::ldl::factorize_quasidefinite_with_cached_perm_budget_par(
k, perm, deadline, Some(cfg.max_l_nnz), par,
) {
Ok(f) => Ok(KktFactor::Direct(f)),
Err(crate::linalg::ldl::LdlError::WouldExceedBudget { .. }) => {
let minres = match n_top {
Some(n) if n <= k.nrows => PreconditionedMinres::with_block_diag_inexact(k.clone(), n, eta, ir),
_ => PreconditionedMinres::new_inexact(k.clone(), eta, ir),
};
Ok(KktFactor::Iterative(minres))
}
Err(crate::linalg::ldl::LdlError::DeadlineExceeded) => Err(KktError::DeadlineExceeded),
Err(crate::linalg::ldl::LdlError::SingularOrIndefinite) => {
Err(KktError::SingularOrIndefinite)
}
}
}
pub fn factorize_kkt_pre_permuted_cached_par(
pre_permuted_k: &CscMatrix,
unpermuted_k: &CscMatrix,
perm: &[usize],
deadline: Option<Instant>,
cfg: &KktConfig,
n_top: Option<usize>,
cached_symbolic: Option<std::sync::Arc<faer::sparse::linalg::cholesky::SymbolicCholesky<usize>>>,
par: faer::Par,
) -> Result<KktFactor, KktError> {
if cfg.dd_ldl {
return factorize_kkt_with_cached_perm_par(
unpermuted_k, perm, deadline, cfg, n_top, par,
);
}
let eta = MINRES_INEXACT_NEWTON_ETA;
let ir = cfg.minres_ir;
match crate::linalg::ldl::factorize_quasidefinite_pre_permuted_cached_par(
pre_permuted_k, perm, deadline, Some(cfg.max_l_nnz), cached_symbolic, par,
) {
Ok(f) => Ok(KktFactor::Direct(f)),
Err(crate::linalg::ldl::LdlError::WouldExceedBudget { .. }) => {
let minres = match n_top {
Some(n) if n <= unpermuted_k.nrows => {
PreconditionedMinres::with_block_diag_inexact(unpermuted_k.clone(), n, eta, ir)
}
_ => PreconditionedMinres::new_inexact(unpermuted_k.clone(), eta, ir),
};
Ok(KktFactor::Iterative(minres))
}
Err(crate::linalg::ldl::LdlError::DeadlineExceeded) => Err(KktError::DeadlineExceeded),
Err(crate::linalg::ldl::LdlError::SingularOrIndefinite) => {
Err(KktError::SingularOrIndefinite)
}
}
}
impl KktFactor {
pub fn symbolic_arc(&self) -> Option<std::sync::Arc<faer::sparse::linalg::cholesky::SymbolicCholesky<usize>>> {
match self {
KktFactor::Direct(f) => Some(f.symbolic_arc()),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KktBackend {
Direct,
Iterative,
}
pub struct AutoKktSolver {
n: usize,
direct: Option<DirectLdl>,
iterative: Option<PreconditionedMinres>,
last_used: Option<KktBackend>,
}
impl AutoKktSolver {
pub fn new(n: usize) -> Self {
Self {
n,
direct: Some(DirectLdl::with_budget(n, DEFAULT_MEMORY_BUDGET_BYTES / BYTES_PER_L_ENTRY)),
iterative: None,
last_used: None,
}
}
pub fn with_budget(n: usize, max_l_nnz: usize) -> Self {
Self {
n,
direct: Some(DirectLdl::with_budget(n, max_l_nnz)),
iterative: None,
last_used: None,
}
}
pub fn last_backend(&self) -> Option<KktBackend> {
self.last_used
}
}
impl KktSolver for AutoKktSolver {
fn solve(
&self,
rhs: &[f64],
sol: &mut [f64],
deadline: Option<Instant>,
) -> Result<(), KktError> {
match self.last_used {
Some(KktBackend::Direct) => self
.direct
.as_ref()
.ok_or(KktError::SingularOrIndefinite)?
.solve(rhs, sol, deadline),
Some(KktBackend::Iterative) => self
.iterative
.as_ref()
.ok_or(KktError::SingularOrIndefinite)?
.solve(rhs, sol, deadline),
None => Err(KktError::SingularOrIndefinite),
}
}
fn refactor(
&mut self,
k: &CscMatrix,
deadline: Option<Instant>,
) -> Result<(), KktError> {
if k.nrows != self.n || k.ncols != self.n {
return Err(KktError::SingularOrIndefinite);
}
if let Some(direct) = self.direct.as_mut() {
match direct.refactor(k, deadline) {
Ok(()) => {
self.last_used = Some(KktBackend::Direct);
return Ok(());
}
Err(KktError::WouldExceedMemory) => {
self.direct = None;
}
Err(e) => return Err(e),
}
}
if let Some(it) = self.iterative.as_mut() {
it.refactor(k, deadline)?;
} else {
self.iterative = Some(PreconditionedMinres::new(k.clone()));
}
self.last_used = Some(KktBackend::Iterative);
Ok(())
}
fn dim(&self) -> usize {
self.n
}
}
#[cfg(test)]
#[allow(clippy::print_stdout, clippy::print_stderr)]
mod tests {
use super::*;
use crate::sparse::CscMatrix;
#[test]
fn directldl_2x2_solve_matches_hand_calc() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let solver = DirectLdl::from_matrix(&k, None).expect("factorize");
let mut sol = vec![0.0; 2];
solver.solve(&[3.0, 0.0], &mut sol, None).expect("solve");
assert!((sol[0] - 1.0).abs() < 1e-10, "u[0]≈1, got {}", sol[0]);
assert!((sol[1] - 1.0).abs() < 1e-10, "u[1]≈1, got {}", sol[1]);
assert_eq!(solver.dim(), 2);
assert!(solver.l_nnz() > 0, "L should have nonzeros after factorize");
}
#[test]
fn directldl_refactor_changes_k() {
let k1 = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, -1.0], 2, 2).unwrap();
let k2 = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[4.0, -2.0], 2, 2).unwrap();
let mut solver = DirectLdl::from_matrix(&k1, None).expect("factorize 1");
let mut sol = vec![0.0; 2];
solver.solve(&[2.0, -1.0], &mut sol, None).expect("solve 1");
assert!((sol[0] - 1.0).abs() < 1e-10);
assert!((sol[1] - 1.0).abs() < 1e-10);
solver.refactor(&k2, None).expect("refactor");
solver.solve(&[4.0, -2.0], &mut sol, None).expect("solve 2");
assert!((sol[0] - 1.0).abs() < 1e-10);
assert!((sol[1] - 1.0).abs() < 1e-10);
}
#[test]
fn directldl_past_deadline_returns_deadline_exceeded() {
let k = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, -1.0], 2, 2).unwrap();
let past = Instant::now() - std::time::Duration::from_secs(1);
let result = DirectLdl::from_matrix(&k, Some(past));
assert!(
matches!(result, Err(KktError::DeadlineExceeded)),
"past deadline should yield DeadlineExceeded, got {:?}",
result.err()
);
}
#[test]
fn directldl_dim_mismatch_returns_err() {
let k = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, -1.0], 2, 2).unwrap();
let mut solver = DirectLdl::new(3); let result = solver.refactor(&k, None);
assert!(result.is_err(), "dim mismatch should yield Err");
}
#[test]
fn directldl_with_tight_budget_returns_would_exceed_memory() {
let k = CscMatrix::from_triplets(
&[0, 0, 1, 0, 1, 2, 3, 3, 4],
&[0, 1, 1, 2, 2, 2, 3, 4, 4],
&[1.0, 0.1, 1.0, 0.1, 0.1, 1.0, -1.0, 0.1, -1.0],
5, 5,
).unwrap();
let mut solver = DirectLdl::with_budget(5, 1);
let result = solver.refactor(&k, None);
assert!(
matches!(result, Err(KktError::WouldExceedMemory)),
"tight budget (1 entry) should trigger WouldExceedMemory, got {:?}",
result.err()
);
}
#[test]
fn directldl_without_budget_is_unconstrained() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let mut solver = DirectLdl::new(2); solver.refactor(&k, None).expect("no-budget refactor should always succeed for valid K");
}
#[test]
fn memory_budget_returns_static_default() {
let budget = memory_budget_bytes();
assert_eq!(budget, DEFAULT_MEMORY_BUDGET_BYTES, "must equal 4 GiB default");
}
#[test]
fn max_l_nnz_from_budget_conversion() {
let l = max_l_nnz_from_budget();
assert_eq!(l, DEFAULT_MEMORY_BUDGET_BYTES / BYTES_PER_L_ENTRY);
}
#[test]
fn ipm_opts_kkt_budget_controls_max_l_nnz() {
use crate::options::IpmOptions;
let opts_default = IpmOptions::default();
assert_eq!(opts_default.effective_max_l_nnz(), DEFAULT_MEMORY_BUDGET_BYTES / BYTES_PER_L_ENTRY);
let opts_small = IpmOptions { kkt_memory_budget_bytes: Some(1600), ..Default::default() };
assert_eq!(opts_small.effective_max_l_nnz(), 100, "1600 / 16 = 100 entries");
}
#[test]
fn kkt_solver_works_as_trait_object() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let solver: Box<dyn KktSolver> = Box::new(
DirectLdl::from_matrix(&k, None).expect("factorize"),
);
let mut sol = vec![0.0; 2];
solver.solve(&[3.0, 0.0], &mut sol, None).expect("solve via trait");
assert!((sol[0] - 1.0).abs() < 1e-10);
assert!((sol[1] - 1.0).abs() < 1e-10);
}
#[test]
fn minres_kkt_2x2_indefinite() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let solver = PreconditionedMinres::new(k);
let mut sol = vec![0.0; 2];
solver.solve(&[3.0, 0.0], &mut sol, None).expect("MINRES solve");
assert!((sol[0] - 1.0).abs() < 1e-7);
assert!((sol[1] - 1.0).abs() < 1e-7);
assert_eq!(solver.dim(), 2);
}
#[test]
fn minres_kkt_5x5_matches_direct_ldl() {
let entries = [
(0, 0, 4.0), (0, 1, 0.5), (1, 1, 4.0), (1, 2, 0.5), (2, 2, 4.0),
(0, 3, 0.3),
(3, 3, -2.0), (3, 4, 0.4), (4, 4, -2.0),
];
let rows: Vec<usize> = entries.iter().map(|(r, _, _)| *r).collect();
let cols: Vec<usize> = entries.iter().map(|(_, c, _)| *c).collect();
let vals: Vec<f64> = entries.iter().map(|(_, _, v)| *v).collect();
let k = CscMatrix::from_triplets(&rows, &cols, &vals, 5, 5).unwrap();
let b = vec![1.0, 2.0, -1.0, 0.5, -0.5];
let mut x_ldl = vec![0.0; 5];
let ldl_solver = DirectLdl::from_matrix(&k, None).unwrap();
ldl_solver.solve(&b, &mut x_ldl, None).unwrap();
let mut x_minres = vec![0.0; 5];
let minres_solver = PreconditionedMinres::new(k);
minres_solver.solve(&b, &mut x_minres, None).expect("MINRES solve");
for i in 0..5 {
assert!(
(x_ldl[i] - x_minres[i]).abs() < 1e-6,
"x[{}]: LDL={}, MINRES={}", i, x_ldl[i], x_minres[i]
);
}
}
#[test]
fn minres_kkt_refactor() {
let k1 = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, -1.0], 2, 2).unwrap();
let k2 = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[4.0, -2.0], 2, 2).unwrap();
let mut solver = PreconditionedMinres::new(k1);
let mut sol = vec![0.0; 2];
solver.solve(&[2.0, -1.0], &mut sol, None).unwrap();
assert!((sol[0] - 1.0).abs() < 1e-7);
assert!((sol[1] - 1.0).abs() < 1e-7);
solver.refactor(&k2, None).expect("refactor");
solver.solve(&[4.0, -2.0], &mut sol, None).unwrap();
assert!((sol[0] - 1.0).abs() < 1e-7);
assert!((sol[1] - 1.0).abs() < 1e-7);
}
#[test]
fn minres_kkt_past_deadline() {
let k = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, -1.0], 2, 2).unwrap();
let solver = PreconditionedMinres::new(k);
let mut sol = vec![0.0; 2];
let past = Instant::now() - std::time::Duration::from_secs(1);
let result = solver.solve(&[1.0, 1.0], &mut sol, Some(past));
assert!(
matches!(result, Err(KktError::DeadlineExceeded)),
"past deadline should yield DeadlineExceeded, got {:?}", result.err()
);
}
#[test]
fn minres_kkt_works_as_trait_object() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let solver: Box<dyn KktSolver> = Box::new(PreconditionedMinres::new(k));
let mut sol = vec![0.0; 2];
solver.solve(&[3.0, 0.0], &mut sol, None).expect("solve via trait");
assert!((sol[0] - 1.0).abs() < 1e-7);
assert!((sol[1] - 1.0).abs() < 1e-7);
}
#[test]
fn auto_uses_direct_when_budget_sufficient() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let mut solver = AutoKktSolver::with_budget(2, 10000);
solver.refactor(&k, None).expect("refactor");
assert_eq!(solver.last_backend(), Some(KktBackend::Direct));
let mut sol = vec![0.0; 2];
solver.solve(&[3.0, 0.0], &mut sol, None).expect("solve");
assert!((sol[0] - 1.0).abs() < 1e-9);
assert!((sol[1] - 1.0).abs() < 1e-9);
}
#[test]
fn auto_falls_back_to_iterative_when_budget_exceeded() {
let k = CscMatrix::from_triplets(
&[0, 0, 1, 1, 2, 2, 3, 3, 4],
&[0, 1, 1, 2, 2, 3, 3, 4, 4],
&[4.0, 0.5, 4.0, 0.5, 4.0, 0.3, -2.0, 0.4, -2.0],
5, 5,
).unwrap();
let mut solver = AutoKktSolver::with_budget(5, 1);
solver.refactor(&k, None).expect("refactor (iterative)");
assert_eq!(solver.last_backend(), Some(KktBackend::Iterative));
let b = vec![1.0, 2.0, -1.0, 0.5, -0.5];
let mut sol = vec![0.0; 5];
solver.solve(&b, &mut sol, None).expect("solve");
let factor = crate::linalg::ldl::factorize_quasidefinite_with_amd(&k, None).unwrap();
let mut sol_ldl = vec![0.0; 5];
factor.solve(&b, &mut sol_ldl);
for i in 0..5 {
assert!(
(sol[i] - sol_ldl[i]).abs() < 1e-6,
"auto[{}]={} vs ldl[{}]={}", i, sol[i], i, sol_ldl[i]
);
}
}
#[test]
fn auto_remembers_iterative_after_first_overflow() {
let k1 = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, -1.0], 2, 2).unwrap();
let k2 = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[4.0, -2.0], 2, 2).unwrap();
let mut solver = AutoKktSolver::with_budget(2, 1); solver.refactor(&k1, None).unwrap();
assert_eq!(solver.last_backend(), Some(KktBackend::Iterative));
solver.refactor(&k2, None).unwrap();
assert_eq!(solver.last_backend(), Some(KktBackend::Iterative),
"should stay iterative after first overflow");
}
#[test]
fn factorize_kkt_chooses_direct_when_budget_sufficient() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let perm = crate::linalg::amd::amd_with_deadline(2, &k.col_ptr, &k.row_ind, None);
let cfg = KktConfig { max_l_nnz: 1000, ..Default::default() };
let factor = factorize_kkt_with_cached_perm_par(&k, &perm, None, &cfg, None, faer::Par::Seq)
.expect("factor should succeed");
assert!(matches!(factor, KktFactor::Direct(_)));
assert!(!factor.is_iterative());
let mut sol = vec![0.0; 2];
factor.solve(&[3.0, 0.0], &mut sol);
assert!((sol[0] - 1.0).abs() < 1e-9);
assert!((sol[1] - 1.0).abs() < 1e-9);
}
#[test]
fn factorize_kkt_chooses_iterative_when_budget_exceeded() {
let k = CscMatrix::from_triplets(
&[0, 0, 1, 1, 2], &[0, 1, 1, 2, 2],
&[4.0, 0.5, 4.0, 0.5, -2.0], 3, 3
).unwrap();
let perm = crate::linalg::amd::amd_with_deadline(3, &k.col_ptr, &k.row_ind, None);
let cfg = KktConfig { max_l_nnz: 1, ..Default::default() };
let factor = factorize_kkt_with_cached_perm_par(&k, &perm, None, &cfg, None, faer::Par::Seq)
.expect("factor should succeed (fallback)");
assert!(matches!(factor, KktFactor::Iterative(_)));
assert!(factor.is_iterative());
let b = vec![1.0, 2.0, -1.0];
let mut sol = vec![0.0; 3];
factor.solve(&b, &mut sol);
let factor_ldl = crate::linalg::ldl::factorize_quasidefinite_with_amd(&k, None).unwrap();
let mut sol_ldl = vec![0.0; 3];
factor_ldl.solve(&b, &mut sol_ldl);
let b_inf = b.iter().map(|v: &f64| v.abs()).fold(0.0_f64, f64::max);
let tol = MINRES_INEXACT_NEWTON_ETA * b_inf;
for i in 0..3 {
assert!(
(sol[i] - sol_ldl[i]).abs() < tol.max(1e-6),
"MINRES[{}]={} vs LDL[{}]={} (tol={})", i, sol[i], i, sol_ldl[i], tol
);
}
}
#[test]
fn minres_ir_actually_reduces_residual() {
let n = 20usize;
let m = 10usize;
let dim = n + m;
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
for i in 0..n {
rows.push(i); cols.push(i);
vals.push(1.0 + 1e-3 * (i as f64).sqrt());
}
for k in 0..m {
for j in 0..n {
let v = ((k * 7 + j * 13) % 17) as f64 / 17.0 - 0.5;
if v.abs() > 0.1 {
rows.push(j); cols.push(n + k); vals.push(v);
}
}
}
for i in 0..m {
rows.push(n + i); cols.push(n + i); vals.push(-1e-6 * (1.0 + i as f64));
}
let k = CscMatrix::from_triplets(&rows, &cols, &vals, dim, dim).unwrap();
let rhs: Vec<f64> = (0..dim).map(|i| ((i * 11) % 7) as f64 - 3.0).collect();
let rhs_norm = rhs.iter().fold(0.0_f64, |a, &v| a + v * v).sqrt();
let solver_no_ir = PreconditionedMinres::with_block_diag_inexact(k.clone(), n, 0.1, 0);
let mut sol_no_ir = vec![0.0_f64; dim];
let _ = solver_no_ir.solve(&rhs, &mut sol_no_ir, None);
let mut residual_no_ir = vec![0.0_f64; dim];
crate::linalg::minres::matvec_sym_upper(&k, &sol_no_ir, &mut residual_no_ir);
let r_no_ir: f64 = (0..dim)
.map(|i| (rhs[i] - residual_no_ir[i]).powi(2))
.sum::<f64>()
.sqrt();
let rel_no_ir = r_no_ir / rhs_norm;
let solver_ir2 = PreconditionedMinres::with_block_diag_inexact(k.clone(), n, 0.1, 2);
let mut sol_ir2 = vec![0.0_f64; dim];
let _ = solver_ir2.solve(&rhs, &mut sol_ir2, None);
let mut residual_ir2 = vec![0.0_f64; dim];
crate::linalg::minres::matvec_sym_upper(&k, &sol_ir2, &mut residual_ir2);
let r_ir2: f64 = (0..dim)
.map(|i| (rhs[i] - residual_ir2[i]).powi(2))
.sum::<f64>()
.sqrt();
let rel_ir2 = r_ir2 / rhs_norm;
eprintln!(
"MINRES IR check: n={} dim={} ||rhs||={:.3e} no_ir_rel={:.3e} ir2_rel={:.3e} ratio={:.3e}",
n, dim, rhs_norm, rel_no_ir, rel_ir2, rel_ir2 / rel_no_ir.max(1e-300)
);
assert!(
rel_ir2 < rel_no_ir / 5.0,
"MINRES IR is not reducing residual: no_ir={:.3e} ir2={:.3e} (ratio {:.2e}, expected < 0.2)",
rel_no_ir, rel_ir2, rel_ir2 / rel_no_ir.max(1e-300)
);
}
#[test]
fn auto_works_as_trait_object() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let mut solver: Box<dyn KktSolver> = Box::new(AutoKktSolver::new(2));
solver.refactor(&k, None).unwrap();
let mut sol = vec![0.0; 2];
solver.solve(&[3.0, 0.0], &mut sol, None).unwrap();
assert!((sol[0] - 1.0).abs() < 1e-9);
assert!((sol[1] - 1.0).abs() < 1e-9);
}
#[test]
fn kkt_config_default_matches_ipm_options_default() {
use crate::options::IpmOptions;
let o = IpmOptions::default();
let cfg = KktConfig {
dd_ldl: o.dd_ldl,
minres_ir: o.effective_minres_ir(),
max_l_nnz: o.effective_max_l_nnz(),
};
let dflt = KktConfig::default();
assert_eq!(cfg.dd_ldl, dflt.dd_ldl);
assert_eq!(cfg.minres_ir, dflt.minres_ir);
assert_eq!(cfg.max_l_nnz, dflt.max_l_nnz);
}
#[test]
fn factorize_kkt_dd_ldl_true_returns_direct_dd() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let perm = crate::linalg::amd::amd_with_deadline(2, &k.col_ptr, &k.row_ind, None);
let cfg = KktConfig { dd_ldl: true, max_l_nnz: 1_000_000, ..Default::default() };
let factor = factorize_kkt_with_cached_perm_par(&k, &perm, None, &cfg, None, faer::Par::Seq)
.expect("DD-LDL factor should succeed");
assert!(matches!(factor, KktFactor::DirectDd(_)), "expected DirectDd with dd_ldl=true");
let mut sol = vec![0.0; 2];
factor.solve(&[3.0, 0.0], &mut sol);
assert!((sol[0] - 1.0).abs() < 1e-9, "sol[0]={}", sol[0]);
assert!((sol[1] - 1.0).abs() < 1e-9, "sol[1]={}", sol[1]);
}
#[test]
fn factorize_kkt_dd_ldl_false_returns_direct() {
let k = CscMatrix::from_triplets(&[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, -1.0], 2, 2).unwrap();
let perm = crate::linalg::amd::amd_with_deadline(2, &k.col_ptr, &k.row_ind, None);
let cfg = KktConfig::default();
assert!(!cfg.dd_ldl);
let factor = factorize_kkt_with_cached_perm_par(&k, &perm, None, &cfg, None, faer::Par::Seq)
.expect("f64 LDL factor should succeed");
assert!(matches!(factor, KktFactor::Direct(_)), "expected Direct with dd_ldl=false");
}
}