pub const PROC_BITS: u32 = 12;
pub const BLOCK_BITS: u32 = 10;
pub const FACT_BITS: u32 = 10;
const _SANITY: () = assert!(PROC_BITS + BLOCK_BITS + FACT_BITS == 32);
pub const MAX_PROC_ID: u32 = (1 << PROC_BITS) - 1;
pub const MAX_BLOCK_ID: u32 = (1 << BLOCK_BITS) - 1;
pub const MAX_FACT_ID: u32 = (1 << FACT_BITS) - 1;
pub const FACTS_PER_WORKGROUP: usize = 1024;
const BLOCK_SHIFT: u32 = FACT_BITS;
const PROC_SHIFT: u32 = FACT_BITS + BLOCK_BITS;
const FACT_MASK: u32 = MAX_FACT_ID;
const BLOCK_MASK: u32 = MAX_BLOCK_ID;
const PROC_MASK: u32 = MAX_PROC_ID;
#[must_use]
pub fn encode_node(proc_id: u32, block_id: u32, fact_id: u32) -> Option<u32> {
fits(proc_id, block_id, fact_id)
.then_some((proc_id << PROC_SHIFT) | (block_id << BLOCK_SHIFT) | fact_id)
}
#[must_use]
pub fn decode_node(node_id: u32) -> (u32, u32, u32) {
let proc_id = (node_id >> PROC_SHIFT) & PROC_MASK;
let block_id = (node_id >> BLOCK_SHIFT) & BLOCK_MASK;
let fact_id = node_id & FACT_MASK;
(proc_id, block_id, fact_id)
}
#[must_use]
pub fn fits(proc_id: u32, block_id: u32, fact_id: u32) -> bool {
proc_id <= MAX_PROC_ID && block_id <= MAX_BLOCK_ID && fact_id <= MAX_FACT_ID
}
#[must_use]
pub fn build_cpu_reference(
num_procs: u32,
blocks_per_proc: u32,
facts_per_proc: u32,
intra_edges: &[(u32, u32, u32)], inter_edges: &[(u32, u32, u32, u32)], flow_gen: &[(u32, u32, u32)], flow_kill: &[(u32, u32, u32)], ) -> (Vec<u32>, Vec<u32>) {
if num_procs == 0 || blocks_per_proc == 0 || facts_per_proc == 0 {
return (vec![0], Vec::new());
}
if !fits(
num_procs.saturating_sub(1),
blocks_per_proc.saturating_sub(1),
facts_per_proc.saturating_sub(1),
) {
return (vec![0], Vec::new());
}
let Some(slots_per_proc) = (blocks_per_proc as usize).checked_mul(facts_per_proc as usize)
else {
return (vec![0], Vec::new());
};
let Some(total_nodes) = (num_procs as usize).checked_mul(slots_per_proc) else {
return (vec![0], Vec::new());
};
if total_nodes > u32::MAX as usize {
return (vec![0], Vec::new());
}
let mut edges_flat: Vec<(u32, u32)> = Vec::new();
let block_count = (num_procs as usize) * (blocks_per_proc as usize);
let idx = |p: u32, b: u32, f: u32| -> u32 {
((p as usize) * slots_per_proc + (b as usize) * facts_per_proc as usize + f as usize) as u32
};
let block_idx =
|p: u32, b: u32| -> usize { (p as usize) * blocks_per_proc as usize + b as usize };
let in_space =
|p: u32, b: u32, f: u32| p < num_procs && b < blocks_per_proc && f < facts_per_proc;
let mut killed = vec![false; total_nodes];
for &(p, b, f) in flow_kill {
if in_space(p, b, f) {
killed[idx(p, b, f) as usize] = true;
}
}
let mut gen_offsets = vec![0usize; block_count + 1];
for &(p, b, f) in flow_gen {
if in_space(p, b, f) {
gen_offsets[block_idx(p, b) + 1] += 1;
}
}
for i in 1..gen_offsets.len() {
gen_offsets[i] += gen_offsets[i - 1];
}
let mut gen_cursor = gen_offsets[..block_count].to_vec();
let mut gen_facts = vec![0u32; gen_offsets[block_count]];
for &(p, b, f) in flow_gen {
if in_space(p, b, f) {
let key = block_idx(p, b);
let slot = gen_cursor[key];
gen_facts[slot] = f;
gen_cursor[key] += 1;
}
}
for &(p, src_b, dst_b) in intra_edges {
if p >= num_procs || src_b >= blocks_per_proc || dst_b >= blocks_per_proc {
continue;
}
for f in 0..facts_per_proc {
if killed[idx(p, src_b, f) as usize] {
continue;
}
edges_flat.push((idx(p, src_b, f), idx(p, dst_b, f)));
}
let gen_key = block_idx(p, src_b);
for &gf in &gen_facts[gen_offsets[gen_key]..gen_offsets[gen_key + 1]] {
edges_flat.push((idx(p, src_b, 0), idx(p, dst_b, gf)));
}
}
for &(sp, sb, dp, db) in inter_edges {
if sp >= num_procs || dp >= num_procs || sb >= blocks_per_proc || db >= blocks_per_proc {
continue;
}
for f in 0..facts_per_proc {
edges_flat.push((idx(sp, sb, f), idx(dp, db, f)));
}
}
if edges_flat.len() > u32::MAX as usize {
return (vec![0], Vec::new());
}
let mut row_ptr = vec![0u32; total_nodes + 1];
for &(src, _) in &edges_flat {
let row = src as usize;
row_ptr[row + 1] = row_ptr[row + 1].saturating_add(1);
}
for row in 1..row_ptr.len() {
row_ptr[row] = row_ptr[row].saturating_add(row_ptr[row - 1]);
}
let mut cursor = row_ptr[..total_nodes]
.iter()
.map(|&offset| offset as usize)
.collect::<Vec<_>>();
let mut col_idx = vec![0u32; edges_flat.len()];
for (src, dst) in edges_flat {
let row = src as usize;
let slot = cursor[row];
col_idx[slot] = dst;
cursor[row] += 1;
}
(row_ptr, col_idx)
}
#[must_use]
pub fn dense_to_encoded(dense: u32, blocks_per_proc: u32, facts_per_proc: u32) -> Option<u32> {
let slots_per_proc = blocks_per_proc.checked_mul(facts_per_proc)?;
if slots_per_proc == 0 {
return None;
}
let p = dense / slots_per_proc;
let within_proc = dense % slots_per_proc;
let b = within_proc / facts_per_proc;
let f = within_proc % facts_per_proc;
encode_node(p, b, f)
}
#[must_use]
pub fn encoded_to_dense(node_id: u32, blocks_per_proc: u32, facts_per_proc: u32) -> Option<u32> {
let (p, b, f) = decode_node(node_id);
let proc_span = blocks_per_proc.checked_mul(facts_per_proc)?;
let proc_offset = p.checked_mul(proc_span)?;
let block_offset = b.checked_mul(facts_per_proc)?;
proc_offset.checked_add(block_offset)?.checked_add(f)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_decode_roundtrips_at_max_values() {
let n = encode_node(MAX_PROC_ID, MAX_BLOCK_ID, MAX_FACT_ID).unwrap();
assert_eq!(n, u32::MAX);
assert_eq!(decode_node(n), (MAX_PROC_ID, MAX_BLOCK_ID, MAX_FACT_ID));
}
#[test]
fn encode_decode_roundtrips_at_zero() {
let n = encode_node(0, 0, 0).unwrap();
assert_eq!(n, 0);
assert_eq!(decode_node(n), (0, 0, 0));
}
#[test]
fn encode_decode_roundtrips_at_component_boundaries() {
for (p, b, f) in [
(0, 0, 1),
(0, 1, 0),
(1, 0, 0),
(0, 0, MAX_FACT_ID),
(0, MAX_BLOCK_ID, 0),
(MAX_PROC_ID, 0, 0),
(1, 2, 3),
(42, 17, 99),
(MAX_PROC_ID / 2, MAX_BLOCK_ID / 2, MAX_FACT_ID / 2),
] {
let n = encode_node(p, b, f).unwrap();
assert_eq!(
decode_node(n),
(p, b, f),
"roundtrip failed for {p}/{b}/{f}"
);
}
}
#[test]
fn fits_catches_over_range_components() {
assert!(fits(MAX_PROC_ID, MAX_BLOCK_ID, MAX_FACT_ID));
assert!(!fits(MAX_PROC_ID + 1, 0, 0));
assert!(!fits(0, MAX_BLOCK_ID + 1, 0));
assert!(!fits(0, 0, MAX_FACT_ID + 1));
assert_eq!(encode_node(MAX_PROC_ID + 1, 0, 0), None);
}
#[test]
fn csr_of_empty_graph_has_only_sentinel_row_ptr() {
let (row_ptr, col_idx) = build_cpu_reference(1, 1, 1, &[], &[], &[], &[]);
assert_eq!(row_ptr, vec![0, 0]);
assert!(col_idx.is_empty());
}
fn di(p: u32, b: u32, f: u32, blocks: u32, facts: u32) -> u32 {
p * blocks * facts + b * facts + f
}
#[test]
fn csr_single_intra_edge_produces_per_fact_duplicate_edges() {
let (row_ptr, col_idx) = build_cpu_reference(1, 2, 4, &[(0, 0, 1)], &[], &[], &[]);
assert_eq!(row_ptr.len(), 9);
assert_eq!(col_idx.len(), 4);
for f in 0..4 {
let src = di(0, 0, f, 2, 4) as usize;
let edge_start = row_ptr[src] as usize;
assert_eq!(col_idx[edge_start], di(0, 1, f, 2, 4));
}
}
#[test]
fn csr_kill_suppresses_edge_for_that_fact() {
let (row_ptr, col_idx) = build_cpu_reference(
1,
2,
4,
&[(0, 0, 1)],
&[],
&[],
&[(0, 0, 2)], );
let n_edges: u32 = row_ptr.windows(2).map(|w| w[1] - w[0]).sum();
assert_eq!(n_edges, 3);
let killed_src = di(0, 0, 2, 2, 4) as usize;
assert_eq!(row_ptr[killed_src + 1] - row_ptr[killed_src], 0);
let _ = col_idx;
}
#[test]
fn csr_inter_edges_connect_procs() {
let (row_ptr, col_idx) = build_cpu_reference(
2,
2,
2,
&[],
&[(0, 1, 1, 0)], &[],
&[],
);
assert_eq!(row_ptr.len(), 9);
assert_eq!(col_idx.len(), 2);
let src0 = di(0, 1, 0, 2, 2) as usize;
let src1 = di(0, 1, 1, 2, 2) as usize;
assert_eq!(
&col_idx[row_ptr[src0] as usize..row_ptr[src0 + 1] as usize],
&[di(1, 0, 0, 2, 2)]
);
assert_eq!(
&col_idx[row_ptr[src1] as usize..row_ptr[src1 + 1] as usize],
&[di(1, 0, 1, 2, 2)]
);
}
#[test]
fn dense_encoded_roundtrips() {
for &(p, b, f, blocks, facts) in &[
(0_u32, 0_u32, 0_u32, 2_u32, 2_u32),
(1, 1, 1, 2, 2),
(42, 17, 99, 64, 128),
(MAX_PROC_ID, 3, 7, 16, 16),
] {
let d = di(p, b, f, blocks, facts);
let enc = dense_to_encoded(d, blocks, facts).unwrap();
assert_eq!(decode_node(enc), (p, b, f));
let back = encoded_to_dense(enc, blocks, facts).unwrap();
assert_eq!(back, d, "roundtrip mismatch {p}/{b}/{f}");
}
}
#[test]
fn csr_gen_introduces_new_fact_flow_from_zero_fact() {
let (row_ptr, col_idx) = build_cpu_reference(1, 2, 4, &[(0, 0, 1)], &[], &[(0, 0, 2)], &[]);
assert_eq!(col_idx.len(), 5);
let zero_src = di(0, 0, 0, 2, 4) as usize;
let fact2_dst = di(0, 1, 2, 2, 4);
let zero_neighbours = &col_idx[row_ptr[zero_src] as usize..row_ptr[zero_src + 1] as usize];
assert!(zero_neighbours.contains(&fact2_dst));
}
#[test]
fn csr_rejects_dimensions_overflowing_encoding() {
let (row_ptr, col_idx) = build_cpu_reference(MAX_PROC_ID + 2, 1, 1, &[], &[], &[], &[]);
assert_eq!(row_ptr, vec![0]);
assert!(col_idx.is_empty());
}
#[test]
fn row_ptr_length_is_nodes_plus_one() {
let procs = 3;
let blocks = 4;
let facts = 8;
let (row_ptr, _) = build_cpu_reference(procs, blocks, facts, &[], &[], &[], &[]);
assert_eq!(
row_ptr.len(),
(procs as usize * blocks as usize * facts as usize) + 1
);
}
#[test]
fn facts_per_workgroup_matches_max_fact_id_plus_one() {
assert_eq!(FACTS_PER_WORKGROUP as u32, MAX_FACT_ID + 1);
}
}