#![allow(clippy::doc_markdown)]
#![allow(clippy::too_many_lines)]
#![allow(clippy::similar_names)]
use wgpu::util::DeviceExt;
use super::gpu_csr::CsrGraph;
pub(super) const MAX_CANDIDATES_PER_ITER: u32 = 8192;
pub(super) struct TraversalBuffers {
pub(super) csr_offsets: wgpu::Buffer,
pub(super) csr_neighbors: wgpu::Buffer,
pub(super) vectors: wgpu::Buffer,
pub(super) query: wgpu::Buffer,
pub(super) frontier_a_ids: wgpu::Buffer,
pub(super) frontier_a_dists: wgpu::Buffer,
pub(super) frontier_b_ids: wgpu::Buffer,
pub(super) frontier_b_dists: wgpu::Buffer,
pub(super) candidates: wgpu::Buffer,
pub(super) candidate_dists: wgpu::Buffer,
pub(super) visited: wgpu::Buffer,
pub(super) counters: wgpu::Buffer,
pub(super) select_counters: wgpu::Buffer,
pub(super) expand_params: wgpu::Buffer,
pub(super) distance_params: wgpu::Buffer,
pub(super) staging_ids: wgpu::Buffer,
pub(super) staging_dists: wgpu::Buffer,
pub(super) candidates_sentinel: wgpu::Buffer,
pub(super) candidates_byte_size: usize,
pub(super) frontier_ids_sentinel: wgpu::Buffer,
pub(super) frontier_dists_sentinel: wgpu::Buffer,
pub(super) ef: usize,
}
impl TraversalBuffers {
#[allow(clippy::too_many_arguments)]
pub(super) fn new(
device: &wgpu::Device,
csr: &CsrGraph,
vectors_flat: &[f32],
query: &[f32],
entry_node: usize,
entry_distance: f32,
ef: usize,
dimension: usize,
) -> Self {
debug_assert!(
(csr.num_nodes as usize)
.checked_mul(dimension)
.is_some_and(|p| u32::try_from(p).is_ok()),
"GPU traversal requires num_nodes * dimension <= u32::MAX \
(got {} * {}); use should_traverse_gpu() to gate the caller",
csr.num_nodes,
dimension,
);
let csr_offsets = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("CSR Offsets"),
contents: bytemuck::cast_slice(&csr.offsets),
usage: wgpu::BufferUsages::STORAGE,
});
let csr_neighbors = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("CSR Neighbors"),
contents: bytemuck::cast_slice(&csr.neighbors),
usage: wgpu::BufferUsages::STORAGE,
});
let vectors = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Vectors"),
contents: bytemuck::cast_slice(vectors_flat),
usage: wgpu::BufferUsages::STORAGE,
});
let query_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Query"),
contents: bytemuck::cast_slice(query),
usage: wgpu::BufferUsages::STORAGE,
});
let visited_words = (csr.num_nodes as usize).div_ceil(32);
let mut visited_data = vec![0u32; visited_words.max(1)];
#[allow(clippy::cast_possible_truncation)]
{
let entry_u32 = entry_node as u32;
if (entry_u32 as usize) < csr.num_nodes as usize {
let word_idx = (entry_u32 / 32) as usize;
let bit_idx = entry_u32 % 32;
visited_data[word_idx] |= 1u32 << bit_idx;
}
}
let frontier_buf_size = (ef * std::mem::size_of::<u32>()) as u64;
let frontier_dists_size = (ef * std::mem::size_of::<f32>()) as u64;
let mut initial_frontier = vec![u32::MAX; ef];
#[allow(clippy::cast_possible_truncation)]
{
initial_frontier[0] = entry_node as u32;
}
let frontier_a_ids = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Frontier A IDs"),
contents: bytemuck::cast_slice(&initial_frontier),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
});
let mut initial_dists = vec![f32::MAX; ef];
initial_dists[0] = entry_distance;
let frontier_a_dists = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Frontier A Dists"),
contents: bytemuck::cast_slice(&initial_dists),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
});
let frontier_b_ids = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Frontier B IDs"),
size: frontier_buf_size,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let frontier_b_dists = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Frontier B Dists"),
size: frontier_dists_size,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let max_cand = MAX_CANDIDATES_PER_ITER as usize;
let candidates_byte_size = max_cand * std::mem::size_of::<u32>();
let candidates = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Candidates"),
size: candidates_byte_size as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let sentinel_data = vec![u32::MAX; max_cand];
let candidates_sentinel = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Candidates Sentinel"),
contents: bytemuck::cast_slice(&sentinel_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
let candidate_dists = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Candidate Dists"),
size: (max_cand * std::mem::size_of::<f32>()) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let visited = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Visited Bitset"),
contents: bytemuck::cast_slice(&visited_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let counters = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Expand Counters"),
size: 4,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let select_counters = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Select Counters"),
size: 4,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
#[allow(clippy::cast_possible_truncation)]
let expand_params_data = [
ef as u32, MAX_CANDIDATES_PER_ITER, csr.num_nodes, 0u32, ];
let expand_params = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Expand Params"),
contents: bytemuck::cast_slice(&expand_params_data),
usage: wgpu::BufferUsages::UNIFORM,
});
#[allow(clippy::cast_possible_truncation)]
let distance_params_data = [dimension as u32, MAX_CANDIDATES_PER_ITER];
let distance_params = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Distance Params"),
contents: bytemuck::cast_slice(&distance_params_data),
usage: wgpu::BufferUsages::UNIFORM,
});
let staging_ids = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging IDs"),
size: frontier_buf_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let staging_dists = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Dists"),
size: frontier_dists_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let frontier_ids_sentinel_data = vec![u32::MAX; ef];
let frontier_ids_sentinel = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Frontier IDs Sentinel"),
contents: bytemuck::cast_slice(&frontier_ids_sentinel_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
let frontier_dists_sentinel_data = vec![f32::MAX; ef];
let frontier_dists_sentinel =
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Frontier Dists Sentinel"),
contents: bytemuck::cast_slice(&frontier_dists_sentinel_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
Self {
csr_offsets,
csr_neighbors,
vectors,
query: query_buf,
frontier_a_ids,
frontier_a_dists,
frontier_b_ids,
frontier_b_dists,
candidates,
candidate_dists,
visited,
counters,
select_counters,
expand_params,
distance_params,
staging_ids,
staging_dists,
candidates_sentinel,
candidates_byte_size,
frontier_ids_sentinel,
frontier_dists_sentinel,
ef,
}
}
pub(super) fn create_expand_bind_group(
&self,
device: &wgpu::Device,
pipeline: &wgpu::ComputePipeline,
) -> wgpu::BindGroup {
let layout = pipeline.get_bind_group_layout(0);
device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Expand BG"),
layout: &layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: self.csr_offsets.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.csr_neighbors.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: self.frontier_a_ids.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: self.candidates.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: self.visited.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: self.counters.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: self.expand_params.as_entire_binding(),
},
],
})
}
pub(super) fn create_distance_bind_group(
&self,
device: &wgpu::Device,
pipeline: &wgpu::ComputePipeline,
) -> wgpu::BindGroup {
let layout = pipeline.get_bind_group_layout(0);
device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Distance BG"),
layout: &layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: self.query.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.vectors.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: self.candidates.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: self.candidate_dists.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: self.distance_params.as_entire_binding(),
},
],
})
}
pub(super) fn create_select_bind_group(
&self,
device: &wgpu::Device,
pipeline: &wgpu::ComputePipeline,
ef: usize,
) -> wgpu::BindGroup {
let layout = pipeline.get_bind_group_layout(0);
#[allow(clippy::cast_possible_truncation)]
let select_params_data = [
MAX_CANDIDATES_PER_ITER, ef as u32, 0u32, 0u32, ];
let select_params_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Select Params"),
contents: bytemuck::cast_slice(&select_params_data),
usage: wgpu::BufferUsages::UNIFORM,
});
device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Select BG"),
layout: &layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: self.candidates.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.candidate_dists.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: self.frontier_b_ids.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: self.frontier_b_dists.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: self.select_counters.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: select_params_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: self.frontier_a_ids.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 7,
resource: self.frontier_a_dists.as_entire_binding(),
},
],
})
}
}