use std::ops::Range;
use super::numeric::SupernodeInfo;
fn do_merge(
parent_nelim: usize,
parent_cc: usize,
child_nelim: usize,
child_cc: usize,
nemin: usize,
) -> bool {
if parent_nelim == 1 && parent_cc == child_cc.saturating_sub(1) {
return true;
}
if parent_nelim < nemin && child_nelim < nemin {
return true;
}
false
}
fn sorted_union_excluding(a: &[usize], b: &[usize], exclude_range: Range<usize>) -> Vec<usize> {
let mut result = Vec::with_capacity(a.len() + b.len());
let mut ia = 0;
let mut ib = 0;
while ia < a.len() && ib < b.len() {
let va = a[ia];
let vb = b[ib];
let next = if va < vb {
ia += 1;
va
} else if vb < va {
ib += 1;
vb
} else {
ia += 1;
ib += 1;
va
};
if !exclude_range.contains(&next) {
result.push(next);
}
}
for &v in &a[ia..] {
if !exclude_range.contains(&v) {
result.push(v);
}
}
for &v in &b[ib..] {
if !exclude_range.contains(&v) {
result.push(v);
}
}
result
}
pub(crate) fn amalgamate(mut supernodes: Vec<SupernodeInfo>, nemin: usize) -> Vec<SupernodeInfo> {
let n = supernodes.len();
if n <= 1 {
return supernodes;
}
let mut deleted = vec![false; n];
let mut nelim: Vec<usize> = supernodes
.iter()
.map(|sn| sn.col_end - sn.col_begin)
.collect();
let mut children = vec![Vec::new(); n];
for (s, sn) in supernodes.iter().enumerate() {
if let Some(p) = sn.parent {
children[p].push(s);
}
}
for p in 0..n {
if deleted[p] {
continue;
}
let p_children = std::mem::take(&mut children[p]);
for &c in &p_children {
if deleted[c] {
continue;
}
let parent_nelim = nelim[p];
let parent_cc = nelim[p] + supernodes[p].pattern.len();
let child_nelim = nelim[c];
let child_cc = nelim[c] + supernodes[c].pattern.len();
if do_merge(parent_nelim, parent_cc, child_nelim, child_cc, nemin) {
let new_col_begin = supernodes[p].col_begin.min(supernodes[c].col_begin);
let new_col_end = supernodes[p].col_end.max(supernodes[c].col_end);
let exclude = new_col_begin..new_col_end;
let child_pattern = std::mem::take(&mut supernodes[c].pattern);
let parent_pattern = std::mem::take(&mut supernodes[p].pattern);
let merged_pattern =
sorted_union_excluding(&parent_pattern, &child_pattern, exclude);
let child_owned = std::mem::take(&mut supernodes[c].owned_ranges);
supernodes[p].owned_ranges.extend(child_owned);
supernodes[p].col_begin = new_col_begin;
supernodes[p].col_end = new_col_end;
supernodes[p].pattern = merged_pattern;
nelim[p] += nelim[c];
let c_children = std::mem::take(&mut children[c]);
for &gc in &c_children {
if !deleted[gc] {
supernodes[gc].parent = Some(p);
}
}
children[p].extend(c_children);
deleted[c] = true;
}
}
for &c in &p_children {
if !deleted[c] {
children[p].push(c);
}
}
}
let mut old_to_new = vec![0usize; n];
let mut new_idx = 0;
for s in 0..n {
if !deleted[s] {
old_to_new[s] = new_idx;
new_idx += 1;
}
}
let mut result = Vec::with_capacity(new_idx);
for s in 0..n {
if !deleted[s] {
let mut sn = std::mem::replace(
&mut supernodes[s],
SupernodeInfo {
col_begin: 0,
col_end: 0,
pattern: Vec::new(),
parent: None,
owned_ranges: Vec::new(),
in_small_leaf: false,
},
);
sn.parent = sn.parent.map(|p| old_to_new[p]);
result.push(sn);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(clippy::single_range_in_vec_init)]
fn sn(
col_begin: usize,
col_end: usize,
pattern: Vec<usize>,
parent: Option<usize>,
) -> SupernodeInfo {
SupernodeInfo {
col_begin,
col_end,
pattern,
parent,
owned_ranges: vec![col_begin..col_end],
in_small_leaf: false,
}
}
#[test]
fn test_do_merge_structural_match() {
assert!(do_merge(1, 5, 2, 6, 32));
}
#[test]
fn test_do_merge_nemin_both_small() {
assert!(do_merge(4, 20, 8, 30, 32));
}
#[test]
fn test_do_merge_one_large() {
assert!(!do_merge(32, 50, 4, 20, 32));
assert!(!do_merge(4, 20, 32, 50, 32));
}
#[test]
fn test_do_merge_both_large() {
assert!(!do_merge(40, 60, 35, 50, 32));
}
#[test]
fn test_sorted_union_disjoint() {
let result = sorted_union_excluding(&[1, 3, 5], &[2, 4, 6], 0..0);
assert_eq!(result, vec![1, 2, 3, 4, 5, 6]);
}
#[test]
fn test_sorted_union_overlapping() {
let result = sorted_union_excluding(&[1, 3, 5], &[3, 5, 7], 0..0);
assert_eq!(result, vec![1, 3, 5, 7]);
}
#[test]
fn test_sorted_union_with_exclusion() {
let result = sorted_union_excluding(&[5, 8, 10], &[3, 5, 7], 3..6);
assert_eq!(result, vec![7, 8, 10]);
}
#[test]
fn test_no_merges_large_supernodes() {
let supernodes = vec![
sn(0, 40, vec![200, 201], Some(4)),
sn(40, 80, vec![200, 202], Some(4)),
sn(80, 120, vec![200, 203], Some(4)),
sn(120, 160, vec![200, 204], Some(4)),
sn(160, 200, vec![], None),
];
let result = amalgamate(supernodes, 32);
assert_eq!(
result.len(),
5,
"no merges expected — all supernodes are large"
);
}
#[test]
fn test_nemin_merge_simple_pair() {
let supernodes = vec![sn(0, 4, vec![4, 5, 10], Some(1)), sn(4, 8, vec![10], None)];
let result = amalgamate(supernodes, 32);
assert_eq!(result.len(), 1, "pair should merge into one");
assert_eq!(result[0].col_begin, 0);
assert_eq!(result[0].col_end, 8);
assert_eq!(result[0].pattern, vec![10]);
assert!(result[0].parent.is_none());
}
#[test]
fn test_structural_match_merge() {
let supernodes = vec![
sn(0, 3, vec![3, 10, 20], Some(1)),
sn(3, 4, vec![10, 20, 30, 40], None),
];
let result = amalgamate(supernodes, 32);
assert_eq!(result.len(), 1, "structural match should merge");
assert_eq!(result[0].col_begin, 0);
assert_eq!(result[0].col_end, 4);
assert_eq!(result[0].pattern, vec![10, 20, 30, 40]);
}
#[test]
fn test_chain_merge() {
let supernodes = vec![
sn(0, 2, vec![2, 3, 4, 5, 6, 7, 8, 9], Some(1)),
sn(2, 4, vec![4, 5, 6, 7, 8, 9], Some(2)),
sn(4, 6, vec![6, 7, 8, 9], Some(3)),
sn(6, 8, vec![8, 9], Some(4)),
sn(8, 10, vec![], None),
];
let result = amalgamate(supernodes, 32);
assert_eq!(
result.len(),
1,
"chain of 5 small supernodes should merge to 1"
);
assert_eq!(result[0].col_begin, 0);
assert_eq!(result[0].col_end, 10);
assert!(result[0].pattern.is_empty());
}
#[test]
fn test_bushy_tree_merge() {
let supernodes = vec![
sn(0, 2, vec![8, 9], Some(4)),
sn(2, 4, vec![8, 9], Some(4)),
sn(4, 6, vec![8, 9], Some(4)),
sn(6, 8, vec![8, 9], Some(4)),
sn(8, 10, vec![], None),
];
let result = amalgamate(supernodes, 32);
assert_eq!(
result.len(),
1,
"all 4 small children should merge into parent"
);
assert_eq!(result[0].col_begin, 0);
assert_eq!(result[0].col_end, 10);
}
#[test]
fn test_partial_merge_mixed_sizes() {
let supernodes = vec![
sn(0, 4, vec![100, 140], Some(3)),
sn(4, 8, vec![100, 140], Some(3)),
sn(8, 48, vec![100, 140], Some(3)),
sn(100, 104, vec![140], Some(4)),
sn(140, 190, vec![], None),
];
let result = amalgamate(supernodes, 32);
assert_eq!(
result.len(),
3,
"only 2 small children should merge, large child stays"
);
assert_eq!(result[0].col_end - result[0].col_begin, 40);
}
#[test]
fn test_parent_reparenting() {
let supernodes = vec![
sn(0, 2, vec![4, 5, 8, 9], Some(2)),
sn(2, 4, vec![4, 5, 8, 9], Some(2)),
sn(4, 6, vec![8, 9], Some(3)),
sn(8, 10, vec![], None),
];
let result = amalgamate(supernodes, 32);
assert_eq!(result.len(), 1, "all small nodes should eventually merge");
assert_eq!(result[0].col_begin, 0);
assert_eq!(result[0].col_end, 10);
}
#[test]
fn test_pattern_union_on_merge() {
let supernodes = vec![
sn(0, 2, vec![2, 5, 10, 20], Some(1)),
sn(2, 4, vec![5, 15, 20], None),
];
let result = amalgamate(supernodes, 32);
assert_eq!(result.len(), 1);
assert_eq!(result[0].pattern, vec![5, 10, 15, 20]);
}
#[test]
fn test_postorder_preserved() {
let supernodes = vec![
sn(0, 2, vec![4, 8], Some(2)),
sn(2, 4, vec![4, 8], Some(2)),
sn(4, 6, vec![8], Some(3)),
sn(8, 12, vec![], None),
];
let result = amalgamate(supernodes, 4);
for (i, sn) in result.iter().enumerate() {
if let Some(p) = sn.parent {
assert!(
p > i,
"postorder violation: supernode {} has parent {} (should be > {})",
i,
p,
i
);
}
}
}
#[test]
fn test_single_supernode_passthrough() {
let supernodes = vec![sn(0, 10, vec![], None)];
let result = amalgamate(supernodes, 32);
assert_eq!(result.len(), 1);
assert_eq!(result[0].col_begin, 0);
assert_eq!(result[0].col_end, 10);
}
#[test]
fn test_simplicial_many_single_column_supernodes() {
let n = 100;
let supernodes: Vec<SupernodeInfo> = (0..n)
.map(|i| {
let pattern: Vec<usize> = ((i + 1)..n.min(i + 6)).collect();
let parent = if i + 1 < n { Some(i + 1) } else { None };
sn(i, i + 1, pattern, parent)
})
.collect();
let result = amalgamate(supernodes, 32);
assert!(
result.len() < 20,
"expected significant reduction from 100, got {}",
result.len()
);
for (i, sn) in result.iter().enumerate() {
if let Some(p) = sn.parent {
assert!(p > i, "postorder violation at {}: parent={}", i, p);
}
}
for (i, sn) in result.iter().enumerate() {
assert!(
sn.col_begin < sn.col_end,
"empty supernode at {}: [{}, {})",
i,
sn.col_begin,
sn.col_end
);
}
for sn in &result {
for &r in &sn.pattern {
assert!(
r < sn.col_begin || r >= sn.col_end,
"pattern entry {} is within [{}, {})",
r,
sn.col_begin,
sn.col_end
);
}
}
}
#[test]
fn test_star_tree_many_children() {
let mut supernodes: Vec<SupernodeInfo> = (0..20)
.map(|i| sn(i * 2, i * 2 + 2, vec![40, 41, 42, 43], Some(20)))
.collect();
supernodes.push(sn(40, 44, vec![], None));
let result = amalgamate(supernodes, 32);
assert!(
result.len() < 21,
"at least some children should merge, got {} supernodes",
result.len()
);
for (i, sn) in result.iter().enumerate() {
if let Some(p) = sn.parent {
assert!(p > i, "postorder violation at {}: parent={}", i, p);
}
}
}
#[test]
fn test_nemin_1_disables_amalgamation() {
let supernodes = vec![sn(0, 4, vec![4, 5, 6], Some(1)), sn(4, 8, vec![5, 6], None)];
let result = amalgamate(supernodes.clone(), 1);
assert_eq!(
result.len(),
2,
"nemin=1 should disable amalgamation: got {} supernodes",
result.len()
);
assert_eq!(result[0].col_begin, 0);
assert_eq!(result[0].col_end, 4);
assert_eq!(result[1].col_begin, 4);
assert_eq!(result[1].col_end, 8);
}
#[test]
fn test_nemin_64_more_aggressive() {
let supernodes = vec![
sn(0, 40, vec![40, 50, 60], Some(1)),
sn(40, 80, vec![50, 60], None),
];
let result_32 = amalgamate(supernodes.clone(), 32);
assert_eq!(
result_32.len(),
2,
"nemin=32 should NOT merge supernodes with nelim=40: got {} supernodes",
result_32.len()
);
let result_64 = amalgamate(supernodes, 64);
assert_eq!(
result_64.len(),
1,
"nemin=64 should merge supernodes with nelim=40: got {} supernodes",
result_64.len()
);
assert_eq!(result_64[0].col_begin, 0);
assert_eq!(result_64[0].col_end, 80);
}
}