#![allow(clippy::similar_names)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::too_many_lines)]
use std::sync::{Arc, OnceLock};
use std::time::Instant;
use super::gpu_backend::GpuAccelerator;
use super::gpu_csr::CsrGraph;
use super::gpu_traversal_buffers::{TraversalBuffers, MAX_CANDIDATES_PER_ITER};
use super::gpu_traversal_pipelines as pipelines;
#[must_use]
fn gpu_distance_cpu_fallback(
query: &[f32],
entry_vec: &[f32],
metric: crate::distance::DistanceMetric,
) -> f32 {
use crate::distance::DistanceMetric;
debug_assert_eq!(query.len(), entry_vec.len());
match metric {
DistanceMetric::Cosine => {
let (mut dot, mut na, mut nb) = (0.0_f32, 0.0_f32, 0.0_f32);
for (x, y) in query.iter().zip(entry_vec.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
let denom = (na * nb).sqrt();
if denom == 0.0 {
1.0
} else {
1.0 - (dot / denom)
}
}
DistanceMetric::Euclidean => query
.iter()
.zip(entry_vec.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>(),
DistanceMetric::DotProduct => -query
.iter()
.zip(entry_vec.iter())
.map(|(a, b)| a * b)
.sum::<f32>(),
DistanceMetric::Hamming | DistanceMetric::Jaccard => f32::MAX,
}
}
#[must_use]
fn adaptive_gpu_iterations(ef_search: usize) -> u32 {
match ef_search {
0..=64 => 20,
65..=128 => 18,
129..=256 => 15,
257..=512 => 12,
_ => 10,
}
}
#[must_use]
pub fn should_traverse_gpu(num_vectors: usize, dimension: usize) -> bool {
if num_vectors <= 500_000 {
return false;
}
num_vectors
.checked_mul(dimension)
.is_some_and(|prod| u32::try_from(prod).is_ok())
}
#[derive(Debug, Clone)]
pub struct GpuTraversalStats {
pub iterations: u32,
pub cache_hit: bool,
pub upload_ms: f64,
pub compute_ms: f64,
pub total_ms: f64,
}
impl std::fmt::Display for GpuTraversalStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"GpuTraversal(iters={}, cache={}, upload={:.2}ms, compute={:.2}ms, total={:.2}ms)",
self.iterations,
if self.cache_hit { "HIT" } else { "MISS" },
self.upload_ms,
self.compute_ms,
self.total_ms,
)
}
}
pub struct GpuTraversalContext {
gpu: Arc<GpuAccelerator>,
expand_pipeline: wgpu::ComputePipeline,
distance_cosine_pipeline: wgpu::ComputePipeline,
distance_euclidean_sq_pipeline: wgpu::ComputePipeline,
distance_dot_pipeline: wgpu::ComputePipeline,
select_pipeline: wgpu::ComputePipeline,
}
impl GpuTraversalContext {
#[must_use]
pub fn global() -> Option<Arc<Self>> {
static INSTANCE: OnceLock<Option<Arc<GpuTraversalContext>>> = OnceLock::new();
INSTANCE
.get_or_init(|| GpuTraversalContext::new().map(Arc::new))
.clone()
}
#[must_use]
pub fn new() -> Option<Self> {
let gpu = GpuAccelerator::global()?;
let device = gpu.device();
let expand_pipeline = pipelines::compile_expand_pipeline(device);
let distance_cosine_pipeline = pipelines::compile_traversal_distance_pipeline(
device,
super::gpu_backend::shaders::TRAVERSAL_COSINE_SHADER,
"traversal_cosine",
"Traversal Cosine",
);
let distance_euclidean_sq_pipeline = pipelines::compile_traversal_distance_pipeline(
device,
super::gpu_backend::shaders::TRAVERSAL_EUCLIDEAN_SQ_SHADER,
"traversal_euclidean_sq",
"Traversal Euclidean Sq",
);
let distance_dot_pipeline = pipelines::compile_traversal_distance_pipeline(
device,
super::gpu_backend::shaders::TRAVERSAL_DOT_PRODUCT_SHADER,
"traversal_dot",
"Traversal Dot Product",
);
let select_pipeline = pipelines::compile_select_pipeline(device);
Some(Self {
gpu,
expand_pipeline,
distance_cosine_pipeline,
distance_euclidean_sq_pipeline,
distance_dot_pipeline,
select_pipeline,
})
}
#[allow(clippy::too_many_arguments)]
pub fn search_layer0(
&self,
csr: &CsrGraph,
vectors_flat: &[f32],
query: &[f32],
entry_node: usize,
k: usize,
ef_search: usize,
dimension: usize,
metric: crate::distance::DistanceMetric,
) -> Vec<(usize, f32)> {
if csr.is_empty() || query.is_empty() || dimension == 0 {
return Vec::new();
}
let total_start = Instant::now();
if let Some(results) = self.search_layer0_inner(
csr,
vectors_flat,
query,
entry_node,
k,
ef_search,
dimension,
metric,
) {
let total_ms = total_start.elapsed().as_secs_f64() * 1000.0;
tracing::debug!(
k,
ef_search,
num_results = results.len(),
total_ms = format!("{total_ms:.2}"),
"GPU traversal completed"
);
results
} else {
tracing::warn!("GPU traversal failed, returning empty results for CPU fallback");
Vec::new()
}
}
#[allow(clippy::too_many_arguments)]
fn search_layer0_inner(
&self,
csr: &CsrGraph,
vectors_flat: &[f32],
query: &[f32],
entry_node: usize,
k: usize,
ef_search: usize,
dimension: usize,
metric: crate::distance::DistanceMetric,
) -> Option<Vec<(usize, f32)>> {
let device = self.gpu.device();
let queue = self.gpu.queue();
let ef = ef_search.max(k);
let max_iterations = adaptive_gpu_iterations(ef_search);
let distance_pipeline = match metric {
crate::distance::DistanceMetric::Cosine => &self.distance_cosine_pipeline,
crate::distance::DistanceMetric::Euclidean => &self.distance_euclidean_sq_pipeline,
crate::distance::DistanceMetric::DotProduct => &self.distance_dot_pipeline,
_ => return None,
};
let entry_offset = entry_node.checked_mul(dimension).filter(|end| {
end.checked_add(dimension)
.is_some_and(|e| e <= vectors_flat.len())
})?;
let entry_vec = &vectors_flat[entry_offset..entry_offset + dimension];
let entry_distance = gpu_distance_cpu_fallback(query, entry_vec, metric);
let upload_start = Instant::now();
let buffers = TraversalBuffers::new(
device,
csr,
vectors_flat,
query,
entry_node,
entry_distance,
ef,
dimension,
);
let _upload_ms = upload_start.elapsed().as_secs_f64() * 1000.0;
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("HNSW Traversal Encoder"),
});
for _iteration in 0..max_iterations {
self.encode_expand_pass(&mut encoder, &buffers);
self.encode_distance_pass(&mut encoder, distance_pipeline, &buffers);
self.encode_select_pass(&mut encoder, &buffers, ef);
}
let result_count = k.min(ef);
#[allow(clippy::cast_possible_truncation)]
let result_ids_size = (result_count * std::mem::size_of::<u32>()) as u64;
#[allow(clippy::cast_possible_truncation)]
let result_dists_size = (result_count * std::mem::size_of::<f32>()) as u64;
encoder.copy_buffer_to_buffer(
&buffers.frontier_a_ids,
0,
&buffers.staging_ids,
0,
result_ids_size,
);
encoder.copy_buffer_to_buffer(
&buffers.frontier_a_dists,
0,
&buffers.staging_dists,
0,
result_dists_size,
);
queue.submit(std::iter::once(encoder.finish()));
let result_ids =
super::helpers::readback_buffer::<u32>(device, &buffers.staging_ids, result_count)?;
let result_dists =
super::helpers::readback_buffer::<f32>(device, &buffers.staging_dists, result_count)?;
let mut results: Vec<(usize, f32)> = result_ids
.iter()
.zip(result_dists.iter())
.filter(|(&id, &dist)| id != u32::MAX && dist < f32::MAX)
.map(|(&id, &dist)| (id as usize, dist))
.collect();
results.sort_by(|a, b| a.1.total_cmp(&b.1));
results.truncate(k);
Some(results)
}
fn encode_expand_pass(&self, encoder: &mut wgpu::CommandEncoder, buffers: &TraversalBuffers) {
encoder.clear_buffer(&buffers.counters, 0, None);
encoder.copy_buffer_to_buffer(
&buffers.candidates_sentinel,
0,
&buffers.candidates,
0,
buffers.candidates_byte_size as u64,
);
let bind_group = buffers.create_expand_bind_group(self.gpu.device(), &self.expand_pipeline);
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Expand Pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.expand_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
#[allow(clippy::cast_possible_truncation)]
let workgroups = buffers.ef.div_ceil(256) as u32;
pass.dispatch_workgroups(workgroups.max(1), 1, 1);
}
fn encode_distance_pass(
&self,
encoder: &mut wgpu::CommandEncoder,
pipeline: &wgpu::ComputePipeline,
buffers: &TraversalBuffers,
) {
let bind_group = buffers.create_distance_bind_group(self.gpu.device(), pipeline);
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Distance Pass"),
timestamp_writes: None,
});
pass.set_pipeline(pipeline);
pass.set_bind_group(0, &bind_group, &[]);
let workgroups = MAX_CANDIDATES_PER_ITER.div_ceil(256);
pass.dispatch_workgroups(workgroups, 1, 1);
}
fn encode_select_pass(
&self,
encoder: &mut wgpu::CommandEncoder,
buffers: &TraversalBuffers,
ef: usize,
) {
encoder.clear_buffer(&buffers.select_counters, 0, None);
let frontier_bytes = (ef * std::mem::size_of::<u32>()) as u64;
encoder.copy_buffer_to_buffer(
&buffers.frontier_ids_sentinel,
0,
&buffers.frontier_b_ids,
0,
frontier_bytes,
);
encoder.copy_buffer_to_buffer(
&buffers.frontier_dists_sentinel,
0,
&buffers.frontier_b_dists,
0,
frontier_bytes, );
let bind_group =
buffers.create_select_bind_group(self.gpu.device(), &self.select_pipeline, ef);
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Select Pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.select_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(1, 1, 1);
let frontier_bytes = (ef * std::mem::size_of::<u32>()) as u64;
encoder.copy_buffer_to_buffer(
&buffers.frontier_b_ids,
0,
&buffers.frontier_a_ids,
0,
frontier_bytes,
);
let dists_bytes = (ef * std::mem::size_of::<f32>()) as u64;
encoder.copy_buffer_to_buffer(
&buffers.frontier_b_dists,
0,
&buffers.frontier_a_dists,
0,
dists_bytes,
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_traverse_gpu_threshold() {
assert!(!should_traverse_gpu(0, 128));
assert!(!should_traverse_gpu(100_000, 128));
assert!(!should_traverse_gpu(500_000, 128));
assert!(should_traverse_gpu(500_001, 128));
assert!(should_traverse_gpu(1_000_000, 128));
}
#[test]
fn test_should_traverse_gpu_u32_offset_correctness_gate() {
assert!(!should_traverse_gpu(10_000_000, 768));
assert!(should_traverse_gpu(5_000_000, 768));
assert!(should_traverse_gpu((u32::MAX as usize) / 128, 128));
assert!(!should_traverse_gpu(usize::MAX / 2, 4));
}
#[test]
fn test_gpu_traversal_context_new_no_panic() {
let _ctx = GpuTraversalContext::new();
}
#[test]
fn test_search_empty_csr_returns_empty() {
if let Some(ctx) = GpuTraversalContext::new() {
let csr = CsrGraph {
offsets: vec![0],
neighbors: vec![],
num_nodes: 0,
max_degree: 0,
total_edges: 0,
};
let result = ctx.search_layer0(
&csr,
&[],
&[1.0, 0.0, 0.0],
0,
10,
64,
3,
crate::distance::DistanceMetric::Cosine,
);
assert!(result.is_empty());
}
}
#[test]
fn test_search_unsupported_metric_returns_empty() {
if let Some(ctx) = GpuTraversalContext::new() {
let csr = CsrGraph {
offsets: vec![0, 1],
neighbors: vec![0],
num_nodes: 1,
max_degree: 1,
total_edges: 1,
};
let result = ctx.search_layer0(
&csr,
&[1.0, 0.0, 0.0],
&[1.0, 0.0, 0.0],
0,
10,
64,
3,
crate::distance::DistanceMetric::Hamming,
);
assert!(result.is_empty());
}
}
}