use crate::engine::dataflow::{
checked_u32, create_bind_group, create_buffer, create_pipeline, dispatch_and_copy,
read_findings, validate_inputs, BfsBindGroupInputs, GpuFinding, GpuParams, MAX_FINDINGS,
WORKGROUP_SIZE,
};
use vyre::error::{Error, Result};
use vyre::ops::graph::csr::CsrGraph;
use crate::runtime::cache::BufferPool;
#[expect(
clippy::too_many_lines,
reason = "this function owns one GPU BFS dispatch and keeps resource sizing adjacent to validation"
)]
pub fn bfs_reachability(
device: &wgpu::Device,
queue: &wgpu::Queue,
csr: &CsrGraph,
sources: &[u32],
max_depth: u32,
command_encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<Vec<(u32, u32, u32)>> {
validate_inputs(csr, sources)?;
if sources.is_empty() || csr.node_count() == 0 {
return Ok(Vec::new());
}
let node_count = checked_u32(csr.node_count(), "node_count")?;
let max_findings = checked_u32(
sources
.len()
.checked_mul(csr.node_count())
.ok_or_else(|| Error::Dataflow {
message: "source_count * node_count overflows usize. Fix: reduce source count or graph size before GPU dispatch.".to_string(),
})?,
"max_findings",
)?;
let max_findings_usize = usize::try_from(max_findings).map_err(|source| Error::Dataflow {
message: format!("max_findings {max_findings} cannot fit usize: {source}. Fix: reject this dispatch on this platform."),
})?;
if max_findings_usize > MAX_FINDINGS {
return Err(Error::Dataflow {
message: format!(
"max_findings {max_findings_usize} exceeds limit {MAX_FINDINGS}. Fix: split sources or graph partitions before GPU dispatch."
),
});
}
if max_findings == 0 {
return Ok(Vec::new());
}
let params = GpuParams {
num_sources: checked_u32(sources.len(), "source_count")?,
num_nodes: node_count,
max_findings,
max_depth,
words_per_source: node_count.div_ceil(32),
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
let queue_slots = crate::engine::bfs_queue_slots_for_device::bfs_queue_slots_for_device(device, None);
let pipeline = create_pipeline(device, queue_slots)?;
let node_buffer = create_buffer(
device,
queue,
"graph node data",
&csr.node_data,
wgpu::BufferUsages::STORAGE,
)?;
let offset_buffer = create_buffer(
device,
queue,
"graph offsets",
&csr.offsets,
wgpu::BufferUsages::STORAGE,
)?;
let target_buffer = create_buffer(
device,
queue,
"graph targets",
&csr.targets,
wgpu::BufferUsages::STORAGE,
)?;
let source_buffer = create_buffer(
device,
queue,
"graph sources",
sources,
wgpu::BufferUsages::STORAGE,
)?;
let params_buffer = create_buffer(
device,
queue,
"graph params",
&[params],
wgpu::BufferUsages::STORAGE,
)?;
let finding_size =
u64::try_from(std::mem::size_of::<GpuFinding>()).map_err(|source| Error::Dataflow {
message: format!(
"GpuFinding size cannot fit u64: {source}. Fix: run on a supported target."
),
})?;
let findings_size =
u64::from(max_findings)
.checked_mul(finding_size)
.ok_or_else(|| Error::Dataflow {
message: "findings buffer size overflow. Fix: split sources or graph partitions before GPU dispatch.".to_string(),
})?;
let pool = BufferPool::global();
let findings_buffer = pool.acquire(
device,
"graph findings",
findings_size,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
)?;
let finding_count_buffer = create_buffer(
device,
queue,
"graph finding count",
&[0u32],
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
)?;
let visited_size = u64::from(params.num_sources) * u64::from(params.words_per_source) * 4;
let visited_bytes = usize::try_from(visited_size.max(16)).map_err(|source| Error::Dataflow {
message: format!("visited buffer size cannot fit usize: {source}. Fix: reject this dispatch on this platform."),
})?;
let visited_buffer = pool.acquire(
device,
"graph visited",
visited_size.max(16),
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
)?;
queue.write_buffer(&visited_buffer, 0, &vec![0u8; visited_bytes]);
let bind_group_inputs = BfsBindGroupInputs {
node_buffer: &node_buffer,
offset_buffer: &offset_buffer,
target_buffer: &target_buffer,
source_buffer: &source_buffer,
findings_buffer: &findings_buffer,
finding_count_buffer: &finding_count_buffer,
params_buffer: ¶ms_buffer,
visited_buffer: &visited_buffer,
};
let bind_group = create_bind_group(
device,
&pipeline,
&bind_group_inputs,
);
let (readback_buffer, readback_count_buffer, submission) = dispatch_and_copy(
device,
queue,
&pipeline,
&bind_group,
&findings_buffer,
&finding_count_buffer,
findings_size,
params.num_sources.div_ceil(WORKGROUP_SIZE),
command_encoder,
)?;
let findings = read_findings(
device,
&readback_buffer,
&readback_count_buffer,
max_findings,
submission,
)?;
pool.release(node_buffer);
pool.release(offset_buffer);
pool.release(target_buffer);
pool.release(source_buffer);
pool.release(params_buffer);
pool.release(findings_buffer);
pool.release(finding_count_buffer);
pool.release(visited_buffer);
pool.release(readback_buffer);
pool.release(readback_count_buffer);
Ok(findings)
}