#![allow(non_snake_case)]
use super::ldlsolvers::config::LDLConfiguration;
use super::*;
use crate::solver::core::kktsolvers::{HasLinearSolverInfo, KKTSolver, LinearSolverInfo};
use crate::solver::core::{cones::*, CoreSettings};
use std::iter::zip;
pub(crate) type BoxedDirectLDLSolver<T> = Box<dyn DirectLDLSolver<T> + Send + Sync>;
pub struct DirectLDLKKTSolver<T> {
m: usize,
n: usize,
p: usize,
x: Vec<T>,
b: Vec<T>,
work1: Vec<T>,
work2: Vec<T>,
map: LDLDataMap,
dsigns: Vec<i8>,
Hsblocks: Vec<T>,
KKT: CscMatrix<T>,
KKTuplo: MatrixTriangle,
ldlsolver: BoxedDirectLDLSolver<T>,
diagonal_regularizer: T,
}
impl<T> DirectLDLKKTSolver<T>
where
T: FloatT,
{
pub fn new(
P: &CscMatrix<T>,
A: &CscMatrix<T>,
cones: &CompositeCone<T>,
m: usize,
n: usize,
settings: &CoreSettings<T>,
) -> Self {
let (kktshape, ldl_ctor) = T::get_ldlsolver_config(settings);
let (KKT, map) = assemble_kkt_matrix(P, A, cones, kktshape);
let p = map.sparse_maps.pdim();
let x = vec![T::zero(); n + m + p];
let b = vec![T::zero(); n + m + p];
let work1 = vec![T::zero(); n + m + p];
let work2 = vec![T::zero(); n + m + p];
let mut dsigns = vec![1_i8; n + m + p];
_fill_signs(&mut dsigns, m, n, &map);
let Hsblocks = allocate_kkt_Hsblocks::<T, T>(cones);
let diagonal_regularizer = T::zero();
let ldlsolver = ldl_ctor(&KKT, &dsigns, settings, None);
Self {
m,
n,
p,
x,
b,
work1,
work2,
map,
dsigns,
Hsblocks,
KKT,
KKTuplo: kktshape,
ldlsolver,
diagonal_regularizer,
}
}
}
impl<T> HasLinearSolverInfo for DirectLDLKKTSolver<T>
where
T: FloatT,
{
fn linear_solver_info(&self) -> LinearSolverInfo {
self.ldlsolver.linear_solver_info()
}
}
impl<T> KKTSolver<T> for DirectLDLKKTSolver<T>
where
T: FloatT,
{
fn update(&mut self, cones: &CompositeCone<T>, settings: &CoreSettings<T>) -> bool {
let map = &self.map;
cones.get_Hs(&mut self.Hsblocks);
let (values, index) = (&mut self.Hsblocks, &map.Hsblocks);
values.negate();
_update_values(&mut self.ldlsolver, &mut self.KKT, index, values);
let mut sparse_map_iter = map.sparse_maps.iter();
let ldl = &mut self.ldlsolver;
let KKT = &mut self.KKT;
for cone in cones.iter() {
if cone.is_sparse_expandable() {
let sc = cone.to_sparse_expansion().unwrap();
let thismap = sparse_map_iter.next().unwrap();
sc.csc_update_sparsecone(thismap, ldl, KKT, _update_values, _scale_values);
}
}
self.regularize_and_refactor(settings)
}
fn setrhs(&mut self, rhsx: &[T], rhsz: &[T]) {
let (m, n, p) = (self.m, self.n, self.p);
self.b[0..n].copy_from(rhsx);
self.b[n..(n + m)].copy_from(rhsz);
self.b[n + m..(n + m + p)].fill(T::zero());
}
fn solve(
&mut self,
lhsx: Option<&mut [T]>,
lhsz: Option<&mut [T]>,
settings: &CoreSettings<T>,
) -> bool {
self.ldlsolver.solve(&self.KKT, &mut self.x, &mut self.b);
let is_success = {
if settings.iterative_refinement_enable {
self.iterative_refinement(settings)
} else {
self.x.is_finite()
}
};
if is_success {
self.getlhs(lhsx, lhsz);
}
is_success
}
fn update_P(&mut self, P: &CscMatrix<T>) {
_update_values(&mut self.ldlsolver, &mut self.KKT, &self.map.P, &P.nzval);
}
fn update_A(&mut self, A: &CscMatrix<T>) {
_update_values(&mut self.ldlsolver, &mut self.KKT, &self.map.A, &A.nzval);
}
}
impl<T> DirectLDLKKTSolver<T>
where
T: FloatT,
{
fn getlhs(&self, lhsx: Option<&mut [T]>, lhsz: Option<&mut [T]>) {
let x = &self.x;
let (m, n) = (self.m, self.n);
if let Some(v) = lhsx {
v.copy_from(&x[0..n]);
}
if let Some(v) = lhsz {
v.copy_from(&x[n..(n + m)]);
}
}
fn regularize_and_refactor(&mut self, settings: &CoreSettings<T>) -> bool {
let map = &self.map;
let KKT = &mut self.KKT;
let dsigns = &self.dsigns;
let diag_kkt = &mut self.work1;
let diag_shifted = &mut self.work2;
if settings.static_regularization_enable {
for (d, idx) in zip(&mut *diag_kkt, &map.diag_full) {
*d = KKT.nzval[*idx];
}
let eps = _compute_regularizer(diag_kkt, settings);
diag_shifted.copy_from(diag_kkt);
zip(&mut *diag_shifted, dsigns).for_each(|(shift, &sign)| {
if sign == 1 {
*shift += eps;
} else {
*shift -= eps;
}
});
_update_values(&mut self.ldlsolver, KKT, &map.diag_full, diag_shifted);
self.diagonal_regularizer = eps;
}
let is_success = self.ldlsolver.refactor(KKT);
if settings.static_regularization_enable {
_update_values_KKT(KKT, &map.diag_full, diag_kkt);
}
is_success
}
fn iterative_refinement(&mut self, settings: &CoreSettings<T>) -> bool {
let (x, b) = (&mut self.x, &self.b);
let (e, dx) = (&mut self.work1, &mut self.work2);
let reltol = settings.iterative_refinement_reltol;
let abstol = settings.iterative_refinement_abstol;
let maxiter = settings.iterative_refinement_max_iter;
let stopratio = settings.iterative_refinement_stop_ratio;
let KKT = &self.KKT;
let KKTsym = KKT.sym(self.KKTuplo);
let normb = b.norm_inf();
let mut norme = _get_refine_error(e, b, &KKTsym, x);
if !norme.is_finite() {
return false;
}
for _ in 0..maxiter {
if norme <= (abstol + reltol * normb) {
break;
}
let lastnorme = norme;
self.ldlsolver.solve(KKT, dx, e);
dx.axpby(T::one(), x, T::one());
norme = _get_refine_error(e, b, &KKTsym, dx);
if !norme.is_finite() {
return false;
}
let improved_ratio = lastnorme / norme;
if improved_ratio < stopratio {
if improved_ratio > T::one() {
std::mem::swap(x, dx);
}
break;
}
std::mem::swap(x, dx);
}
true
}
}
fn _compute_regularizer<T: FloatT>(diag_kkt: &[T], settings: &CoreSettings<T>) -> T {
let maxdiag = diag_kkt.norm_inf();
settings.static_regularization_constant + settings.static_regularization_proportional * maxdiag
}
fn _get_refine_error<T: FloatT>(
e: &mut [T],
b: &[T],
KKTsym: &Symmetric<CscMatrix<T>>,
ξ: &mut [T],
) -> T {
e.copy_from(b);
KKTsym.symv(e, ξ, -T::one(), T::one());
e.norm_inf()
}
fn _update_values<T: FloatT>(
ldlsolver: &mut BoxedDirectLDLSolver<T>,
KKT: &mut CscMatrix<T>,
index: &[usize],
values: &[T],
) {
_update_values_KKT(KKT, index, values);
ldlsolver.update_values(index, values);
}
fn _update_values_KKT<T: FloatT>(KKT: &mut CscMatrix<T>, index: &[usize], values: &[T]) {
for (idx, v) in zip(index, values) {
KKT.nzval[*idx] = *v;
}
}
fn _scale_values<T: FloatT>(
ldlsolver: &mut BoxedDirectLDLSolver<T>,
KKT: &mut CscMatrix<T>,
index: &[usize],
scale: T,
) {
_scale_values_KKT(KKT, index, scale);
ldlsolver.scale_values(index, scale);
}
fn _scale_values_KKT<T: FloatT>(KKT: &mut CscMatrix<T>, index: &[usize], scale: T) {
for idx in index.iter() {
KKT.nzval[*idx] *= scale;
}
}
fn _fill_signs(signs: &mut [i8], m: usize, n: usize, map: &LDLDataMap) {
signs.fill(1);
signs[n..(n + m)].iter_mut().for_each(|x| *x = -*x);
let mut p = m + n;
for thismap in map.sparse_maps.iter() {
let thisp = thismap.pdim();
signs[p..(p + thisp)].copy_from_slice(thismap.Dsigns());
p += thisp;
}
}