use crate::error::{Error, Result};
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
use crate::ops::graph::csr::CsrGraph;
pub use crate::ops::graph::MAX_BFS_QUEUE;
use crate::ops::{AlgebraicLaw, OpSpec, BYTES_TO_U32_OUTPUTS, U32_INPUTS};
const OP_ID: &str = "graph.dfs";
const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];
#[derive(Debug, Default, Clone, Copy)]
pub struct Dfs;
impl Dfs {
pub const SPEC: OpSpec =
OpSpec::composition(OP_ID, U32_INPUTS, BYTES_TO_U32_OUTPUTS, LAWS, Self::program);
#[must_use]
pub fn program() -> Program {
Self::program_with_stack_size(MAX_BFS_QUEUE)
}
#[must_use]
pub fn program_with_stack_size(max_stack_slots: u32) -> Program {
Program::new(
vec![
BufferDecl::read("node_labels", 0, DataType::U32),
BufferDecl::read("edge_offsets", 1, DataType::U32),
BufferDecl::read("edge_targets", 2, DataType::U32),
BufferDecl::read("source_nodes", 3, DataType::U32),
BufferDecl::output("findings", 4, DataType::U32),
BufferDecl::read_write("finding_count", 5, DataType::U32),
BufferDecl::read("params", 6, DataType::U32),
BufferDecl::read_write("visited_set", 7, DataType::U32),
BufferDecl::workgroup("stack", max_stack_slots, DataType::U32),
BufferDecl::workgroup("stack_depth", max_stack_slots, DataType::U32),
],
WORKGROUP_SIZE,
vec![
Node::let_bind("source_idx", Expr::gid_x()),
Node::if_then(
valid_source_invocation(Expr::var("source_idx")),
vec![
Node::let_bind(
"start_node",
Expr::load("source_nodes", Expr::var("source_idx")),
),
Node::if_then(
Expr::lt(Expr::var("start_node"), node_count()),
vec![
Node::let_bind("stack_size", Expr::u32(1)),
Node::store("stack", Expr::u32(0), Expr::var("start_node")),
Node::store("stack_depth", Expr::u32(0), Expr::u32(0)),
mark_visited(Expr::var("source_idx"), Expr::var("start_node")),
Node::loop_for(
"step",
Expr::u32(0),
node_count(),
vec![
Node::if_then(
Expr::eq(Expr::var("stack_size"), Expr::u32(0)),
vec![Node::return_()],
),
Node::assign(
"stack_size",
Expr::sub(Expr::var("stack_size"), Expr::u32(1)),
),
Node::let_bind(
"node",
Expr::load("stack", Expr::var("stack_size")),
),
Node::let_bind(
"depth",
Expr::load("stack_depth", Expr::var("stack_size")),
),
report_if_sink(
Expr::var("start_node"),
Expr::var("node"),
Expr::var("depth"),
Expr::var("source_idx"),
),
Node::if_then(
Expr::lt(
Expr::var("depth"),
Expr::load("params", Expr::u32(3)),
),
vec![push_neighbors(
Expr::var("source_idx"),
Expr::var("node"),
Expr::var("depth"),
Expr::var("stack_size"),
max_stack_slots,
)],
),
],
),
],
),
],
),
],
)
.with_entry_op_id(OP_ID)
}
}
fn node_count() -> Expr {
Expr::load("params", Expr::u32(1))
}
fn words_per_source() -> Expr {
Expr::load("params", Expr::u32(4))
}
fn valid_source_invocation(source_idx: Expr) -> Expr {
Expr::and(
Expr::lt(source_idx.clone(), Expr::load("params", Expr::u32(0))),
Expr::and(
Expr::ne(words_per_source(), Expr::u32(0)),
Expr::le(
source_idx,
Expr::div(Expr::u32(u32::MAX), words_per_source()),
),
),
)
}
fn mark_visited(source_idx: Expr, node: Expr) -> Node {
let base = Expr::mul(source_idx, words_per_source());
let word = Expr::add(base, Expr::shr(node.clone(), Expr::u32(5)));
let bit = Expr::shl(Expr::u32(1), Expr::bitand(node, Expr::u32(31)));
Node::store(
"visited_set",
word.clone(),
Expr::bitor(Expr::load("visited_set", word), bit),
)
}
fn not_visited(source_idx: Expr, node: Expr) -> Expr {
let base = Expr::mul(source_idx, words_per_source());
let word = Expr::add(base, Expr::shr(node.clone(), Expr::u32(5)));
let bit = Expr::shl(Expr::u32(1), Expr::bitand(node, Expr::u32(31)));
Expr::eq(
Expr::bitand(Expr::load("visited_set", word), bit),
Expr::u32(0),
)
}
fn report_if_sink(start_node: Expr, node: Expr, depth: Expr, source_idx: Expr) -> Node {
let label = Expr::bitand(
Expr::shr(Expr::load("node_labels", node.clone()), Expr::u32(16)),
Expr::u32(0xff),
);
Node::if_then(
Expr::and(
Expr::or(
Expr::eq(label.clone(), Expr::u32(2)),
Expr::eq(label, Expr::u32(3)),
),
Expr::ne(node.clone(), start_node.clone()),
),
vec![
Node::let_bind(
"finding_idx",
Expr::atomic_add("finding_count", Expr::u32(0), Expr::u32(1)),
),
Node::if_then(
Expr::lt(Expr::var("finding_idx"), Expr::load("params", Expr::u32(2))),
vec![
Node::store(
"findings",
Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
start_node,
),
Node::store(
"findings",
Expr::add(
Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
Expr::u32(1),
),
node,
),
Node::store(
"findings",
Expr::add(
Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
Expr::u32(2),
),
depth,
),
Node::store(
"findings",
Expr::add(
Expr::mul(Expr::var("finding_idx"), Expr::u32(4)),
Expr::u32(3),
),
source_idx,
),
],
),
],
)
}
fn push_neighbors(
source_idx: Expr,
node: Expr,
depth: Expr,
stack_size: Expr,
max_stack_slots: u32,
) -> Node {
Node::loop_for(
"edge",
Expr::load("edge_offsets", node.clone()),
Expr::load("edge_offsets", Expr::add(node, Expr::u32(1))),
vec![
Node::let_bind("tgt_node", Expr::load("edge_targets", Expr::var("edge"))),
Node::if_then(
Expr::lt(Expr::var("tgt_node"), node_count()),
vec![Node::if_then(
not_visited(source_idx.clone(), Expr::var("tgt_node")),
vec![
mark_visited(source_idx.clone(), Expr::var("tgt_node")),
Node::if_then_else(
Expr::lt(stack_size.clone(), Expr::u32(max_stack_slots)),
vec![
Node::store("stack", stack_size.clone(), Expr::var("tgt_node")),
Node::store(
"stack_depth",
stack_size.clone(),
Expr::add(depth.clone(), Expr::u32(1)),
),
Node::assign(
"stack_size",
Expr::add(stack_size.clone(), Expr::u32(1)),
),
],
vec![Node::return_()],
),
],
)],
),
],
)
}
pub fn validate_edge_offsets(csr: &CsrGraph) -> Result<()> {
let expected_offsets = csr.node_count().checked_add(1).ok_or_else(|| Error::Csr {
message:
"CsrInvalid: node_count + 1 overflows usize. Fix: split the graph before DFS dispatch."
.to_string(),
})?;
if csr.offsets.len() != expected_offsets {
return Err(Error::Csr {
message: format!(
"CsrInvalid: edge_offsets length {} does not equal num_nodes + 1 ({expected_offsets}). Fix: rebuild CSR offsets before DFS dispatch.",
csr.offsets.len()
),
});
}
csr.validate()
}
pub fn validate_stack_capacity(csr: &CsrGraph, sources: &[u32], max_depth: u32) -> Result<()> {
validate_edge_offsets(csr)?;
let max_stack = usize::try_from(MAX_BFS_QUEUE).map_err(|error| Error::Csr {
message: format!(
"Overflow: context MAX_BFS_QUEUE value {MAX_BFS_QUEUE} cannot fit usize: {error}. Fix: lower MAX_BFS_QUEUE for this target platform."
),
})?;
for &source in sources {
let Ok(source_index) = usize::try_from(source) else {
continue;
};
if source_index >= csr.node_count() {
continue;
}
validate_source_stack(csr, source, max_depth, max_stack)?;
}
Ok(())
}
fn validate_source_stack(
csr: &CsrGraph,
source: u32,
max_depth: u32,
max_stack_slots: usize,
) -> Result<()> {
let node_count = csr.node_count();
let mut visited = vec![false; node_count];
let source_index = usize::try_from(source).map_err(|error| Error::Csr {
message: format!(
"InvalidEdge: source {source} cannot fit usize: {error}. Fix: use source node ids representable on this platform."
),
})?;
let mut stack = vec![(source, 0u32)];
visited[source_index] = true;
let mut max_stack = 1usize;
while let Some((node, depth)) = stack.pop() {
if depth >= max_depth {
continue;
}
let node_index = usize::try_from(node).map_err(|error| Error::Csr {
message: format!(
"CsrInvalid: queued node {node} cannot fit usize: {error}. Fix: rebuild CSR with platform-sized node ids."
),
})?;
let next_node_index = node_index.checked_add(1).ok_or_else(|| Error::Csr {
message:
"CsrInvalid: node_index + 1 overflows usize. Fix: rebuild CSR with fewer nodes."
.to_string(),
})?;
let start = usize::try_from(csr.offsets[node_index]).map_err(|error| Error::Csr {
message: format!(
"CsrInvalid: edge offset cannot fit usize: {error}. Fix: rebuild CSR with fewer edges."
),
})?;
let end = usize::try_from(csr.offsets[next_node_index]).map_err(|error| Error::Csr {
message: format!(
"CsrInvalid: edge offset cannot fit usize: {error}. Fix: rebuild CSR with fewer edges."
),
})?;
for &target in &csr.targets[start..end] {
let target_index = usize::try_from(target).map_err(|error| Error::Csr {
message: format!(
"CsrInvalid: target {target} cannot fit usize: {error}. Fix: rebuild CSR with platform-sized node ids."
),
})?;
if !visited[target_index] {
visited[target_index] = true;
stack.push((target, depth.saturating_add(1)));
if stack.len() > max_stack {
max_stack = stack.len();
}
if max_stack > max_stack_slots {
return Err(Error::Csr {
message: format!(
"DfsStackTooLarge: stack depth {max_stack} exceeds {max_stack_slots}. Fix: increase max_stack_slots or split the graph."
),
});
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir;
use crate::lower::wgsl;
use crate::ops::graph::csr::{to_csr, CsrGraph};
#[test]
pub(crate) fn dfs_spec_is_category_a_composition() {
let spec = &Dfs::SPEC;
assert_eq!(spec.id(), "graph.dfs");
let program = spec.program().expect("must build program");
assert!(!program.entry().is_empty());
let errors = ir::validate(&program);
assert!(errors.is_empty(), "validation failed: {errors:?}");
wgsl::lower(&program).expect("must lower to WGSL");
}
#[test]
pub(crate) fn dfs_program_with_custom_stack_size_builds() {
let program = Dfs::program_with_stack_size(1024);
assert_eq!(program.buffer("stack").unwrap().count(), 1024);
assert_eq!(program.buffer("stack_depth").unwrap().count(), 1024);
}
#[test]
pub(crate) fn validate_edge_offsets_rejects_malformed_csr() {
let csr = CsrGraph {
offsets: vec![0, 1],
targets: vec![0, 1], node_data: vec![0],
};
assert!(validate_edge_offsets(&csr).is_err());
}
#[test]
pub(crate) fn validate_stack_capacity_passes_for_small_graph() {
let csr = to_csr(3, &[(0, 1), (1, 2)]).unwrap();
assert!(validate_stack_capacity(&csr, &[0], 10).is_ok());
}
#[test]
pub(crate) fn validate_stack_capacity_fails_for_deep_graph() {
let mut edges = Vec::new();
for target in 1..10_001 {
edges.push((0, target));
}
let csr = to_csr(10_001, &edges).unwrap();
assert!(validate_stack_capacity(&csr, &[0], 10_001).is_err());
}
#[test]
pub(crate) fn validate_stack_capacity_respects_max_depth() {
let mut edges = Vec::new();
for i in 0..10_000 {
edges.push((i, i + 1));
}
let csr = to_csr(10_001, &edges).unwrap();
assert!(validate_stack_capacity(&csr, &[0], 100).is_ok());
}
#[test]
pub(crate) fn validate_stack_capacity_skips_out_of_range_sources() {
let csr = to_csr(2, &[(0, 1)]).unwrap();
assert!(validate_stack_capacity(&csr, &[0, 5], 10).is_ok());
}
}