use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::bitset::bitset_words;
pub const DENSE_THRESHOLD_PCT: u32 = 25;
pub const OP_ID: &str = "vyre-primitives::graph::adaptive_traverse_dense";
pub const NAME_FRONTIER_IN: &str = "adap_frontier_in";
pub const NAME_FRONTIER_OUT: &str = "adap_frontier_out";
pub const NAME_ADJ_ROWS_DENSE: &str = "adap_adj_rows_dense";
#[must_use]
pub fn should_use_dense(frontier_in: &[u32], node_count: u32) -> bool {
if node_count == 0 {
return false;
}
let popcount: u32 = frontier_in.iter().map(|w| w.count_ones()).sum();
(popcount as u64) * 100 >= (DENSE_THRESHOLD_PCT as u64) * (node_count as u64)
}
#[must_use]
pub fn adaptive_dense_step(
frontier_in: &str,
frontier_out: &str,
adj_rows_dense: &str,
node_count: u32,
) -> Program {
if node_count == 0 {
return crate::invalid_output_program(
OP_ID,
frontier_out,
DataType::U32,
"Fix: adaptive_dense_step requires node_count > 0, got 0.".to_string(),
);
}
let words = bitset_words(node_count);
let Some(adj_count) = u64::from(node_count).checked_mul(u64::from(words)) else {
return crate::invalid_output_program(
OP_ID,
frontier_out,
DataType::U32,
format!("Fix: adaptive_dense_step buffer size overflows u64 ({node_count} nodes x {words} words)."),
);
};
if adj_count > u64::from(u32::MAX) {
return crate::invalid_output_program(
OP_ID,
frontier_out,
DataType::U32,
format!("Fix: adaptive_dense_step buffer size {adj_count} exceeds u32::MAX ({node_count} nodes x {words} words). Partition the graph or use csr_forward_traverse."),
);
}
let adj_count_u32 = adj_count as u32;
let d = Expr::InvocationId { axis: 0 };
let body: Vec<Node> = vec![
Node::let_bind("row_start", Expr::mul(d.clone(), Expr::u32(words))),
Node::let_bind("hit", Expr::u32(0)),
Node::loop_for(
"w",
Expr::u32(0),
Expr::u32(words),
vec![Node::assign(
"hit",
Expr::bitor(
Expr::var("hit"),
Expr::bitand(
Expr::load(
adj_rows_dense,
Expr::add(Expr::var("row_start"), Expr::var("w")),
),
Expr::load(frontier_in, Expr::var("w")),
),
),
)],
),
Node::if_then(
Expr::ne(Expr::var("hit"), Expr::u32(0)),
vec![
Node::let_bind("word_idx", Expr::shr(d.clone(), Expr::u32(5))),
Node::let_bind(
"bit_mask",
Expr::shl(Expr::u32(1), Expr::bitand(d.clone(), Expr::u32(31))),
),
Node::let_bind(
"_",
Expr::atomic_or(frontier_out, Expr::var("word_idx"), Expr::var("bit_mask")),
),
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(frontier_in, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(words),
BufferDecl::storage(frontier_out, 1, BufferAccess::ReadWrite, DataType::U32)
.with_count(words),
BufferDecl::storage(adj_rows_dense, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(adj_count_u32),
],
[1, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(vec![Node::if_then(
Expr::lt(d.clone(), Expr::u32(node_count)),
body,
)]),
}],
)
}
#[must_use]
pub fn cpu_dense_step(frontier_in: &[u32], adj_rows_dense: &[u32], node_count: u32) -> Vec<u32> {
let words = bitset_words(node_count) as usize;
let mut out = vec![0_u32; words];
for d in 0..node_count as usize {
let row_start = d * words;
let mut hit: u32 = 0;
for w in 0..words {
let adj = adj_rows_dense.get(row_start + w).copied().unwrap_or(0);
let frontier = frontier_in.get(w).copied().unwrap_or(0);
hit |= adj & frontier;
}
if hit != 0 {
out[d / 32] |= 1 << (d % 32);
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn pack_nodes(bits: &[u32], node_count: u32) -> Vec<u32> {
let mut buf = vec![0_u32; bitset_words(node_count) as usize];
for &b in bits {
buf[(b as usize) / 32] |= 1 << (b % 32);
}
buf
}
fn build_dense_adj(edges: &[(u32, u32)], node_count: u32) -> Vec<u32> {
let words = bitset_words(node_count) as usize;
let mut rows = vec![0_u32; (node_count as usize) * words];
for &(src, dst) in edges {
let idx = (dst as usize) * words + (src as usize) / 32;
rows[idx] |= 1 << (src % 32);
}
rows
}
#[test]
fn should_use_dense_empty_frontier_is_false() {
assert!(!should_use_dense(&[0_u32], 32));
}
#[test]
fn should_use_dense_zero_nodes_returns_false() {
assert!(!should_use_dense(&[], 0));
}
#[test]
fn should_use_dense_full_frontier_is_true() {
let f = vec![0xFFFF_FFFF_u32; 4];
assert!(should_use_dense(&f, 128));
}
#[test]
fn should_use_dense_quarter_frontier_at_threshold() {
assert!(should_use_dense(&[0xFF_u32], 32));
}
#[test]
fn should_use_dense_just_under_threshold_is_false() {
assert!(!should_use_dense(&[0x7F_u32], 32));
}
#[test]
fn cpu_dense_step_empty_frontier_produces_empty() {
let frontier_in = pack_nodes(&[], 16);
let adj = build_dense_adj(&[(0, 1), (1, 2)], 16);
let out = cpu_dense_step(&frontier_in, &adj, 16);
assert_eq!(out, vec![0; bitset_words(16) as usize]);
}
#[test]
fn cpu_dense_step_single_edge() {
let out = cpu_dense_step(&pack_nodes(&[0], 16), &build_dense_adj(&[(0, 1)], 16), 16);
assert_eq!(out, pack_nodes(&[1], 16));
}
#[test]
fn cpu_dense_step_fanout() {
let out = cpu_dense_step(
&pack_nodes(&[0], 16),
&build_dense_adj(&[(0, 1), (0, 2), (0, 5)], 16),
16,
);
assert_eq!(out, pack_nodes(&[1, 2, 5], 16));
}
#[test]
fn cpu_dense_step_fanin() {
let out = cpu_dense_step(
&pack_nodes(&[1, 2], 16),
&build_dense_adj(&[(1, 3), (2, 3), (4, 3)], 16),
16,
);
assert_eq!(out, pack_nodes(&[3], 16));
}
#[test]
fn cpu_dense_step_cross_word_boundary() {
let out = cpu_dense_step(&pack_nodes(&[5], 70), &build_dense_adj(&[(5, 65)], 70), 70);
assert_eq!(out, pack_nodes(&[65], 70));
}
#[test]
fn cpu_dense_step_short_buffers_treat_missing_words_as_zero() {
let out = cpu_dense_step(&[1], &[], 16);
assert!(out.iter().all(|&word| word == 0));
}
#[test]
fn cpu_dense_step_is_one_hop_only() {
let out = cpu_dense_step(
&pack_nodes(&[0], 16),
&build_dense_adj(&[(0, 1), (1, 2), (2, 3)], 16),
16,
);
assert_eq!(out, pack_nodes(&[1], 16));
}
#[test]
fn emitted_program_has_expected_shape() {
let p = adaptive_dense_step("fin", "fout", "adj", 64);
assert_eq!(p.workgroup_size, [1, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["fin", "fout", "adj"]);
let find = |name: &str| p.buffers.iter().find(|b| b.name() == name).unwrap().count;
let words = bitset_words(64);
assert_eq!(find("fin"), words);
assert_eq!(find("fout"), words);
assert_eq!(find("adj"), 64 * words);
}
#[test]
fn selector_roundtrip_common_density_profiles() {
assert!(!should_use_dense(&pack_nodes(&[5], 512), 512));
let mut f = vec![0_u32; bitset_words(512) as usize];
for b in 0..256_u32 {
f[b as usize / 32] |= 1 << (b % 32);
}
assert!(should_use_dense(&f, 512));
}
}