use crate::sparse::csc::CscPattern;
use crate::symbolic::supernode::Supernode;
#[derive(Debug, Clone)]
pub struct SmallLeafParams {
pub nrow_max: usize,
pub ncol_max: usize,
pub arena_budget: usize,
}
impl Default for SmallLeafParams {
fn default() -> Self {
Self {
nrow_max: 16,
ncol_max: 8,
arena_budget: 4096,
}
}
}
#[derive(Debug, Clone)]
pub struct SmallLeafGroup {
pub members: Vec<usize>,
pub member_rows: Vec<Vec<usize>>,
pub arena_size: usize,
pub offsets: Vec<usize>,
}
pub fn find_small_leaf_groups(
supernodes: &[Supernode],
permuted_pattern: &CscPattern,
params: &SmallLeafParams,
) -> (Vec<SmallLeafGroup>, Vec<Option<usize>>) {
let mut groups: Vec<SmallLeafGroup> = Vec::new();
let mut snode_group: Vec<Option<usize>> = vec![None; supernodes.len()];
let mut seen: Vec<bool> = vec![false; permuted_pattern.n];
let mut trailing: Vec<usize> = Vec::new();
let mut current: Option<SmallLeafGroup> = None;
let flush = |current: &mut Option<SmallLeafGroup>,
groups: &mut Vec<SmallLeafGroup>,
snode_group: &mut [Option<usize>]| {
if let Some(g) = current.take() {
let gid = groups.len();
for &m in &g.members {
snode_group[m] = Some(gid);
}
groups.push(g);
}
};
for (idx, snode) in supernodes.iter().enumerate() {
let qualifies = snode.children.is_empty()
&& snode.ncol <= params.ncol_max
&& snode.nrow <= params.nrow_max
&& snode.nrow > 0;
if !qualifies {
flush(&mut current, &mut groups, &mut snode_group);
continue;
}
let rows = compute_leaf_rows(snode, permuted_pattern, &mut seen, &mut trailing);
let leaf_size = rows.len() * rows.len();
let must_close = match ¤t {
Some(g) => g.arena_size + leaf_size > params.arena_budget,
None => false,
};
if must_close {
flush(&mut current, &mut groups, &mut snode_group);
}
let g = current.get_or_insert_with(|| SmallLeafGroup {
members: Vec::new(),
member_rows: Vec::new(),
arena_size: 0,
offsets: vec![0],
});
g.members.push(idx);
g.member_rows.push(rows);
g.arena_size += leaf_size;
g.offsets.push(g.arena_size);
}
flush(&mut current, &mut groups, &mut snode_group);
(groups, snode_group)
}
fn compute_leaf_rows(
snode: &Supernode,
pattern: &CscPattern,
seen: &mut [bool],
trailing: &mut Vec<usize>,
) -> Vec<usize> {
let first_col = snode.first_col;
let ncol = snode.ncol;
for s in seen.iter_mut().skip(first_col).take(ncol) {
*s = true;
}
trailing.clear();
for j in first_col..first_col + ncol {
for k in pattern.col_ptr[j]..pattern.col_ptr[j + 1] {
let r = pattern.row_idx[k];
if !seen[r] {
seen[r] = true;
trailing.push(r);
}
}
}
trailing.sort_unstable();
let mut rows = Vec::with_capacity(ncol + trailing.len());
rows.extend(first_col..first_col + ncol);
rows.extend_from_slice(trailing);
for s in seen.iter_mut().skip(first_col).take(ncol) {
*s = false;
}
for &r in trailing.iter() {
seen[r] = false;
}
rows
}
#[cfg(test)]
mod tests {
use super::*;
fn mk_leaf(first_col: usize, ncol: usize, nrow: usize) -> Supernode {
Supernode {
first_col,
ncol,
nrow,
row_indices: Vec::new(),
children: Vec::new(),
}
}
fn mk_nonleaf(first_col: usize, ncol: usize, nrow: usize) -> Supernode {
Supernode {
first_col,
ncol,
nrow,
row_indices: Vec::new(),
children: vec![0],
}
}
fn diag_pattern(n: usize) -> CscPattern {
let col_ptr: Vec<usize> = (0..=n).collect();
let row_idx: Vec<usize> = (0..n).collect();
CscPattern {
n,
col_ptr,
row_idx,
}
}
#[test]
fn empty_input_yields_empty_output() {
let pat = diag_pattern(0);
let (g, m) = find_small_leaf_groups(&[], &pat, &SmallLeafParams::default());
assert!(g.is_empty());
assert!(m.is_empty());
}
#[test]
fn all_small_leaves_pack_into_one_group() {
let snodes = vec![
mk_leaf(0, 2, 2),
mk_leaf(2, 2, 2),
mk_leaf(4, 2, 2),
mk_leaf(6, 1, 1),
mk_leaf(7, 2, 2),
];
let pat = diag_pattern(9);
let (groups, snode_group) =
find_small_leaf_groups(&snodes, &pat, &SmallLeafParams::default());
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].members, vec![0, 1, 2, 3, 4]);
let expected_arena = 4 + 4 + 4 + 1 + 4;
assert_eq!(groups[0].arena_size, expected_arena);
assert_eq!(groups[0].offsets.len(), 6);
assert_eq!(groups[0].offsets.first(), Some(&0));
assert_eq!(groups[0].offsets.last(), Some(&expected_arena));
for i in 0..5 {
assert_eq!(snode_group[i], Some(0));
}
assert_eq!(groups[0].member_rows[0], vec![0, 1]);
assert_eq!(groups[0].member_rows[3], vec![6]);
}
#[test]
fn non_leaf_breaks_group() {
let snodes = vec![
mk_leaf(0, 2, 2),
mk_leaf(2, 2, 2),
mk_nonleaf(4, 3, 3),
mk_leaf(7, 2, 2),
mk_leaf(9, 2, 2),
];
let pat = diag_pattern(11);
let (groups, snode_group) =
find_small_leaf_groups(&snodes, &pat, &SmallLeafParams::default());
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].members, vec![0, 1]);
assert_eq!(groups[1].members, vec![3, 4]);
assert_eq!(snode_group[2], None);
}
#[test]
fn oversize_leaf_breaks_group() {
let snodes = vec![mk_leaf(0, 2, 2), mk_leaf(2, 9, 9), mk_leaf(11, 2, 2)];
let pat = diag_pattern(13);
let (groups, snode_group) =
find_small_leaf_groups(&snodes, &pat, &SmallLeafParams::default());
assert_eq!(groups.len(), 2);
assert_eq!(groups[0].members, vec![0]);
assert_eq!(groups[1].members, vec![2]);
assert_eq!(snode_group[1], None);
}
#[test]
fn budget_forces_split() {
let params = SmallLeafParams {
nrow_max: 16,
ncol_max: 8,
arena_budget: 5,
};
let snodes = vec![mk_leaf(0, 2, 2), mk_leaf(2, 2, 2), mk_leaf(4, 2, 2)];
let pat = diag_pattern(6);
let (groups, snode_group) = find_small_leaf_groups(&snodes, &pat, ¶ms);
assert_eq!(groups.len(), 3);
assert_eq!(snode_group[0], Some(0));
assert_eq!(snode_group[1], Some(1));
assert_eq!(snode_group[2], Some(2));
}
#[test]
fn zero_nrow_is_skipped() {
let snodes = vec![mk_leaf(0, 0, 0), mk_leaf(0, 2, 2)];
let pat = diag_pattern(2);
let (groups, snode_group) =
find_small_leaf_groups(&snodes, &pat, &SmallLeafParams::default());
assert_eq!(snode_group[0], None);
assert_eq!(snode_group[1], Some(0));
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].members, vec![1]);
}
#[test]
fn offsets_are_prefix_sums() {
let snodes = vec![mk_leaf(0, 2, 2), mk_leaf(2, 3, 3), mk_leaf(5, 4, 4)];
let pat = diag_pattern(9);
let (groups, _) = find_small_leaf_groups(&snodes, &pat, &SmallLeafParams::default());
assert_eq!(groups.len(), 1);
let g = &groups[0];
assert_eq!(g.offsets, vec![0, 4, 4 + 9, 4 + 9 + 16]);
assert_eq!(g.arena_size, *g.offsets.last().unwrap());
for w in g.offsets.windows(2) {
assert!(w[1] > w[0]);
}
}
#[test]
fn compute_leaf_rows_with_offdiagonal_nonzeros() {
let col_ptr = vec![0, 0, 0, 2, 5, 5, 6, 6, 7];
let row_idx = vec![2, 7, 3, 5, 7, 5, 7];
let pat = CscPattern {
n: 8,
col_ptr,
row_idx,
};
let leaf = mk_leaf(2, 2, 2);
let mut seen = vec![false; 8];
let mut trailing = Vec::new();
let rows = compute_leaf_rows(&leaf, &pat, &mut seen, &mut trailing);
assert_eq!(rows, vec![2, 3, 5, 7]);
assert!(seen.iter().all(|&b| !b), "seen invariant restored");
}
}