use faer::sparse::{SparseColMat, SparseRowMat};
use std::sync::Arc;
struct SparseHessianSymbolic {
dim: usize,
nnz: usize,
col_ptrs: Vec<usize>,
row_indices: Vec<usize>,
first_row: Vec<usize>,
contiguous: bool,
}
impl SparseHessianSymbolic {
fn build(csrs: &[&SparseRowMat<usize, f64>], dim: usize) -> Self {
use std::collections::BTreeSet;
let n = csrs[0].nrows();
let mut rows_by_col = vec![BTreeSet::<usize>::new(); dim];
let mut cols = Vec::with_capacity(32);
for i in 0..n {
cols.clear();
for csr in csrs {
let sym = csr.symbolic();
let rp = sym.row_ptr();
let ci = sym.col_idx();
for p in rp[i]..rp[i + 1] {
cols.push(ci[p]);
}
}
cols.sort_unstable();
cols.dedup();
for (ai, &ca) in cols.iter().enumerate() {
assert!(
ca < dim,
"SparseHessianSymbolic::build: column index {ca} out of Hessian dimension {dim}"
);
for &cb in &cols[ai..] {
assert!(
cb < dim,
"SparseHessianSymbolic::build: column index {cb} out of Hessian dimension {dim}"
);
rows_by_col[cb].insert(ca);
}
}
}
let nnz = rows_by_col.iter().map(BTreeSet::len).sum();
let mut col_ptrs = Vec::with_capacity(dim + 1);
let mut row_indices = Vec::with_capacity(nnz);
col_ptrs.push(0);
for rows in rows_by_col {
row_indices.extend(rows);
col_ptrs.push(row_indices.len());
}
let mut first_row = vec![usize::MAX; dim];
let mut contiguous = true;
for c in 0..dim {
let start = col_ptrs[c];
let end = col_ptrs[c + 1];
if start == end {
continue;
}
first_row[c] = row_indices[start];
for (off, &ri) in row_indices[start..end].iter().enumerate() {
if ri != first_row[c] + off {
contiguous = false;
break;
}
}
if !contiguous {
break;
}
}
SparseHessianSymbolic {
dim,
nnz,
col_ptrs,
row_indices,
first_row,
contiguous,
}
}
}
pub struct SparseHessianAccumulator {
sym: Arc<SparseHessianSymbolic>,
pub(crate) values: Vec<f64>,
}
impl Clone for SparseHessianAccumulator {
fn clone(&self) -> Self {
SparseHessianAccumulator {
sym: Arc::clone(&self.sym),
values: self.values.clone(),
}
}
}
impl SparseHessianAccumulator {
pub fn from_single_csr(csr: &SparseRowMat<usize, f64>, dim: usize) -> Self {
Self::from_multi_csr(&[csr], dim)
}
pub fn from_multi_csr(csrs: &[&SparseRowMat<usize, f64>], dim: usize) -> Self {
let sym = Arc::new(SparseHessianSymbolic::build(csrs, dim));
let nnz = sym.nnz;
SparseHessianAccumulator {
sym,
values: vec![0.0; nnz],
}
}
#[inline(always)]
pub fn add_upper(&mut self, r: usize, c: usize, val: f64) {
assert!(r <= c, "add_upper requires r <= c, got ({r}, {c})");
let s = &*self.sym;
if s.contiguous {
let start = s.col_ptrs[c];
let end = s.col_ptrs[c + 1];
let offset = r.wrapping_sub(s.first_row[c]);
assert!(
r >= s.first_row[c] && offset < end - start,
"add_upper contiguous OOB"
);
let idx = start + offset;
unsafe {
*self.values.get_unchecked_mut(idx) += val;
}
} else {
let start = s.col_ptrs[c];
let end = s.col_ptrs[c + 1];
let slice = &s.row_indices[start..end];
for (off, &ri) in slice.iter().enumerate() {
if ri == r {
unsafe {
*self.values.get_unchecked_mut(start + off) += val;
}
return;
}
}
assert!(
false,
"SparseHessianAccumulator::add_upper: ({r}, {c}) not in pattern"
);
}
}
#[inline]
pub fn add_values(&mut self, other: &[f64]) {
assert_eq!(self.values.len(), other.len());
for (a, &b) in self.values.iter_mut().zip(other.iter()) {
*a += b;
}
}
pub fn empty_clone(&self) -> Self {
SparseHessianAccumulator {
sym: Arc::clone(&self.sym),
values: vec![0.0; self.values.len()],
}
}
pub fn into_sparse_col_mat(self) -> SparseColMat<usize, f64> {
use faer::sparse::SymbolicSparseColMat;
let (col_ptrs, row_indices, dim) = match Arc::try_unwrap(self.sym) {
Ok(owned) => (owned.col_ptrs, owned.row_indices, owned.dim),
Err(shared) => (
shared.col_ptrs.clone(),
shared.row_indices.clone(),
shared.dim,
),
};
let symbolic = {
unsafe { SymbolicSparseColMat::new_unchecked(dim, dim, col_ptrs, None, row_indices) }
};
SparseColMat::new(symbolic, self.values)
}
}
#[cfg(test)]
mod tests {
use super::*;
use faer::sparse::Triplet;
#[test]
fn sparse_hessian_pattern_is_column_major_csc() {
let sparse = SparseColMat::try_new_from_triplets(
1,
3,
&[
Triplet::new(0, 0, 1.0),
Triplet::new(0, 1, 1.0),
Triplet::new(0, 2, 1.0),
],
)
.expect("sparse column matrix");
let csr = sparse.to_row_major().expect("csr conversion");
let accumulator = SparseHessianAccumulator::from_single_csr(&csr, 3);
assert_eq!(accumulator.sym.col_ptrs, vec![0, 1, 3, 6]);
assert_eq!(accumulator.sym.row_indices, vec![0, 0, 1, 0, 1, 2]);
}
}