use crate::sparse::csc::CscPattern;
#[derive(Debug, Clone)]
pub struct SuperMap {
pub icmp: Vec<usize>,
pub pairs: Vec<(usize, usize)>,
pub singletons: Vec<usize>,
}
impl SuperMap {
pub fn ncmp(&self) -> usize {
self.pairs.len() + self.singletons.len()
}
}
pub fn build_supermap(perm: &[usize]) -> SuperMap {
let n = perm.len();
let mut icmp = vec![usize::MAX; n];
let mut pairs: Vec<(usize, usize)> = Vec::new();
let mut singletons: Vec<usize> = Vec::new();
let mut visited = vec![false; n];
for start in 0..n {
if visited[start] {
continue;
}
if perm[start] == usize::MAX {
visited[start] = true;
singletons.push(start);
continue;
}
let mut cycle: Vec<usize> = Vec::new();
let mut j = start;
loop {
if visited[j] {
break;
}
visited[j] = true;
cycle.push(j);
let next = perm[j];
if next == usize::MAX {
break;
}
if next == start {
break;
}
j = next;
}
match cycle.len() {
1 => singletons.push(cycle[0]),
2 => {
let (a, b) = (cycle[0], cycle[1]);
let pair = if a < b { (a, b) } else { (b, a) };
pairs.push(pair);
}
_ => {
let mut i = 0;
while i + 1 < cycle.len() {
let (a, b) = (cycle[i], cycle[i + 1]);
let pair = if a < b { (a, b) } else { (b, a) };
pairs.push(pair);
i += 2;
}
if cycle.len() % 2 == 1 {
singletons.push(cycle[cycle.len() - 1]);
}
}
}
}
for (sid, &(a, b)) in pairs.iter().enumerate() {
icmp[a] = sid;
icmp[b] = sid;
}
let pair_count = pairs.len();
for (k, &s) in singletons.iter().enumerate() {
icmp[s] = pair_count + k;
}
SuperMap {
icmp,
pairs,
singletons,
}
}
pub fn compress_pattern(pat: &CscPattern, map: &SuperMap) -> CscPattern {
let ncmp = map.ncmp();
if ncmp == 0 {
return CscPattern {
n: 0,
col_ptr: vec![0],
row_idx: Vec::new(),
};
}
let mut mark: Vec<u32> = vec![0; ncmp];
let mut tag: u32 = 0;
let mut col_ptr: Vec<usize> = Vec::with_capacity(ncmp + 1);
col_ptr.push(0);
let mut row_idx: Vec<usize> = Vec::new();
let mut super_cols: Vec<Vec<usize>> = vec![Vec::new(); ncmp];
for (orig, &sid) in map.icmp.iter().enumerate() {
if sid < ncmp {
super_cols[sid].push(orig);
}
}
let mut col_buf: Vec<usize> = Vec::new();
for (sc, originals) in super_cols.iter().enumerate() {
tag = tag.wrapping_add(1);
if tag == 0 {
mark.iter_mut().for_each(|m| *m = 0);
tag = 1;
}
col_buf.clear();
for &orig_c in originals {
let start = pat.col_ptr[orig_c];
let end = pat.col_ptr[orig_c + 1];
for k in start..end {
let orig_r = pat.row_idx[k];
let sr = map.icmp[orig_r];
if sr == sc {
continue;
}
if mark[sr] != tag {
mark[sr] = tag;
col_buf.push(sr);
}
}
}
col_buf.sort_unstable();
row_idx.extend_from_slice(&col_buf);
col_ptr.push(row_idx.len());
}
CscPattern {
n: ncmp,
col_ptr,
row_idx,
}
}
pub fn expand_permutation(super_perm: &[usize], map: &SuperMap) -> Vec<usize> {
let n = map.icmp.len();
let mut out: Vec<usize> = Vec::with_capacity(n);
let pair_count = map.pairs.len();
for &s in super_perm {
if s < pair_count {
let (a, b) = map.pairs[s];
out.push(a);
out.push(b);
} else {
out.push(map.singletons[s - pair_count]);
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn supermap_all_singletons_identity_perm() {
let perm: Vec<usize> = (0..5).collect();
let map = build_supermap(&perm);
assert!(map.pairs.is_empty());
assert_eq!(map.singletons, vec![0, 1, 2, 3, 4]);
assert_eq!(map.icmp, vec![0, 1, 2, 3, 4]);
assert_eq!(map.ncmp(), 5);
}
#[test]
fn supermap_two_2cycles() {
let perm = vec![2, 3, 0, 1];
let map = build_supermap(&perm);
assert_eq!(map.pairs, vec![(0, 2), (1, 3)]);
assert!(map.singletons.is_empty());
assert_eq!(map.icmp, vec![0, 1, 0, 1]);
assert_eq!(map.ncmp(), 2);
}
#[test]
fn supermap_three_cycle_plus_singletons() {
let perm = vec![1, 2, 0, 3, 4, 5];
let map = build_supermap(&perm);
assert_eq!(map.pairs, vec![(0, 1)]);
assert_eq!(map.singletons, vec![2, 3, 4, 5]);
assert_eq!(map.icmp, vec![0, 0, 1, 2, 3, 4]);
assert_eq!(map.ncmp(), 5);
}
#[test]
fn supermap_unmatched_is_singleton() {
let perm = vec![1, 0, usize::MAX];
let map = build_supermap(&perm);
assert_eq!(map.pairs, vec![(0, 1)]);
assert_eq!(map.singletons, vec![2]);
assert_eq!(map.ncmp(), 2);
}
#[test]
fn expand_is_identity_when_super_perm_is_iota() {
let perm = vec![2, 3, 0, 1];
let map = build_supermap(&perm);
let super_iota: Vec<usize> = (0..map.ncmp()).collect();
let expanded = expand_permutation(&super_iota, &map);
assert_eq!(expanded, vec![0, 2, 1, 3]);
let mut sorted = expanded.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1, 2, 3]);
}
#[test]
fn expand_is_bijection_with_3cycle() {
let perm = vec![1, 2, 0, 3, 4, 5];
let map = build_supermap(&perm);
let super_iota: Vec<usize> = (0..map.ncmp()).collect();
let expanded = expand_permutation(&super_iota, &map);
assert_eq!(expanded.len(), 6);
let mut sorted = expanded.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1, 2, 3, 4, 5]);
}
fn build_full_pattern(n: usize, edges: &[(usize, usize)]) -> CscPattern {
let mut cols: Vec<Vec<usize>> = vec![Vec::new(); n];
for (j, col) in cols.iter_mut().enumerate() {
col.push(j);
}
for &(i, j) in edges {
cols[i].push(j);
cols[j].push(i);
}
let mut col_ptr = Vec::with_capacity(n + 1);
let mut row_idx = Vec::new();
col_ptr.push(0);
for col in &mut cols {
col.sort_unstable();
col.dedup();
row_idx.extend_from_slice(col);
col_ptr.push(row_idx.len());
}
CscPattern {
n,
col_ptr,
row_idx,
}
}
#[test]
fn compress_contracts_paired_columns_and_drops_selfloops() {
let pat = build_full_pattern(4, &[(0, 2), (1, 3), (0, 1)]);
let map = build_supermap(&[2, 3, 0, 1]);
let cpat = compress_pattern(&pat, &map);
assert_eq!(cpat.n, 2);
assert_eq!(cpat.col_ptr, vec![0, 1, 2]);
assert_eq!(cpat.row_idx, vec![1, 0]);
}
#[test]
fn compress_dedups_parallel_edges() {
let pat = build_full_pattern(4, &[(0, 1), (2, 3)]);
let map = build_supermap(&[2, 3, 0, 1]);
let cpat = compress_pattern(&pat, &map);
assert_eq!(cpat.n, 2);
assert_eq!(cpat.col_ptr, vec![0, 1, 2]);
assert_eq!(cpat.row_idx, vec![1, 0]);
}
#[test]
fn compress_preserves_symmetry() {
let pat = build_full_pattern(6, &[(0, 1), (2, 3), (0, 4), (2, 5), (1, 5)]);
let map = build_supermap(&[1, 0, 3, 2, 4, 5]);
let cpat = compress_pattern(&pat, &map);
let mut edges = std::collections::HashSet::new();
for c in 0..cpat.n {
for k in cpat.col_ptr[c]..cpat.col_ptr[c + 1] {
edges.insert((cpat.row_idx[k], c));
}
}
for &(r, c) in &edges {
assert!(
edges.contains(&(c, r)),
"edge ({}, {}) present but ({}, {}) not — asymmetry",
r,
c,
c,
r
);
}
for c in 0..cpat.n {
let s = cpat.col_ptr[c];
let e = cpat.col_ptr[c + 1];
let col = &cpat.row_idx[s..e];
for w in col.windows(2) {
assert!(w[0] < w[1], "col {} not strictly sorted: {:?}", c, col);
}
}
}
}