#![allow(non_snake_case)]
use super::ldlsolvers::qdldl::*;
use super::*;
use crate::solver::core::kktsolvers::KKTSolver;
use crate::solver::core::{cones::*, CoreSettings};
type BoxedDirectLDLSolver<T> = Box<dyn DirectLDLSolver<T> + Send>;
pub struct DirectLDLKKTSolver<T> {
m: usize,
n: usize,
p: usize,
x: Vec<T>,
b: Vec<T>,
work1: Vec<T>,
work2: Vec<T>,
map: LDLDataMap,
dsigns: Vec<i8>,
Hsblocks: Vec<T>,
KKT: CscMatrix<T>,
ldlsolver: BoxedDirectLDLSolver<T>,
diagonal_regularizer: T,
}
impl<T> DirectLDLKKTSolver<T>
where
T: FloatT,
{
pub fn new(
P: &CscMatrix<T>,
A: &CscMatrix<T>,
cones: &CompositeCone<T>,
m: usize,
n: usize,
settings: &CoreSettings<T>,
) -> Self {
let p = 2 * cones.type_count(SupportedConeTag::SecondOrderCone);
let x = vec![T::zero(); n + m + p];
let b = vec![T::zero(); n + m + p];
let work1 = vec![T::zero(); n + m + p];
let work2 = vec![T::zero(); n + m + p];
let mut dsigns = vec![1_i8; n + m + p];
_fill_signs(&mut dsigns, m, n, p);
let Hsblocks = allocate_kkt_Hsblocks::<T, T>(cones);
let (kktshape, ldl_ctor) = _get_ldlsolver_config(settings);
let (KKT, map) = assemble_kkt_matrix(P, A, cones, kktshape);
let diagonal_regularizer = T::zero();
let ldlsolver = ldl_ctor(&KKT, &dsigns, settings);
Self {
m,
n,
p,
x,
b,
work1,
work2,
map,
dsigns,
Hsblocks,
KKT,
ldlsolver,
diagonal_regularizer,
}
}
}
impl<T> KKTSolver<T> for DirectLDLKKTSolver<T>
where
T: FloatT,
{
fn update(&mut self, cones: &CompositeCone<T>, settings: &CoreSettings<T>) -> bool {
let map = &self.map;
cones.get_Hs(&mut self.Hsblocks);
let (values, index) = (&mut self.Hsblocks, &map.Hsblocks);
values.negate();
_update_values(&mut self.ldlsolver, &mut self.KKT, index, values);
let mut cidx = 0;
for cone in cones.iter() {
if let SupportedCone::SecondOrderCone(soc) = cone {
let η2 = T::powi(soc.η, 2);
let KKT = &mut self.KKT;
let ldlsolver = &mut self.ldlsolver;
_update_values(ldlsolver, KKT, &map.SOC_u[cidx], &soc.u);
_update_values(ldlsolver, KKT, &map.SOC_v[cidx], &soc.v);
_scale_values(ldlsolver, KKT, &map.SOC_u[cidx], -η2);
_scale_values(ldlsolver, KKT, &map.SOC_v[cidx], -η2);
_update_values(ldlsolver, KKT, &[map.SOC_D[cidx * 2]], &[-η2; 1]);
_update_values(ldlsolver, KKT, &[map.SOC_D[cidx * 2 + 1]], &[η2; 1]);
cidx += 1;
} }
self.regularize_and_refactor(settings)
}
fn setrhs(&mut self, rhsx: &[T], rhsz: &[T]) {
let (m, n, p) = (self.m, self.n, self.p);
self.b[0..n].copy_from(rhsx);
self.b[n..(n + m)].copy_from(rhsz);
self.b[n + m..(n + m + p)].fill(T::zero());
}
fn solve(
&mut self,
lhsx: Option<&mut [T]>,
lhsz: Option<&mut [T]>,
settings: &CoreSettings<T>,
) -> bool {
self.ldlsolver.solve(&mut self.x, &self.b);
let is_success = {
if settings.iterative_refinement_enable {
self.iterative_refinement(settings)
} else {
self.x.is_finite()
}
};
if is_success {
self.getlhs(lhsx, lhsz);
}
is_success
}
}
impl<T> DirectLDLKKTSolver<T>
where
T: FloatT,
{
fn getlhs(&self, lhsx: Option<&mut [T]>, lhsz: Option<&mut [T]>) {
let x = &self.x;
let (m, n, _p) = (self.m, self.n, self.p);
if let Some(v) = lhsx {
v.copy_from(&x[0..n]);
}
if let Some(v) = lhsz {
v.copy_from(&x[n..(n + m)]);
}
}
fn regularize_and_refactor(&mut self, settings: &CoreSettings<T>) -> bool {
let map = &self.map;
let KKT = &mut self.KKT;
let dsigns = &self.dsigns;
let diag_kkt = &mut self.work1;
let diag_shifted = &mut self.work2;
if settings.static_regularization_enable {
for (d, idx) in diag_kkt.iter_mut().zip(map.diag_full.iter()) {
*d = KKT.nzval[*idx]
}
let eps = _compute_regularizer(diag_kkt, settings);
diag_shifted.copy_from(diag_kkt);
diag_shifted
.iter_mut()
.zip(dsigns.iter())
.for_each(|(shift, &sign)| {
if sign == 1 {
*shift += eps;
} else {
*shift -= eps;
}
});
_update_values(&mut self.ldlsolver, KKT, &map.diag_full, diag_shifted);
self.diagonal_regularizer = eps;
}
let is_success = self.ldlsolver.refactor(KKT);
if settings.static_regularization_enable {
_update_values_KKT(KKT, &map.diag_full, diag_kkt);
}
is_success
}
fn iterative_refinement(&mut self, settings: &CoreSettings<T>) -> bool {
let (x, b) = (&mut self.x, &self.b);
let (e, dx) = (&mut self.work1, &mut self.work2);
let reltol = settings.iterative_refinement_reltol;
let abstol = settings.iterative_refinement_abstol;
let maxiter = settings.iterative_refinement_max_iter;
let stopratio = settings.iterative_refinement_stop_ratio;
let K = &self.KKT;
let normb = b.norm_inf();
let mut norme = _get_refine_error(e, b, K, x);
for _ in 0..maxiter {
if !norme.is_finite() {
return false;
}
if norme <= (abstol + reltol * normb) {
break;
}
let lastnorme = norme;
self.ldlsolver.solve(dx, e);
dx.axpby(T::one(), x, T::one()); norme = _get_refine_error(e, b, K, dx);
if lastnorme / norme < stopratio {
break;
} else {
std::mem::swap(x, dx);
}
}
true
}
}
fn _compute_regularizer<T: FloatT>(diag_kkt: &[T], settings: &CoreSettings<T>) -> T {
let maxdiag = diag_kkt.norm_inf();
settings.static_regularization_constant + settings.static_regularization_proportional * maxdiag
}
fn _get_refine_error<T: FloatT>(e: &mut [T], b: &[T], K: &CscMatrix<T>, ξ: &mut [T]) -> T {
e.copy_from(b);
K.symv(e, MatrixTriangle::Triu, ξ, -T::one(), T::one());
e.norm_inf()
}
type LDLConstructor<T> = fn(&CscMatrix<T>, &[i8], &CoreSettings<T>) -> BoxedDirectLDLSolver<T>;
fn _get_ldlsolver_config<T: FloatT>(
settings: &CoreSettings<T>,
) -> (MatrixTriangle, LDLConstructor<T>)
where
T: FloatT,
{
let ldlptr: LDLConstructor<T>;
let kktshape: MatrixTriangle;
match settings.direct_solve_method.as_str() {
"qdldl" => {
kktshape = QDLDLDirectLDLSolver::<T>::required_matrix_shape();
ldlptr = |M, D, S| Box::new(QDLDLDirectLDLSolver::<T>::new(M, D, S));
}
"custom" => {
unimplemented!();
}
_ => {
panic! {"Unrecognized LDL solver type"};
}
}
(kktshape, ldlptr)
}
fn _update_values<T: FloatT>(
ldlsolver: &mut BoxedDirectLDLSolver<T>,
KKT: &mut CscMatrix<T>,
index: &[usize],
values: &[T],
) {
_update_values_KKT(KKT, index, values);
ldlsolver.update_values(index, values);
}
fn _update_values_KKT<T: FloatT>(KKT: &mut CscMatrix<T>, index: &[usize], values: &[T]) {
for (idx, v) in index.iter().zip(values.iter()) {
KKT.nzval[*idx] = *v;
}
}
fn _scale_values<T: FloatT>(
ldlsolver: &mut BoxedDirectLDLSolver<T>,
KKT: &mut CscMatrix<T>,
index: &[usize],
scale: T,
) {
_scale_values_KKT(KKT, index, scale);
ldlsolver.scale_values(index, scale);
}
fn _scale_values_KKT<T: FloatT>(KKT: &mut CscMatrix<T>, index: &[usize], scale: T) {
for idx in index.iter() {
KKT.nzval[*idx] *= scale;
}
}
fn _fill_signs(signs: &mut [i8], m: usize, n: usize, p: usize) {
signs.fill(1);
signs[n..(n + m)].iter_mut().for_each(|x| *x = -*x);
signs[(n + m)..(n + m + p)]
.iter_mut()
.step_by(2)
.for_each(|x| *x = -*x);
}