use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::graph::program_graph::{
ProgramGraphShape, BINDING_PRIMITIVE_START, NAME_EDGE_KIND_MASK, NAME_EDGE_OFFSETS,
NAME_EDGE_TARGETS,
};
pub const OP_ID: &str = "vyre-primitives::graph::csr_forward_traverse";
pub const BINDING_FRONTIER_IN: u32 = BINDING_PRIMITIVE_START;
pub const BINDING_FRONTIER_OUT: u32 = BINDING_PRIMITIVE_START + 1;
#[must_use]
pub const fn bitset_words(node_count: u32) -> u32 {
crate::bitset::bitset_words(node_count)
}
#[must_use]
pub fn csr_forward_traverse(
shape: ProgramGraphShape,
frontier_in: &str,
frontier_out: &str,
allow_mask: u32,
) -> Program {
let t = Expr::InvocationId { axis: 0 };
let words = bitset_words(shape.node_count);
let body = vec![
Node::let_bind("src", t.clone()),
Node::let_bind("word_idx", Expr::shr(Expr::var("src"), Expr::u32(5))),
Node::let_bind(
"bit_mask",
Expr::shl(Expr::u32(1), Expr::bitand(Expr::var("src"), Expr::u32(31))),
),
Node::let_bind("src_word", Expr::load(frontier_in, Expr::var("word_idx"))),
Node::if_then(
Expr::ne(
Expr::bitand(Expr::var("src_word"), Expr::var("bit_mask")),
Expr::u32(0),
),
vec![
Node::let_bind(
"edge_start",
Expr::load(NAME_EDGE_OFFSETS, Expr::var("src")),
),
Node::let_bind(
"edge_end",
Expr::load(NAME_EDGE_OFFSETS, Expr::add(Expr::var("src"), Expr::u32(1))),
),
Node::loop_for(
"e",
Expr::var("edge_start"),
Expr::var("edge_end"),
vec![
Node::let_bind(
"kind_mask",
Expr::load(NAME_EDGE_KIND_MASK, Expr::var("e")),
),
Node::if_then(
Expr::ne(
Expr::bitand(Expr::var("kind_mask"), Expr::u32(allow_mask)),
Expr::u32(0),
),
vec![
Node::let_bind(
"dst",
Expr::load(NAME_EDGE_TARGETS, Expr::var("e")),
),
Node::if_then(
Expr::lt(Expr::var("dst"), Expr::u32(shape.node_count)),
vec![
Node::let_bind(
"dst_word_idx",
Expr::shr(Expr::var("dst"), Expr::u32(5)),
),
Node::let_bind(
"dst_bit",
Expr::shl(
Expr::u32(1),
Expr::bitand(Expr::var("dst"), Expr::u32(31)),
),
),
Node::let_bind(
"_prev",
Expr::atomic_or(
frontier_out,
Expr::var("dst_word_idx"),
Expr::var("dst_bit"),
),
),
],
),
],
),
],
),
],
),
];
let mut buffers = shape.read_only_buffers();
buffers.push(
BufferDecl::storage(
frontier_in,
BINDING_FRONTIER_IN,
BufferAccess::ReadOnly,
DataType::U32,
)
.with_count(words),
);
buffers.push(
BufferDecl::storage(
frontier_out,
BINDING_FRONTIER_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(words),
);
Program::wrapped(
buffers,
[1, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(shape.node_count)),
body,
)]),
}],
)
}
#[must_use]
pub fn cpu_ref(
node_count: u32,
edge_offsets: &[u32],
edge_targets: &[u32],
edge_kind_mask: &[u32],
frontier_in: &[u32],
allow_mask: u32,
) -> Vec<u32> {
let mut out = Vec::new();
cpu_ref_into(
node_count,
edge_offsets,
edge_targets,
edge_kind_mask,
frontier_in,
allow_mask,
&mut out,
);
out
}
pub fn cpu_ref_into(
node_count: u32,
edge_offsets: &[u32],
edge_targets: &[u32],
edge_kind_mask: &[u32],
frontier_in: &[u32],
allow_mask: u32,
out: &mut Vec<u32>,
) {
let words = bitset_words(node_count) as usize;
out.clear();
out.resize(words, 0);
let expected_offsets = node_count as usize + 1;
if edge_offsets.len() != expected_offsets {
return;
}
let edge_count = edge_offsets.last().copied().unwrap_or_default() as usize;
if edge_targets.len() < edge_count || edge_kind_mask.len() < edge_count {
return;
}
for src in 0..node_count {
let word_idx = (src / 32) as usize;
let bit_mask = 1u32 << (src % 32);
if word_idx >= frontier_in.len() {
continue;
}
if (frontier_in[word_idx] & bit_mask) == 0 {
continue;
}
let edge_start = edge_offsets[src as usize] as usize;
let edge_end = edge_offsets[src as usize + 1] as usize;
for e in edge_start..edge_end {
let kind = edge_kind_mask[e];
if (kind & allow_mask) == 0 {
continue;
}
let dst = edge_targets[e];
if dst < node_count {
let dst_word = (dst / 32) as usize;
let dst_bit = 1u32 << (dst % 32);
out[dst_word] |= dst_bit;
}
}
}
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| csr_forward_traverse(ProgramGraphShape::new(4, 4), "fin", "fout", 0xFFFF_FFFF),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![
to_bytes(&[0, 0, 0, 0]), to_bytes(&[0, 2, 3, 4, 4]), to_bytes(&[1, 2, 3, 3]), to_bytes(&[1, 1, 1, 1]), to_bytes(&[0, 0, 0, 0]), to_bytes(&[0b0001]), to_bytes(&[0]), ]]
}),
Some(|| {
let to_bytes = |w: &[u32]| w.iter().flat_map(|v| v.to_le_bytes()).collect::<Vec<u8>>();
vec![vec![to_bytes(&[0b0110])]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_step_reaches_immediate_successors() {
let got = cpu_ref(
4,
&[0, 2, 3, 4, 4],
&[1, 2, 3, 3],
&[1, 1, 1, 1],
&[0b0001],
0xFFFF_FFFF,
);
assert_eq!(got, vec![0b0110]);
}
#[test]
fn edge_mask_filters_disallowed_edges() {
let got = cpu_ref(
4,
&[0, 2, 3, 4, 4],
&[1, 2, 3, 3],
&[0b10, 0b01, 0b01, 0b01],
&[0b0001],
0b01,
);
assert_eq!(got, vec![0b0100]);
}
#[test]
fn empty_frontier_produces_empty_output() {
let got = cpu_ref(
4,
&[0, 2, 3, 4, 4],
&[1, 2, 3, 3],
&[1, 1, 1, 1],
&[0],
0xFFFF_FFFF,
);
assert_eq!(got, vec![0]);
}
#[test]
fn malformed_csr_returns_empty_frontier_without_panicking() {
let got = cpu_ref(4, &[0, 1], &[1], &[1], &[0b0001], 0xFFFF_FFFF);
assert_eq!(got, vec![0]);
}
#[test]
fn cpu_ref_into_reuses_output_buffer() {
let mut out = Vec::with_capacity(8);
let ptr = out.as_ptr();
cpu_ref_into(
4,
&[0, 2, 3, 4, 4],
&[1, 2, 3, 3],
&[1, 1, 1, 1],
&[0b0001],
0xFFFF_FFFF,
&mut out,
);
assert_eq!(out.as_ptr(), ptr);
assert_eq!(out, vec![0b0110]);
}
}