use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::graph::persistent_bfs_step::persistent_bfs_step_child_prefixed;
use crate::graph::program_graph::{ProgramGraphShape, BINDING_PRIMITIVE_START};
pub const OP_ID: &str = "vyre-primitives::graph::persistent_bfs";
pub const BINDING_FRONTIER_IN: u32 = BINDING_PRIMITIVE_START;
pub const BINDING_FRONTIER_OUT: u32 = BINDING_PRIMITIVE_START + 1;
pub const BINDING_CHANGED: u32 = BINDING_PRIMITIVE_START + 2;
#[must_use]
pub const fn bitset_words(node_count: u32) -> u32 {
crate::bitset::bitset_words(node_count)
}
#[must_use]
pub fn persistent_bfs(
shape: ProgramGraphShape,
frontier_in: &str,
frontier_out: &str,
edge_kind_mask: u32,
max_iters: u32,
) -> Program {
let words = bitset_words(shape.node_count);
let t = Expr::gid_x();
let unrolled_iter = |iter: u32| -> Node {
persistent_bfs_step_child_prefixed(
OP_ID,
shape,
frontier_out,
"changed",
"wg_scratch",
edge_kind_mask,
&format!("unroll_{iter}"),
)
};
let mut entry: Vec<Node> = vec![
Node::let_bind("seed_word_idx", t.clone()),
Node::if_then(
Expr::lt(Expr::var("seed_word_idx"), Expr::u32(words)),
vec![Node::store(
frontier_out,
Expr::var("seed_word_idx"),
Expr::load(frontier_in, Expr::var("seed_word_idx")),
)],
),
Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![Node::store("changed", Expr::u32(0), Expr::u32(0))],
),
Node::barrier(),
];
let unroll_count = max_iters.min(4);
for iter in 0..unroll_count {
entry.push(unrolled_iter(iter));
}
let remaining = max_iters.saturating_sub(unroll_count);
if remaining > 0 {
entry.push(Node::loop_for(
"iter",
Expr::u32(0),
Expr::u32(remaining),
vec![
Node::let_bind("local_changed", Expr::u32(0)),
Node::if_then(
Expr::lt(t.clone(), Expr::u32(shape.node_count)),
vec![
crate::graph::csr_forward_or_changed::csr_forward_or_changed_child_prefixed(
OP_ID,
shape,
frontier_out,
"local_changed",
edge_kind_mask,
"remaining_csr",
),
],
),
Node::if_then(
Expr::eq(Expr::var("local_changed"), Expr::u32(1)),
vec![Node::let_bind(
"_",
Expr::atomic_or("changed", Expr::u32(0), Expr::u32(1)),
)],
),
],
));
}
let mut buffers = shape.read_only_buffers();
buffers.push(
BufferDecl::storage(
frontier_in,
BINDING_FRONTIER_IN,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(words.max(1)),
);
buffers.push(
BufferDecl::storage(
frontier_out,
BINDING_FRONTIER_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(words.max(1)),
);
buffers.push(
BufferDecl::storage(
"changed",
BINDING_CHANGED,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(1),
);
buffers.push(BufferDecl::workgroup("wg_scratch", 256, DataType::U32));
Program::wrapped(
buffers,
[1, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(entry),
}],
)
}
#[must_use]
pub fn cpu_ref(
node_count: u32,
edge_offsets: &[u32],
edge_targets: &[u32],
edge_kind_mask: &[u32],
frontier_in: &[u32],
allow_mask: u32,
max_iters: u32,
) -> (Vec<u32>, u32) {
let words = bitset_words(node_count) as usize;
let mut out = frontier_in.to_vec();
let mut changed = 0u32;
for _ in 0..max_iters {
let step = crate::graph::csr_forward_traverse::cpu_ref(
node_count,
edge_offsets,
edge_targets,
edge_kind_mask,
&out,
allow_mask,
);
let mut step_changed = false;
for w in 0..words {
let old = out[w];
out[w] |= step[w];
if out[w] != old {
step_changed = true;
}
}
if step_changed {
changed = 1;
} else {
break;
}
}
out.resize(words, 0);
(out, changed)
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| persistent_bfs(ProgramGraphShape::new(4, 4), "fin", "fout", 0xFFFF_FFFF, 4),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[0, 0, 0, 0]), to_bytes(&[0, 2, 3, 4, 4]), to_bytes(&[1, 2, 3, 3]), to_bytes(&[1, 1, 1, 1]), to_bytes(&[0, 0, 0, 0]), to_bytes(&[0b0001]), to_bytes(&[0]), to_bytes(&[0]), ]]
}),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[0b1111]), to_bytes(&[1]), ]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn persistent_bfs_reaches_closure() {
let (frontier, changed) = cpu_ref(
4,
&[0, 2, 3, 4, 4],
&[1, 2, 3, 3],
&[1, 1, 1, 1],
&[0b0001],
0xFFFF_FFFF,
4,
);
assert_eq!(frontier, vec![0b1111]);
assert_eq!(changed, 1);
}
#[test]
fn empty_frontier_stays_empty() {
let (frontier, changed) = cpu_ref(
4,
&[0, 2, 3, 4, 4],
&[1, 2, 3, 3],
&[1, 1, 1, 1],
&[0],
0xFFFF_FFFF,
4,
);
assert_eq!(frontier, vec![0]);
assert_eq!(changed, 0);
}
#[test]
fn edge_mask_limits_reachability() {
let (frontier, changed) = cpu_ref(
4,
&[0, 2, 3, 4, 4],
&[1, 2, 3, 3],
&[0b10, 0b01, 0b01, 0b01],
&[0b0001],
0b01,
4,
);
assert_eq!(frontier, vec![0b1101]);
assert_eq!(changed, 1);
}
#[test]
fn max_iters_caps_expansion() {
let (frontier, changed) = cpu_ref(
4,
&[0, 1, 2, 3, 3],
&[1, 2, 3],
&[1, 1, 1],
&[0b0001],
0xFFFF_FFFF,
2,
);
assert_eq!(frontier, vec![0b0111]);
assert_eq!(changed, 1);
}
#[test]
fn zero_max_iters_is_noop() {
let (frontier, changed) = cpu_ref(
4,
&[0, 2, 3, 4, 4],
&[1, 2, 3, 3],
&[1, 1, 1, 1],
&[0b0001],
0xFFFF_FFFF,
0,
);
assert_eq!(frontier, vec![0b0001]);
assert_eq!(changed, 0);
}
#[test]
fn program_builds_and_validates() {
let program = persistent_bfs(ProgramGraphShape::new(8, 8), "fin", "fout", 0xFF, 4);
assert_eq!(program.workgroup_size, [1, 1, 1]);
assert_eq!(program.buffers().len(), 9);
}
}