use super::dataflow_fixpoint::reachability_closure_into;
#[must_use]
pub fn fusable_pairs(adj: &[u32], n: u32, max_iters: u32) -> Vec<u32> {
let mut closure = Vec::new();
let mut next = Vec::new();
let mut out = Vec::new();
fusable_pairs_with_scratch_into(adj, n, max_iters, &mut closure, &mut next, &mut out);
out
}
pub fn fusable_pairs_into(adj: &[u32], n: u32, max_iters: u32, out: &mut Vec<u32>) {
let mut closure = Vec::new();
let mut next = Vec::new();
fusable_pairs_with_scratch_into(adj, n, max_iters, &mut closure, &mut next, out);
}
pub fn fusable_pairs_with_scratch_into(
adj: &[u32],
n: u32,
max_iters: u32,
closure: &mut Vec<u32>,
next: &mut Vec<u32>,
out: &mut Vec<u32>,
) {
if n == 0 {
closure.clear();
next.clear();
out.clear();
return;
}
let Some(cells) = n.checked_mul(n).map(|v| v as usize) else {
closure.clear();
next.clear();
out.clear();
return;
};
if adj.len() != cells {
closure.clear();
next.clear();
out.clear();
return;
}
reachability_closure_into(adj, n, max_iters.max(1), closure, next);
let n_usize = n as usize;
out.clear();
out.resize(n_usize * n_usize, 0);
for i in 0..n_usize {
for j in 0..n_usize {
if i != j && closure[i * n_usize + j] == 0 && closure[j * n_usize + i] == 0 {
out[i * n_usize + j] = 1;
}
}
}
}
#[must_use]
pub fn fusion_score(adj: &[u32], n: u32, max_iters: u32) -> u32 {
let mask = fusable_pairs(adj, n, max_iters);
mask.iter().sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn independent_passes_are_fusable() {
#[rustfmt::skip]
let adj = vec![
0, 0, 0,
0, 0, 0,
0, 0, 0,
];
let fused = fusable_pairs(&adj, 3, 3);
assert_eq!(fused[0 * 3 + 1], 1);
assert_eq!(fused[1 * 3 + 0], 1);
assert_eq!(fused[0 * 3 + 2], 1);
assert_eq!(fused[2 * 3 + 0], 1);
assert_eq!(fused[0], 0);
}
#[test]
fn dependent_passes_are_not_fusable() {
#[rustfmt::skip]
let adj = vec![
0, 1, 0,
0, 0, 1,
0, 0, 0,
];
let fused = fusable_pairs(&adj, 3, 3);
assert_eq!(fused[0 * 3 + 1], 0);
assert_eq!(fused[0 * 3 + 2], 0);
assert_eq!(fused[1 * 3 + 2], 0);
}
#[test]
fn diamond_top_and_bottom_not_fusable() {
#[rustfmt::skip]
let adj = vec![
0, 1, 1, 0,
0, 0, 0, 1,
0, 0, 0, 1,
0, 0, 0, 0,
];
let fused = fusable_pairs(&adj, 4, 4);
assert_eq!(fused[1 * 4 + 2], 1);
assert_eq!(fused[2 * 4 + 1], 1);
assert_eq!(fused[0 * 4 + 3], 0);
}
#[test]
fn empty_graph_returns_empty() {
let fused = fusable_pairs(&[], 0, 0);
assert!(fused.is_empty());
}
#[test]
fn fusable_pairs_into_reuses_output_and_clears_invalid_shape() {
let adj = vec![0u32; 9];
let mut out = Vec::with_capacity(16);
let ptr = out.as_ptr();
fusable_pairs_into(&adj, 3, 3, &mut out);
assert_eq!(out, fusable_pairs(&adj, 3, 3));
assert_eq!(out.as_ptr(), ptr);
fusable_pairs_into(&[0, 1, 0], 2, 3, &mut out);
assert!(out.is_empty());
}
#[test]
fn fusable_pairs_with_scratch_reuses_all_buffers() {
let adj = vec![0u32; 9];
let mut closure = Vec::with_capacity(16);
let mut next = Vec::with_capacity(16);
let mut out = Vec::with_capacity(16);
let closure_ptr = closure.as_ptr();
let next_ptr = next.as_ptr();
let out_ptr = out.as_ptr();
fusable_pairs_with_scratch_into(&adj, 3, 3, &mut closure, &mut next, &mut out);
assert_eq!(out, fusable_pairs(&adj, 3, 3));
assert_eq!(closure.as_ptr(), closure_ptr);
assert_eq!(next.as_ptr(), next_ptr);
assert_eq!(out.as_ptr(), out_ptr);
assert_eq!(fusion_score(&adj, 3, 3), 6);
}
}