#![allow(non_snake_case)]
#![allow(clippy::too_many_arguments)]
use crate::solver::chordal::ConeMapEntry;
use crate::solver::chordal::SparsityPattern;
use crate::solver::core::cones::*;
use crate::{
algebra::*,
solver::{
chordal::{ChordalInfo, SuperNodeTree, VertexSet},
SupportedConeT::*,
},
};
use std::cmp::{max, min};
use std::ops::Range;
use std::{collections::HashMap, iter::zip};
type BlockOverlapTriplet = (usize, usize, bool);
impl<T> ChordalInfo<T>
where
T: FloatT,
{
#[allow(clippy::type_complexity)]
pub(crate) fn decomp_augment_compact(
&mut self,
P: &CscMatrix<T>,
q: &[T],
A: &CscMatrix<T>,
b: &[T],
) -> (
CscMatrix<T>,
Vec<T>,
CscMatrix<T>,
Vec<T>,
Vec<SupportedConeT<T>>,
) {
let (A_new, b_new, cones_new) = self.find_compact_A_b_and_cones(A, b);
let nadd = A_new.n - A.n;
let P_new = CscMatrix::blockdiag(&[P, &CscMatrix::zeros((nadd, nadd))]).unwrap();
let mut q_new = vec![T::zero(); q.len() + nadd];
q_new[0..q.len()].copy_from(q);
(P_new, q_new, A_new, b_new, cones_new)
}
fn find_compact_A_b_and_cones(
&mut self,
A: &CscMatrix<T>,
b: &[T],
) -> (CscMatrix<T>, Vec<T>, Vec<SupportedConeT<T>>) {
let cones = &self.init_cones;
let (Aa_m, Aa_n, n_overlaps) = self.find_A_dimension(A);
let Aa_nnz = A.nnz() + 2 * n_overlaps;
let mut Aa_I = vec![usize::MAX; Aa_nnz]; let mut Aa_J = extra_columns(Aa_nnz, A.nnz(), A.n);
let mut Aa_V = alternating_sequence::<T>(Aa_nnz, A.nnz());
findnz(&mut Aa_J, &mut Aa_V, A);
let bs = SparseVector::new(b);
let mut ba_I = vec![usize::MAX; bs.nzval.len()];
let n_decomposed = self.final_cone_count();
let mut cones_new = Vec::with_capacity(n_decomposed);
let mut cone_maps = Vec::with_capacity(n_decomposed);
let mut patterns_iter = self.spatterns.iter().peekable();
let mut patterns_count = 0..self.spatterns.len();
let row_ranges = cones.rng_cones_iter();
let mut row_ptr = 0; let mut overlap_ptr = A.nnz();
for (coneidx, (cone, row_range)) in zip(cones, row_ranges).enumerate() {
if patterns_iter.len() != 0 && patterns_iter.peek().unwrap().orig_index == coneidx {
assert!(matches!(cone, SupportedConeT::PSDTriangleConeT(_)));
(row_ptr, overlap_ptr) = add_entries_with_sparsity_pattern(
&mut Aa_I,
&mut ba_I,
&mut cones_new,
&mut cone_maps,
A,
&bs,
row_range,
patterns_iter.next().unwrap(),
patterns_count.next().unwrap(),
row_ptr,
overlap_ptr,
);
} else {
(row_ptr, overlap_ptr) = add_entries_with_cone(
&mut Aa_I,
&mut ba_I,
&mut cones_new,
&mut cone_maps,
A,
&bs,
row_range,
cone,
row_ptr,
overlap_ptr,
);
}
}
self.cone_maps = Some(cone_maps);
let A_new = CscMatrix::new_from_triplets(Aa_m, Aa_n, Aa_I, Aa_J, Aa_V);
let b_new = SparseVector {
nzind: ba_I,
nzval: bs.nzval,
n: Aa_m,
}
.into();
(A_new, b_new, cones_new)
}
fn find_A_dimension(&self, A: &CscMatrix<T>) -> (usize, usize, usize) {
let (dim, num_overlaps) = self.get_decomposed_dim_and_overlaps();
let rows = dim;
let cols = A.n + num_overlaps;
(rows, cols, num_overlaps)
}
}
fn add_entries_with_cone<T>(
Aa_I: &mut [usize],
ba_I: &mut [usize],
cones_new: &mut Vec<SupportedConeT<T>>,
cone_maps: &mut Vec<ConeMapEntry>,
A: &CscMatrix<T>,
b: &SparseVector<T>,
row_range: Range<usize>,
cone: &SupportedConeT<T>,
row_ptr: usize,
overlap_ptr: usize,
) -> (usize, usize)
where
T: FloatT,
{
let n = A.n;
let offset = (row_ptr as isize) - (row_range.start as isize);
let row_range_col = get_rows_vec(b, row_range.clone());
if let Some(row_range_col) = row_range_col {
for k in row_range_col {
ba_I[k] = b.nzind[k].checked_add_signed(offset).unwrap();
}
}
for col in 0..n {
let row_range_col = get_rows_mat(A, col, row_range.clone());
if let Some(row_range_col) = row_range_col {
for k in row_range_col {
Aa_I[k] = A.rowval[k].checked_add_signed(offset).unwrap();
}
}
}
cones_new.push(cone.clone());
let orig_index = {
if cone_maps.is_empty() {
0
} else {
cone_maps.last().unwrap().orig_index + 1
}
};
cone_maps.push(ConeMapEntry {
orig_index,
tree_and_clique: None,
});
(row_ptr + cone.nvars(), overlap_ptr)
}
fn add_entries_with_sparsity_pattern<T>(
A_I: &mut [usize],
b_I: &mut [usize],
cones_new: &mut Vec<SupportedConeT<T>>,
cone_maps: &mut Vec<ConeMapEntry>,
A: &CscMatrix<T>,
b: &SparseVector<T>,
row_range: Range<usize>,
spattern: &SparsityPattern,
spattern_index: usize,
row_ptr: usize,
overlap_ptr: usize,
) -> (usize, usize)
where
T: FloatT,
{
let mut row_ptr = row_ptr;
let mut overlap_ptr = overlap_ptr;
let sntree = &spattern.sntree;
let ordering = &spattern.ordering;
let (_, n) = A.size();
let clique_to_rows = clique_rows_map(row_ptr, sntree);
for i in (0..sntree.n_cliques).rev() {
let mut separator: Vec<usize> = sntree
.get_separators(i)
.iter()
.map(|&v| spattern.ordering[v])
.collect();
let mut snode: Vec<usize> = sntree
.get_snode(i)
.iter()
.map(|&v| spattern.ordering[v])
.collect();
separator.sort();
snode.sort();
let block_indices = get_block_indices(&snode, &separator, ordering.len());
let parent_rows;
let mut parent_clique;
if i == (sntree.n_cliques - 1) {
parent_rows = 0..0;
parent_clique = vec![];
} else {
let parent_index = sntree.get_clique_parent(i);
parent_rows = clique_to_rows.get(&parent_index).unwrap().clone();
parent_clique = get_clique_by_index(sntree, parent_index)
.iter()
.map(|&v| spattern.ordering[v])
.collect();
parent_clique.sort();
}
for col in 0..n {
let row_range_col = get_rows_mat(A, col, row_range.clone()).unwrap_or(0..0);
let row_range_b = {
if col == 0 {
get_rows_vec(b, row_range.clone()).unwrap_or(0..0)
} else {
0..0
}
};
overlap_ptr = add_clique_entries(
A_I,
b_I,
&A.rowval,
&b.nzind,
&block_indices,
&parent_clique,
parent_rows.clone(),
col,
row_ptr,
overlap_ptr,
row_range.clone(),
row_range_col.clone(),
row_range_b.clone(),
);
}
let cone_dim = sntree.get_nblk(i);
cones_new.push(PSDTriangleConeT(cone_dim));
cone_maps.push(ConeMapEntry {
orig_index: spattern.orig_index,
tree_and_clique: Some((spattern_index, i)),
});
row_ptr += triangular_number(cone_dim);
}
(row_ptr, overlap_ptr)
}
fn add_clique_entries(
A_I: &mut [usize],
b_I: &mut [usize],
A_rowval: &[usize],
b_nzind: &[usize],
block_indices: &[BlockOverlapTriplet],
parent_clique: &[usize],
parent_rows: Range<usize>,
col: usize,
row_ptr: usize,
overlap_ptr: usize,
row_range: Range<usize>,
row_range_col: Range<usize>,
row_range_b: Range<usize>,
) -> usize {
let mut overlap_ptr = overlap_ptr;
for (counter, &block_idx) in block_indices.iter().enumerate() {
let new_row_val = row_ptr + counter;
let (i, j, is_overlap) = block_idx;
if is_overlap {
if col == 0 {
A_I[overlap_ptr] = new_row_val;
A_I[overlap_ptr + 1] =
parent_rows.start + parent_block_indices(parent_clique, i, j);
overlap_ptr += 2;
}
} else {
let k = coord_to_upper_triangular_index((i, j));
modify_clique_rows(
A_I,
k,
A_rowval,
new_row_val,
row_range.clone(),
row_range_col.clone(),
);
if col == 0 {
modify_clique_rows(
b_I,
k,
b_nzind,
new_row_val,
row_range.clone(),
row_range_b.clone(),
);
}
}
}
overlap_ptr
}
fn modify_clique_rows(
v: &mut [usize],
k: usize,
rowval: &[usize],
new_row_val: usize,
row_range: Range<usize>,
row_range_col: Range<usize>,
) {
let row_0 = get_row_index(k, rowval, row_range, row_range_col);
if let Some(row_0) = row_0 {
v[row_0] = new_row_val;
}
}
fn get_row_index(
k: usize,
rowval: &[usize],
row_range: Range<usize>,
row_range_col: Range<usize>,
) -> Option<usize> {
if row_range_col.clone().eq(0..0) {
return None;
}
let k_shift = row_range.start + k;
let u = min(row_range_col.end, row_range_col.start + k_shift + 1);
let l = row_range_col.start;
let r = rowval[l..u].partition_point(|&y| y < k_shift) + l;
if r >= u || rowval[r] != k_shift {
None
} else {
Some(r)
}
}
fn parent_block_indices(parent_clique: &[usize], i: usize, j: usize) -> usize {
let ir = parent_clique.partition_point(|&x| x < i); let jr = parent_clique.partition_point(|&x| x < j); coord_to_upper_triangular_index((ir, jr))
}
fn get_block_indices(snode: &[usize], separator: &[usize], nv: usize) -> Vec<BlockOverlapTriplet> {
let N = separator.len() + snode.len();
let mut block_indices = Vec::<BlockOverlapTriplet>::with_capacity(triangular_number(N));
for &j in separator.iter() {
for &i in separator.iter() {
if i <= j {
block_indices.push((i, j, true));
}
}
}
for &j in snode.iter() {
for &i in snode.iter() {
if i <= j {
block_indices.push((i, j, false));
}
}
}
for &i in snode {
for &j in separator {
block_indices.push((min(i, j), max(i, j), false));
}
}
block_indices.sort_by_cached_key(|x| x.1 * nv + x.0);
block_indices
}
fn clique_rows_map(row_start: usize, sntree: &SuperNodeTree) -> HashMap<usize, Range<usize>> {
let n_cliques = sntree.n_cliques;
let mut row_start = row_start;
let mut out = HashMap::with_capacity(n_cliques);
for i in (0..n_cliques).rev() {
let num_rows = triangular_number(sntree.get_nblk(i));
let rows = row_start..(row_start + num_rows);
let indx = sntree.snode_post[i];
out.insert(indx, rows);
row_start += num_rows;
}
out
}
fn get_rows_subset(rows: &[usize], row_range: Range<usize>) -> Option<Range<usize>> {
if rows.is_empty() || row_range.is_empty() {
return None;
}
if *rows.last().unwrap() < row_range.start {
return None;
}
if *rows.first().unwrap() >= row_range.end {
return None;
}
let s = rows.partition_point(|&i| i < row_range.start);
let e = rows.partition_point(|&i| i < row_range.end);
Some(s..e)
}
fn get_rows_vec<T>(b: &SparseVector<T>, row_range: Range<usize>) -> Option<Range<usize>>
where
T: FloatT,
{
get_rows_subset(&b.nzind, row_range)
}
fn get_rows_mat<T>(A: &CscMatrix<T>, col: usize, row_range: Range<usize>) -> Option<Range<usize>>
where
T: FloatT,
{
let colrange = A.colptr[col]..(A.colptr[col + 1]);
let rows = &A.rowval[colrange.clone()];
let se = get_rows_subset(rows, row_range);
se.map(|se| (colrange.start + se.start)..(colrange.start + se.end))
}
fn alternating_sequence<T>(total_length: usize, n_start: usize) -> Vec<T>
where
T: FloatT,
{
let mut v = vec![T::one(); total_length];
for i in ((n_start + 1)..v.len()).step_by(2) {
v[i] = -T::one();
}
v
}
fn extra_columns(total_length: usize, n_start: usize, start_val: usize) -> Vec<usize> {
let mut v = vec![0; total_length];
let mut start_val = start_val;
for i in (n_start..(v.len() - 1)).step_by(2) {
v[i] = start_val;
v[i + 1] = start_val;
start_val += 1;
}
v
}
fn findnz<T>(J: &mut [usize], V: &mut [T], S: &CscMatrix<T>)
where
T: FloatT,
{
let mut count = 0;
for col in 0..S.n {
for k in S.colptr[col]..S.colptr[col + 1] {
J[count] = col;
V[count] = S.nzval[k];
count += 1
}
}
}
fn get_clique_by_index(sntree: &SuperNodeTree, i: usize) -> VertexSet {
let mut out = VertexSet::new();
out.extend(&sntree.snode[i]);
out.extend(&sntree.separators[i]);
out
}
#[test]
fn test_alternating_sequence() {
let Annz = 2;
let n_overlaps = 2;
let Aa_nnz = Annz + 2 * n_overlaps;
let Aa_V = alternating_sequence::<f64>(Aa_nnz, Annz);
assert_eq!(Aa_V, vec![1., 1., 1., -1., 1., -1.]);
}