use std::sync::Arc;
use wgpu::util::DeviceExt;
use super::gpu_backend::GpuAccelerator;
pub struct PqGpuContext {
gpu: Arc<GpuAccelerator>,
}
impl PqGpuContext {
#[must_use]
pub fn new() -> Option<Self> {
let gpu = GpuAccelerator::global()?;
Some(Self { gpu })
}
}
#[must_use]
pub fn should_use_gpu(n: usize, k: usize, subspace_dim: usize) -> bool {
n.saturating_mul(k).saturating_mul(subspace_dim) > 10_000_000
}
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn gpu_kmeans_assign(
ctx: &PqGpuContext,
sub_vectors: &[Vec<f32>],
centroids: &[Vec<f32>],
subspace_dim: usize,
) -> Option<Vec<usize>> {
if sub_vectors.is_empty() || centroids.is_empty() || subspace_dim == 0 {
return None;
}
if sub_vectors.iter().any(|v| v.len() != subspace_dim)
|| centroids.iter().any(|c| c.len() != subspace_dim)
{
return None;
}
let n = sub_vectors.len();
let k = centroids.len();
let flat_vectors = super::helpers::flatten_vecs(sub_vectors, subspace_dim);
let flat_centroids = super::helpers::flatten_vecs(centroids, subspace_dim);
let device = ctx.gpu.device();
let queue = ctx.gpu.queue();
let pipeline = ctx.gpu.kmeans_pipeline();
let buffers = create_kmeans_buffers(device, &flat_vectors, &flat_centroids, n, k, subspace_dim);
let bind_group_layout = pipeline.get_bind_group_layout(0);
let bind_group = create_kmeans_bind_group(device, &bind_group_layout, &buffers);
dispatch_and_readback(device, queue, pipeline, &bind_group, &buffers, n)
}
struct KmeansBuffers {
vectors: wgpu::Buffer,
centroids: wgpu::Buffer,
assignments: wgpu::Buffer,
staging: wgpu::Buffer,
params: wgpu::Buffer,
assignments_size: u64,
}
fn create_kmeans_buffers(
device: &wgpu::Device,
flat_vectors: &[f32],
flat_centroids: &[f32],
n: usize,
k: usize,
subspace_dim: usize,
) -> KmeansBuffers {
let vectors = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Vectors Buffer"),
contents: bytemuck::cast_slice(flat_vectors),
usage: wgpu::BufferUsages::STORAGE,
});
let centroids = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Centroids Buffer"),
contents: bytemuck::cast_slice(flat_centroids),
usage: wgpu::BufferUsages::STORAGE,
});
#[allow(clippy::cast_possible_truncation)]
let assignments_size = (n * std::mem::size_of::<u32>()) as u64;
let assignments = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Assignments Buffer"),
size: assignments_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let staging = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer"),
size: assignments_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
#[allow(clippy::cast_possible_truncation)]
let params_data = [n as u32, k as u32, subspace_dim as u32, 0_u32];
let params = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Params Buffer"),
contents: bytemuck::cast_slice(¶ms_data),
usage: wgpu::BufferUsages::UNIFORM,
});
KmeansBuffers {
vectors,
centroids,
assignments,
staging,
params,
assignments_size,
}
}
fn create_kmeans_bind_group(
device: &wgpu::Device,
layout: &wgpu::BindGroupLayout,
buffers: &KmeansBuffers,
) -> wgpu::BindGroup {
device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("PQ K-means Bind Group"),
layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: buffers.vectors.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: buffers.centroids.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: buffers.assignments.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: buffers.params.as_entire_binding(),
},
],
})
}
fn dispatch_and_readback(
device: &wgpu::Device,
queue: &wgpu::Queue,
pipeline: &wgpu::ComputePipeline,
bind_group: &wgpu::BindGroup,
buffers: &KmeansBuffers,
n: usize,
) -> Option<Vec<usize>> {
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("PQ K-means Encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("PQ K-means Pass"),
timestamp_writes: None,
});
pass.set_pipeline(pipeline);
pass.set_bind_group(0, bind_group, &[]);
#[allow(clippy::cast_possible_truncation)]
let workgroups = n.div_ceil(256) as u32;
pass.dispatch_workgroups(workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(
&buffers.assignments,
0,
&buffers.staging,
0,
buffers.assignments_size,
);
queue.submit(std::iter::once(encoder.finish()));
let assignments_u32 = super::helpers::readback_buffer::<u32>(device, &buffers.staging, n)?;
#[allow(clippy::cast_lossless)]
let assignments: Vec<usize> = assignments_u32.iter().map(|&a| a as usize).collect();
Some(assignments)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_use_gpu_threshold() {
assert!(!should_use_gpu(100, 16, 8));
assert!(should_use_gpu(10000, 256, 8));
assert!(!should_use_gpu(10_000_000 / (256 * 8), 256, 8));
}
#[test]
fn test_gpu_context_new_does_not_panic() {
let _ctx = PqGpuContext::new();
}
#[test]
fn test_gpu_kmeans_assign_matches_cpu() {
let sub_vectors = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.1, 0.9, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.9, 0.1],
vec![1.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
vec![0.5, 0.5, 0.0, 0.0],
vec![0.0, 0.0, 0.5, 0.5],
];
let centroids = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
];
let cpu_assignments: Vec<usize> = sub_vectors
.iter()
.map(|v| {
centroids
.iter()
.enumerate()
.map(|(idx, c)| {
let dist: f32 =
v.iter().zip(c.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
(idx, dist)
})
.min_by(|a, b| a.1.total_cmp(&b.1))
.map(|(idx, _)| idx)
.unwrap()
})
.collect();
if let Some(ctx) = PqGpuContext::new() {
if let Some(gpu_assignments) = gpu_kmeans_assign(&ctx, &sub_vectors, ¢roids, 4) {
assert_eq!(
gpu_assignments.len(),
sub_vectors.len(),
"GPU must return one assignment per vector"
);
assert_eq!(
gpu_assignments, cpu_assignments,
"GPU assignments must match CPU"
);
}
}
}
#[test]
fn test_gpu_kmeans_assign_empty_input() {
if let Some(ctx) = PqGpuContext::new() {
assert!(gpu_kmeans_assign(&ctx, &[], &[vec![1.0]], 1).is_none());
assert!(gpu_kmeans_assign(&ctx, &[vec![1.0]], &[], 1).is_none());
assert!(gpu_kmeans_assign(&ctx, &[vec![1.0]], &[vec![1.0]], 0).is_none());
}
}
#[test]
fn test_gpu_kmeans_assign_dimension_mismatch_returns_none() {
if let Some(ctx) = PqGpuContext::new() {
let sub_vectors = vec![vec![1.0, 0.0, 0.0]]; let centroids = vec![vec![1.0, 0.0, 0.0, 0.0]]; assert!(
gpu_kmeans_assign(&ctx, &sub_vectors, ¢roids, 4).is_none(),
"mismatched sub_vector dim must return None"
);
}
}
#[test]
fn test_pq_context_shares_global_device() {
let gpu_available = GpuAccelerator::is_available();
let pq_ctx = PqGpuContext::new();
assert_eq!(
pq_ctx.is_some(),
gpu_available,
"PqGpuContext availability must match GpuAccelerator::is_available()"
);
}
}