pub use crate::ops::graph::MAX_BFS_QUEUE;
use crate::error::{Error, Result};
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::AlgebraicLaw;
use crate::ops::graph::csr::CsrGraph;
use crate::ops::{OpSpec, BYTES_TO_U32_OUTPUTS, U32_INPUTS};
use std::collections::VecDeque;
#[derive(Debug, Default, Clone, Copy)]
pub struct Bfs;
impl Bfs {
pub const SPEC: OpSpec =
OpSpec::composition(OP_ID, U32_INPUTS, BYTES_TO_U32_OUTPUTS, LAWS, Self::program);
#[must_use]
pub fn program() -> Program {
Self::program_with_queue_size(MAX_BFS_QUEUE)
}
#[must_use]
pub fn program_with_queue_size(max_queue_slots: u32) -> Program {
let source_idx = Expr::var("source_idx");
let start_node = Expr::var("start_node");
let node = Expr::var("node");
let tgt_node = Expr::var("tgt_node");
let queue_tail = Expr::var("queue_tail");
let queue_head = Expr::var("queue_head");
let depth = Expr::var("depth");
Program::new(
vec![
BufferDecl::read("node_labels", 0, DataType::U32),
BufferDecl::read("edge_offsets", 1, DataType::U32),
BufferDecl::read("edge_targets", 2, DataType::U32),
BufferDecl::read("source_nodes", 3, DataType::U32),
BufferDecl::output("findings", 4, DataType::U32),
BufferDecl::read_write("finding_count", 5, DataType::U32),
BufferDecl::read("params", 6, DataType::U32),
BufferDecl::read_write("visited_set", 7, DataType::U32),
BufferDecl::workgroup("queue", max_queue_slots, DataType::U32),
],
WORKGROUP_SIZE,
vec![
Node::let_bind("source_idx", Expr::gid_x()),
Node::if_then(
valid_source_invocation(source_idx.clone()),
vec![
Node::let_bind(
"start_node",
Expr::load("source_nodes", source_idx.clone()),
),
Node::if_then(
Expr::lt(start_node.clone(), node_count()),
vec![
Node::let_bind("queue_head", Expr::u32(0)),
Node::let_bind("queue_tail", Expr::u32(1)),
Node::let_bind("depth", Expr::u32(0)),
Node::store("queue", Expr::u32(0), start_node.clone()),
mark_visited(source_idx.clone(), start_node.clone()),
Node::Loop {
var: "level".to_string(),
from: Expr::u32(0),
to: Expr::load("params", Expr::u32(3)),
body: vec![
Node::let_bind("level_end", queue_tail.clone()),
Node::Loop {
var: "slot".to_string(),
from: queue_head.clone(),
to: Expr::var("level_end"),
body: vec![
Node::let_bind(
"node",
Expr::load("queue", Expr::var("slot")),
),
report_if_sink(
start_node.clone(),
node.clone(),
depth.clone(),
source_idx.clone(),
),
scan_neighbors(
source_idx.clone(),
node,
tgt_node,
queue_tail.clone(),
max_queue_slots,
),
],
},
Node::assign("queue_head", Expr::var("level_end")),
Node::assign("depth", Expr::add(depth, Expr::u32(1))),
],
},
],
),
],
),
],
)
.with_entry_op_id(OP_ID)
}
}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
pub fn mark_visited(source_idx: Expr, node: Expr) -> Node {
let base = Expr::mul(source_idx, words_per_source());
let word = Expr::add(base, Expr::shr(node.clone(), Expr::u32(5)));
let bit = Expr::shl(Expr::u32(1), Expr::bitand(node, Expr::u32(31)));
Node::store(
"visited_set",
word.clone(),
Expr::bitor(Expr::load("visited_set", word), bit),
)
}
pub fn node_count() -> Expr {
Expr::load("params", Expr::u32(1))
}
pub fn not_visited(source_idx: Expr, node: Expr) -> Expr {
let base = Expr::mul(source_idx, words_per_source());
let word = Expr::add(base, Expr::shr(node.clone(), Expr::u32(5)));
let bit = Expr::shl(Expr::u32(1), Expr::bitand(node, Expr::u32(31)));
Expr::eq(
Expr::bitand(Expr::load("visited_set", word), bit),
Expr::u32(0),
)
}
pub const OP_ID: &str = "graph.bfs";
pub fn report_if_sink(start_node: Expr, node: Expr, depth: Expr, source_idx: Expr) -> Node {
let label = Expr::bitand(
Expr::shr(Expr::load("node_labels", node.clone()), Expr::u32(16)),
Expr::u32(0xff),
);
Node::if_then(
Expr::BinOp {
op: crate::ir::BinOp::And,
left: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Or,
left: Box::new(Expr::eq(label.clone(), Expr::u32(2))),
right: Box::new(Expr::eq(label, Expr::u32(3))),
}),
right: Box::new(Expr::BinOp {
op: crate::ir::BinOp::Ne,
left: Box::new(node.clone()),
right: Box::new(start_node.clone()),
}),
},
vec![
Node::let_bind(
"finding_idx",
Expr::atomic_add("finding_count", Expr::u32(0), Expr::u32(1)),
),
Node::if_then(
Expr::lt(Expr::var("finding_idx"), Expr::load("params", Expr::u32(2))),
vec![
Node::store(
"findings",
Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
start_node,
),
Node::store(
"findings",
Expr::add(
Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
Expr::u32(1),
),
node,
),
Node::store(
"findings",
Expr::add(
Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
Expr::u32(2),
),
depth,
),
Node::store(
"findings",
Expr::add(
Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
Expr::u32(3),
),
source_idx,
),
],
),
],
)
}
pub fn scan_neighbors(
source_idx: Expr,
node: Expr,
tgt_node: Expr,
queue_tail: Expr,
max_queue_slots: u32,
) -> Node {
Node::Loop {
var: "edge".to_string(),
from: Expr::load("edge_offsets", node.clone()),
to: Expr::load("edge_offsets", Expr::add(node, Expr::u32(1))),
body: vec![
Node::let_bind("tgt_node", Expr::load("edge_targets", Expr::var("edge"))),
Node::if_then(
Expr::lt(tgt_node.clone(), node_count()),
vec![Node::if_then(
not_visited(source_idx.clone(), tgt_node.clone()),
vec![
mark_visited(source_idx, tgt_node.clone()),
Node::if_then_else(
Expr::lt(queue_tail.clone(), Expr::u32(max_queue_slots)),
vec![
Node::store("queue", queue_tail.clone(), tgt_node),
Node::assign("queue_tail", Expr::add(queue_tail, Expr::u32(1))),
],
vec![Node::Return],
),
],
)],
),
],
}
}
pub fn valid_source_invocation(source_idx: Expr) -> Expr {
Expr::and(
Expr::lt(source_idx.clone(), Expr::load("params", Expr::u32(0))),
Expr::and(
Expr::ne(words_per_source(), Expr::u32(0)),
Expr::le(
source_idx,
Expr::div(Expr::u32(u32::MAX), words_per_source()),
),
),
)
}
pub fn validate_edge_offsets(csr: &CsrGraph) -> Result<()> {
let expected_offsets = csr.node_count().checked_add(1).ok_or_else(|| Error::Csr {
message:
"CsrInvalid: node_count + 1 overflows usize. Fix: split the graph before BFS dispatch."
.to_string(),
})?;
if csr.offsets.len() != expected_offsets {
return Err(Error::Csr {
message: format!(
"CsrInvalid: edge_offsets length {} does not equal num_nodes + 1 ({expected_offsets}). Fix: rebuild CSR offsets before BFS dispatch.",
csr.offsets.len()
),
});
}
csr.validate()
}
pub fn validate_frontier_capacity(csr: &CsrGraph, sources: &[u32], max_depth: u32) -> Result<()> {
validate_edge_offsets(csr)?;
let max_queue = usize::try_from(MAX_BFS_QUEUE).map_err(|error| Error::Csr {
message: format!(
"Overflow: context MAX_BFS_QUEUE value {MAX_BFS_QUEUE} cannot fit usize: {error}. Fix: lower MAX_BFS_QUEUE for this target platform."
),
})?;
for &source in sources {
let Ok(source_index) = usize::try_from(source) else {
continue;
};
if source_index >= csr.node_count() {
continue;
}
validate_source_frontier(csr, source, max_depth, max_queue)?;
}
Ok(())
}
pub fn validate_source_frontier(
csr: &CsrGraph,
source: u32,
max_depth: u32,
max_queue: usize,
) -> Result<()> {
let node_count = csr.node_count();
let mut visited = vec![false; node_count];
let source_index = usize::try_from(source).map_err(|error| Error::Csr {
message: format!(
"InvalidEdge: source {source} cannot fit usize: {error}. Fix: use source node ids representable on this platform."
),
})?;
let mut queue = VecDeque::from([(source, 0u32)]);
visited[source_index] = true;
while let Some((node, depth)) = queue.pop_front() {
if depth >= max_depth {
continue;
}
let node_index = usize::try_from(node).map_err(|error| Error::Csr {
message: format!(
"CsrInvalid: queued node {node} cannot fit usize: {error}. Fix: rebuild CSR with platform-sized node ids."
),
})?;
let next_node_index = node_index.checked_add(1).ok_or_else(|| Error::Csr {
message:
"CsrInvalid: node_index + 1 overflows usize. Fix: rebuild CSR with fewer nodes."
.to_string(),
})?;
let start = usize::try_from(csr.offsets[node_index]).map_err(|error| Error::Csr {
message: format!(
"CsrInvalid: edge offset cannot fit usize: {error}. Fix: rebuild CSR with fewer edges."
),
})?;
let end = usize::try_from(csr.offsets[next_node_index]).map_err(|error| Error::Csr {
message: format!(
"CsrInvalid: edge offset cannot fit usize: {error}. Fix: rebuild CSR with fewer edges."
),
})?;
for &target in &csr.targets[start..end] {
let target_index = usize::try_from(target).map_err(|error| Error::Csr {
message: format!(
"CsrInvalid: target {target} cannot fit usize: {error}. Fix: rebuild CSR with platform-sized node ids."
),
})?;
if !visited[target_index] {
visited[target_index] = true;
queue.push_back((target, depth.saturating_add(1)));
if queue.len() > max_queue {
return Err(Error::Csr {
message: format!(
"BfsFrontierTooLarge: frontier {} exceeds {MAX_BFS_QUEUE}. Fix: Split the graph into smaller subgraphs or implement multi-pass BFS",
queue.len()
),
});
}
}
}
}
Ok(())
}
pub fn words_per_source() -> Expr {
Expr::load("params", Expr::u32(4))
}
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];