#![allow(non_snake_case)]
use super::datamap::*;
use crate::algebra::*;
use crate::solver::core::cones::CompositeCone;
use crate::solver::core::cones::*;
use num_traits::Zero;
pub(crate) fn allocate_kkt_Hsblocks<T, Z>(cones: &CompositeCone<T>) -> Vec<Z>
where
T: FloatT,
Z: Zero + Clone,
{
let mut nnz = 0;
if let Some(rng_last) = cones.rng_blocks.last() {
nnz = (*rng_last).end;
}
vec![Z::zero(); nnz]
}
pub fn assemble_kkt_matrix<T: FloatT>(
P: &CscMatrix<T>,
A: &CscMatrix<T>,
cones: &CompositeCone<T>,
shape: MatrixTriangle,
) -> (CscMatrix<T>, LDLDataMap) {
let (m, n) = (A.nrows(), P.nrows());
let n_socs = cones.type_count(SupportedConeTag::SecondOrderCone);
let p = 2 * n_socs;
let mut maps = LDLDataMap::new(P, A, cones);
let nnz_diagP = P.count_diagonal_entries();
let nnz_Hsblocks = maps.Hsblocks.len();
let nnz_SOC_vecs = 2 * maps.SOC_u.iter().fold(0, |acc, block| acc + block.len());
let nnz_SOC_ext = maps.SOC_D.len();
let nnzKKT = P.nnz() + n - nnz_diagP + A.nnz() + nnz_Hsblocks + nnz_SOC_vecs + nnz_SOC_ext;
let mut K = CscMatrix::<T>::spalloc(m + n + p, m + n + p, nnzKKT);
_kkt_assemble_colcounts(&mut K, P, A, cones, (m, n, p), shape);
_kkt_assemble_fill(&mut K, &mut maps, P, A, cones, (m, n, p), shape);
(K, maps)
}
fn _kkt_assemble_colcounts<T: FloatT>(
K: &mut CscMatrix<T>,
P: &CscMatrix<T>,
A: &CscMatrix<T>,
cones: &CompositeCone<T>,
mnp: (usize, usize, usize),
shape: MatrixTriangle,
) {
let (m, n, p) = (mnp.0, mnp.1, mnp.2);
K.colptr.fill(0);
match shape {
MatrixTriangle::Triu => {
K.colcount_block(P, 0, MatrixShape::N);
K.colcount_missing_diag(P, 0);
K.colcount_block(A, n, MatrixShape::T);
}
MatrixTriangle::Tril => {
K.colcount_missing_diag(P, 0);
K.colcount_block(P, 0, MatrixShape::T);
K.colcount_block(A, 0, MatrixShape::N);
}
}
for (i, cone) in cones.iter().enumerate() {
let firstcol = cones.rng_cones[i].start + n;
let blockdim = cone.numel();
if cone.Hs_is_diagonal() {
K.colcount_diag(firstcol, blockdim);
} else {
K.colcount_dense_triangle(firstcol, blockdim, shape);
}
}
let mut socidx = 0;
for (i, cone) in cones.iter().enumerate() {
if let SupportedCone::SecondOrderCone(SOC) = cone {
let nvars = SOC.numel();
let headidx = cones.rng_cones[i].start;
let col = m + n + 2 * socidx;
match shape {
MatrixTriangle::Triu => {
K.colcount_colvec(nvars, headidx + n, col); K.colcount_colvec(nvars, headidx + n, col + 1); }
MatrixTriangle::Tril => {
K.colcount_rowvec(nvars, col, headidx + n); K.colcount_rowvec(nvars, col + 1, headidx + n); }
}
socidx += 1;
}
}
K.colcount_diag(n + m, p);
}
fn _kkt_assemble_fill<T: FloatT>(
K: &mut CscMatrix<T>,
maps: &mut LDLDataMap,
P: &CscMatrix<T>,
A: &CscMatrix<T>,
cones: &CompositeCone<T>,
mnp: (usize, usize, usize),
shape: MatrixTriangle,
) {
let (m, n, p) = (mnp.0, mnp.1, mnp.2);
K.colcount_to_colptr();
match shape {
MatrixTriangle::Triu => {
K.fill_block(P, &mut maps.P, 0, 0, MatrixShape::N);
K.fill_missing_diag(P, 0); K.fill_block(A, &mut maps.A, 0, n, MatrixShape::T);
}
MatrixTriangle::Tril => {
K.fill_missing_diag(P, 0); K.fill_block(P, &mut maps.P, 0, 0, MatrixShape::T);
K.fill_block(A, &mut maps.A, n, 0, MatrixShape::N);
}
}
for (i, (cone, rng_cone)) in cones.iter().zip(cones.rng_cones.iter()).enumerate() {
let firstcol = rng_cone.start + n;
let blockdim = cone.numel();
let block = &mut maps.Hsblocks[cones.rng_blocks[i].clone()];
if cone.Hs_is_diagonal() {
K.fill_diag(block, firstcol, blockdim);
} else {
K.fill_dense_triangle(block, firstcol, blockdim, shape);
}
}
let mut socidx = 0;
for (i, cone) in cones.iter().enumerate() {
if let SupportedCone::SecondOrderCone(_) = cone {
let headidx = cones.rng_cones[i].start;
let col = m + n + 2 * socidx;
match shape {
MatrixTriangle::Triu => {
K.fill_colvec(&mut maps.SOC_v[socidx], headidx + n, col); K.fill_colvec(&mut maps.SOC_u[socidx], headidx + n, col + 1);
}
MatrixTriangle::Tril => {
K.fill_rowvec(&mut maps.SOC_v[socidx], col, headidx + n); K.fill_rowvec(&mut maps.SOC_u[socidx], col + 1, headidx + n);
}
}
socidx += 1;
}
}
K.fill_diag(&mut maps.SOC_D, n + m, p);
K.backshift_colptrs();
match shape {
MatrixTriangle::Triu => {
maps.diag_full.copy_from_slice(&K.colptr[1..]);
maps.diag_full.iter_mut().for_each(|x| *x -= 1);
maps.diagP.copy_from_slice(&K.colptr[1..=n]);
maps.diagP.iter_mut().for_each(|x| *x -= 1);
}
MatrixTriangle::Tril => {
maps.diag_full
.copy_from_slice(&K.colptr[0..K.colptr.len() - 1]);
maps.diagP.copy_from_slice(&K.colptr[0..n]);
}
}
}