mod bfs;
use self::bfs::bfs_reachability::bfs_reachability;
use crate::bytemuck_safe::{safe_bytes_of_slice, safe_cast_slice};
use crate::runtime::cache::{BufferPool, PooledBuffer};
use bytemuck::{Pod, Zeroable};
use vyre::error::{Error, Result};
use vyre::ops::graph::bfs::Bfs;
use vyre::ops::graph::csr::CsrGraph;
pub const WORKGROUP_SIZE: u32 = 64;
pub const MAX_FINDINGS: usize = 1_000_000;
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub(crate) struct GpuFinding {
start_node: u32,
sink_node: u32,
depth: u32,
source_idx: u32,
}
#[repr(C)]
#[derive(Copy, Clone, Debug, Pod, Zeroable)]
pub(crate) struct GpuParams {
num_sources: u32,
num_nodes: u32,
max_findings: u32,
max_depth: u32,
words_per_source: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
pub(crate) fn validate_inputs(csr: &CsrGraph, sources: &[u32]) -> Result<()> {
csr.validate()?;
for &source in sources {
let source_index = usize::try_from(source).map_err(|err| Error::Dataflow {
message: format!(
"source node {source} cannot fit usize: {err}. Fix: reject this source on this platform."
),
})?;
if source_index >= csr.node_count() {
return Err(Error::Dataflow {
message: format!(
"source node {source} is outside node_count {}. Fix: remove invalid sources before GPU dispatch.",
csr.node_count()
),
});
}
}
Ok(())
}
pub(crate) fn checked_u32(value: usize, label: &str) -> Result<u32> {
u32::try_from(value).map_err(|source| Error::Dataflow {
message: format!("{label} value {value} exceeds u32::MAX: {source}. Fix: split the graph input before GPU dispatch."),
})
}
pub fn read_finding_count(
device: &wgpu::Device,
readback_count_buffer: &wgpu::Buffer,
max_findings: u32,
submission: wgpu::SubmissionIndex,
) -> Result<u32> {
let slice = readback_count_buffer.slice(0..4);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
match device.poll(wgpu::Maintain::wait_for(submission)) {
wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
}
rx.recv()
.map_err(|error| Error::Dataflow {
message: format!("GPU finding count map channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."),
})?
.map_err(|error| Error::Dataflow {
message: format!("GPU finding count map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
})?;
let mapped = slice.get_mapped_range();
let count = u32::from_ne_bytes(mapped[..4].try_into().map_err(|source| {
Error::Dataflow {
message: format!("GPU finding count readback was not four bytes: {source}. Fix: inspect the count readback buffer size."),
}
})?)
.min(max_findings);
drop(mapped);
readback_count_buffer.unmap();
Ok(count)
}
pub fn read_finding_rows(
device: &wgpu::Device,
readback_buffer: &wgpu::Buffer,
count: u32,
submission: wgpu::SubmissionIndex,
) -> Result<Vec<(u32, u32, u32)>> {
let finding_byte_len = u64::from(count)
* u64::try_from(std::mem::size_of::<GpuFinding>()).map_err(|source| Error::Dataflow {
message: format!(
"GpuFinding size cannot fit u64: {source}. Fix: run on a supported target."
),
})?;
let slice = readback_buffer.slice(0..finding_byte_len);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
match device.poll(wgpu::Maintain::wait_for(submission)) {
wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
}
rx.recv()
.map_err(|error| Error::Dataflow {
message: format!("GPU findings map channel closed: {error}. Fix: keep the readback receiver alive until map_async completes."),
})?
.map_err(|error| Error::Dataflow {
message: format!("GPU findings map failed: {error:?}. Fix: check for device loss, adapter timeout, or invalid readback buffer usage."),
})?;
let mapped = slice.get_mapped_range();
let findings: &[GpuFinding] = safe_cast_slice(&mapped).map_err(|error| Error::Dataflow {
message: format!(
"safe cast failed: {error}. Fix: ensure the readback buffer size matches GpuFinding layout."
),
})?;
let results = findings
.iter()
.take(usize::try_from(count).map_err(|source| Error::Dataflow {
message: format!("finding count {count} cannot fit usize: {source}. Fix: reject this readback on this platform."),
})?)
.map(|finding| (finding.start_node, finding.sink_node, finding.depth))
.collect();
drop(mapped);
readback_buffer.unmap();
Ok(results)
}
pub(crate) fn read_findings(
device: &wgpu::Device,
readback_buffer: &wgpu::Buffer,
readback_count_buffer: &wgpu::Buffer,
max_findings: u32,
submission: wgpu::SubmissionIndex,
) -> Result<Vec<(u32, u32, u32)>> {
let count = read_finding_count(
device,
readback_count_buffer,
max_findings,
submission.clone(),
)?;
if count == 0 {
return Ok(Vec::new());
}
read_finding_rows(device, readback_buffer, count, submission)
}
pub(crate) fn create_pipeline(
device: &wgpu::Device,
queue_slots: u32,
) -> Result<wgpu::ComputePipeline> {
vyre::ops::registry::gate::verify_certificate(Bfs::SPEC.id()).map_err(|source| {
Error::Dataflow {
message: source.to_string(),
}
})?;
let source = vyre::lower::wgsl::lower(&Bfs::program_with_queue_size(queue_slots)).map_err(
|source| Error::Dataflow {
message: format!("failed to lower graph BFS IR to WGSL: {source}. Fix: validate the canonical BFS Program and lowerer before dispatch."),
},
)?;
crate::runtime::compile_compute_pipeline(device, "graph bfs pipeline", &source, "main")
.map_err(|source| Error::Dataflow {
message: format!(
"failed to compile graph BFS pipeline: {source}. Fix: repair the BFS WGSL or GPU runtime configuration."
),
})
}
pub(crate) fn create_buffer<T: Pod>(
device: &wgpu::Device,
queue: &wgpu::Queue,
label: &'static str,
data: &[T],
usage: wgpu::BufferUsages,
) -> Result<PooledBuffer> {
let contents = safe_bytes_of_slice(data);
let effective = if contents.is_empty() {
&[0u8; 16][..]
} else {
contents
};
let size = u64::try_from(effective.len()).map_err(|source| Error::Dataflow {
message: format!(
"buffer `{label}` has {} bytes that cannot fit u64: {source}. Fix: split the graph input before dispatch.",
effective.len()
),
})?;
let buffer =
BufferPool::global().acquire(device, label, size, usage | wgpu::BufferUsages::COPY_DST)?;
queue.write_buffer(&buffer, 0, effective);
Ok(buffer)
}
pub(crate) struct BfsBindGroupInputs<'a> {
pub(crate) node_buffer: &'a wgpu::Buffer,
pub(crate) offset_buffer: &'a wgpu::Buffer,
pub(crate) target_buffer: &'a wgpu::Buffer,
pub(crate) source_buffer: &'a wgpu::Buffer,
pub(crate) findings_buffer: &'a wgpu::Buffer,
pub(crate) finding_count_buffer: &'a wgpu::Buffer,
pub(crate) params_buffer: &'a wgpu::Buffer,
pub(crate) visited_buffer: &'a wgpu::Buffer,
}
pub(crate) fn create_bind_group(
device: &wgpu::Device,
pipeline: &wgpu::ComputePipeline,
inputs: &BfsBindGroupInputs<'_>,
) -> wgpu::BindGroup {
let layout = pipeline.get_bind_group_layout(0);
device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("graph bfs bind group"),
layout: &layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: inputs.node_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: inputs.offset_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: inputs.target_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: inputs.source_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: inputs.findings_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: inputs.finding_count_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: inputs.params_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 7,
resource: inputs.visited_buffer.as_entire_binding(),
},
],
})
}
#[expect(
clippy::too_many_arguments,
reason = "GPU dispatch helpers pass distinct wgpu handles whose grouping would hide binding roles"
)]
pub(crate) fn dispatch_and_copy(
device: &wgpu::Device,
queue: &wgpu::Queue,
pipeline: &wgpu::ComputePipeline,
bind_group: &wgpu::BindGroup,
findings_buffer: &wgpu::Buffer,
finding_count_buffer: &wgpu::Buffer,
findings_size: u64,
workgroup_count: u32,
encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<(PooledBuffer, PooledBuffer, wgpu::SubmissionIndex)> {
let pool = BufferPool::global();
let readback_buffer = pool.acquire(
device,
"graph findings readback",
findings_size,
wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
)?;
let readback_count_buffer = pool.acquire(
device,
"graph finding count readback",
4,
wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
)?;
let mut owned_encoder: Option<wgpu::CommandEncoder> = encoder.is_none().then(|| {
device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("graph bfs encoder"),
})
});
let encoder = if let Some(encoder) = encoder {
encoder
} else {
owned_encoder
.as_mut()
.expect("owned encoder must be present when encoder is omitted")
};
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("graph bfs pass"),
timestamp_writes: None,
});
pass.set_pipeline(pipeline);
pass.set_bind_group(0, bind_group, &[]);
pass.dispatch_workgroups(workgroup_count, 1, 1);
}
encoder.copy_buffer_to_buffer(findings_buffer, 0, &readback_buffer, 0, findings_size);
encoder.copy_buffer_to_buffer(finding_count_buffer, 0, &readback_count_buffer, 0, 4);
let submission = if let Some(encoder) = owned_encoder {
queue.submit(Some(encoder.finish()))
} else {
return Err(Error::Dataflow {
message: "dispatch_and_copy was called with an external encoder. Submit the caller-owned command encoder before readback. Fix: call with `None` for immediate submit behavior.".to_string(),
});
};
Ok((readback_buffer, readback_count_buffer, submission))
}
pub fn bfs_reachability_cached(
csr: &CsrGraph,
sources: &[u32],
max_depth: u32,
) -> Result<Vec<(u32, u32, u32)>> {
let (device, queue) = crate::runtime::cached_device()?;
bfs_reachability(device, queue, csr, sources, max_depth, None)
}