use super::{GpuCsrBuffers, GpuDevice};
use crate::NodeId;
use anyhow::{Context, Result};
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct BfsParams {
num_nodes: u32,
current_level: u32,
source_node: u32,
_padding: u32,
}
#[derive(Debug, Clone)]
pub struct GpuBfsResult {
pub distances: Vec<u32>,
pub visited_count: usize,
}
impl GpuBfsResult {
#[must_use]
pub fn distance(&self, node: NodeId) -> Option<u32> {
self.distances
.get(node.0 as usize)
.copied()
.filter(|&d| d != u32::MAX)
}
#[must_use]
pub fn is_reachable(&self, node: NodeId) -> bool {
self.distance(node).is_some()
}
}
async fn read_buffer_u32(device: &GpuDevice, buffer: &wgpu::Buffer) -> Result<u32> {
let staging_buffer = device.create_buffer(
"Staging Buffer",
4,
wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
)?;
let mut encoder = device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
encoder.copy_buffer_to_buffer(buffer, 0, &staging_buffer, 0, 4);
device.queue().submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
device.device().poll(wgpu::Maintain::Wait);
rx.receive()
.await
.context("Failed to receive map result")?
.context("Buffer mapping failed")?;
let data = buffer_slice.get_mapped_range();
let value = u32::from_ne_bytes(data[0..4].try_into()?);
drop(data);
staging_buffer.unmap();
Ok(value)
}
async fn read_distances(
device: &GpuDevice,
distances_buffer: &wgpu::Buffer,
num_nodes: usize,
) -> Result<Vec<u32>> {
let size = (num_nodes * std::mem::size_of::<u32>()) as u64;
let staging_buffer = device.create_buffer(
"Distances Staging",
size,
wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
)?;
let mut encoder = device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
encoder.copy_buffer_to_buffer(distances_buffer, 0, &staging_buffer, 0, size);
device.queue().submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
device.device().poll(wgpu::Maintain::Wait);
rx.receive()
.await
.context("Failed to receive map result")?
.context("Buffer mapping failed")?;
let data = buffer_slice.get_mapped_range();
let distances: Vec<u32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
staging_buffer.unmap();
Ok(distances)
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::cast_possible_truncation)]
pub async fn gpu_bfs(
device: &GpuDevice,
buffers: &GpuCsrBuffers,
source: NodeId,
) -> Result<GpuBfsResult> {
const SHADER: &str = include_str!("shaders/bfs_simple.wgsl");
let shader_module = device
.device()
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("BFS Shader"),
source: wgpu::ShaderSource::Wgsl(SHADER.into()),
});
let bind_group_layout =
device
.device()
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("BFS Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device
.device()
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("BFS Pipeline Layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let compute_pipeline =
device
.device()
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("BFS Pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: "bfs_level",
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
let num_nodes = buffers.num_nodes();
let params_buffer = device.create_buffer_init(
"BFS Params",
bytemuck::bytes_of(&BfsParams {
num_nodes: num_nodes as u32,
current_level: 0,
source_node: source.0,
_padding: 0,
}),
wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
)?;
let mut initial_distances = vec![u32::MAX; num_nodes];
if (source.0 as usize) < num_nodes {
initial_distances[source.0 as usize] = 0;
}
let distances_buffer = device.create_buffer_init(
"BFS Distances",
bytemuck::cast_slice(&initial_distances),
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
)?;
let updated_buffer = device.create_buffer_init(
"BFS Updated Flag",
bytemuck::bytes_of(&0u32),
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC,
)?;
let bind_group = device
.device()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("BFS Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: params_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: buffers.row_offsets.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: buffers.col_indices.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: distances_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: updated_buffer.as_entire_binding(),
},
],
});
let workgroup_size = 256;
let num_workgroups = (num_nodes as u32).div_ceil(workgroup_size).max(1);
for level in 0..num_nodes {
device
.queue()
.write_buffer(&updated_buffer, 0, bytemuck::bytes_of(&0u32));
device.queue().write_buffer(
¶ms_buffer,
0,
bytemuck::bytes_of(&BfsParams {
num_nodes: num_nodes as u32,
current_level: level as u32,
source_node: source.0,
_padding: 0,
}),
);
let mut encoder = device
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("BFS Command Encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("BFS Compute Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&compute_pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
}
device.queue().submit(Some(encoder.finish()));
device.device().poll(wgpu::Maintain::Wait);
let updated_value = read_buffer_u32(device, &updated_buffer).await?;
if updated_value == 0 {
break;
}
}
let distances = read_distances(device, &distances_buffer, num_nodes).await?;
let visited_count = distances.iter().filter(|&&d| d != u32::MAX).count();
Ok(GpuBfsResult {
distances,
visited_count,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::CsrGraph;
#[tokio::test]
async fn test_gpu_bfs_simple_chain() {
if !GpuDevice::is_gpu_available().await {
eprintln!("⚠️ Skipping test_gpu_bfs_simple_chain: GPU not available");
return;
}
let device = GpuDevice::new().await.unwrap();
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
graph.add_edge(NodeId(1), NodeId(2), 1.0).unwrap();
let buffers = GpuCsrBuffers::from_csr_graph(&device, &graph).unwrap();
let result = gpu_bfs(&device, &buffers, NodeId(0)).await.unwrap();
assert_eq!(result.distance(NodeId(0)), Some(0));
assert_eq!(result.distance(NodeId(1)), Some(1));
assert_eq!(result.distance(NodeId(2)), Some(2));
}
#[tokio::test]
async fn test_gpu_bfs_disconnected() {
if !GpuDevice::is_gpu_available().await {
eprintln!("⚠️ Skipping test_gpu_bfs_disconnected: GPU not available");
return;
}
let device = GpuDevice::new().await.unwrap();
let mut graph = CsrGraph::new();
graph.add_edge(NodeId(0), NodeId(1), 1.0).unwrap();
graph.add_edge(NodeId(2), NodeId(2), 1.0).unwrap();
let buffers = GpuCsrBuffers::from_csr_graph(&device, &graph).unwrap();
let result = gpu_bfs(&device, &buffers, NodeId(0)).await.unwrap();
assert_eq!(result.distance(NodeId(0)), Some(0));
assert!(!result.is_reachable(NodeId(2))); }
#[test]
fn test_gpu_bfs_result_api() {
let result = GpuBfsResult {
distances: vec![0, 1, u32::MAX],
visited_count: 2,
};
assert_eq!(result.distance(NodeId(0)), Some(0));
assert_eq!(result.distance(NodeId(1)), Some(1));
assert_eq!(result.distance(NodeId(2)), None);
assert!(result.is_reachable(NodeId(0)));
assert!(result.is_reachable(NodeId(1)));
assert!(!result.is_reachable(NodeId(2)));
}
}