use std::sync::Arc;
use wgpu::util::DeviceExt;
const PQ_KMEANS_ASSIGN_SHADER: &str = r"
struct Params {
num_vectors: u32,
num_centroids: u32,
subspace_dim: u32,
_padding: u32,
}
@group(0) @binding(0) var<storage, read> vectors: array<f32>;
@group(0) @binding(1) var<storage, read> centroids: array<f32>;
@group(0) @binding(2) var<storage, read_write> assignments: array<u32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn kmeans_assign(@builtin(global_invocation_id) id: vec3<u32>) {
let idx = id.x;
if (idx >= params.num_vectors) { return; }
let sd = params.subspace_dim;
let k = params.num_centroids;
let vec_offset = idx * sd;
var best_dist: f32 = 3.4028235e+38;
var best_idx: u32 = 0u;
for (var c: u32 = 0u; c < k; c = c + 1u) {
let cent_offset = c * sd;
var dist: f32 = 0.0;
for (var d: u32 = 0u; d < sd; d = d + 1u) {
let diff = vectors[vec_offset + d] - centroids[cent_offset + d];
dist = dist + diff * diff;
}
if (dist < best_dist) {
best_dist = dist;
best_idx = c;
}
}
assignments[idx] = best_idx;
}
";
pub struct PqGpuContext {
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
pipeline: Arc<wgpu::ComputePipeline>,
bind_group_layout: Arc<wgpu::BindGroupLayout>,
}
impl PqGpuContext {
#[must_use]
pub fn new() -> Option<Self> {
std::thread::spawn(Self::new_sync).join().ok().flatten()
}
#[allow(clippy::too_many_lines)]
fn new_sync() -> Option<Self> {
let backends = wgpu::Backends::all();
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
backends,
..Default::default()
});
let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
}))?;
let (device, queue) = pollster::block_on(adapter.request_device(
&wgpu::DeviceDescriptor {
label: Some("VelesDB PQ K-means"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::default(),
memory_hints: wgpu::MemoryHints::Performance,
},
None,
))
.ok()?;
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("PQ K-means Assignment Shader"),
source: wgpu::ShaderSource::Wgsl(PQ_KMEANS_ASSIGN_SHADER.into()),
});
let bind_group_layout =
super::helpers::create_quad_bind_group_layout(&device, "PQ K-means Bind Group Layout");
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("PQ K-means Pipeline Layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("PQ K-means Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("kmeans_assign"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
Some(Self {
device: Arc::new(device),
queue: Arc::new(queue),
pipeline: Arc::new(pipeline),
bind_group_layout: Arc::new(bind_group_layout),
})
}
}
#[must_use]
pub fn should_use_gpu(n: usize, k: usize, subspace_dim: usize) -> bool {
n * k * subspace_dim > 10_000_000
}
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn gpu_kmeans_assign(
ctx: &PqGpuContext,
sub_vectors: &[Vec<f32>],
centroids: &[Vec<f32>],
subspace_dim: usize,
) -> Option<Vec<usize>> {
if sub_vectors.is_empty() || centroids.is_empty() || subspace_dim == 0 {
return None;
}
if sub_vectors.iter().any(|v| v.len() != subspace_dim)
|| centroids.iter().any(|c| c.len() != subspace_dim)
{
return None;
}
let n = sub_vectors.len();
let k = centroids.len();
let flat_vectors = super::helpers::flatten_vecs(sub_vectors, subspace_dim);
let flat_centroids = super::helpers::flatten_vecs(centroids, subspace_dim);
let device = &ctx.device;
let queue = &ctx.queue;
let vectors_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Vectors Buffer"),
contents: bytemuck::cast_slice(&flat_vectors),
usage: wgpu::BufferUsages::STORAGE,
});
let centroids_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Centroids Buffer"),
contents: bytemuck::cast_slice(&flat_centroids),
usage: wgpu::BufferUsages::STORAGE,
});
let assignments_size = (n * std::mem::size_of::<u32>()) as u64;
let assignments_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Assignments Buffer"),
size: assignments_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer"),
size: assignments_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
#[allow(clippy::cast_possible_truncation)]
let params = [n as u32, k as u32, subspace_dim as u32, 0_u32];
let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Params Buffer"),
contents: bytemuck::cast_slice(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("PQ K-means Bind Group"),
layout: &ctx.bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: vectors_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: centroids_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: assignments_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: params_buffer.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("PQ K-means Encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("PQ K-means Pass"),
timestamp_writes: None,
});
pass.set_pipeline(&ctx.pipeline);
pass.set_bind_group(0, &bind_group, &[]);
#[allow(clippy::cast_possible_truncation)]
let workgroups = n.div_ceil(256) as u32;
pass.dispatch_workgroups(workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(&assignments_buffer, 0, &staging_buffer, 0, assignments_size);
queue.submit(std::iter::once(encoder.finish()));
let assignments_u32 = super::helpers::readback_buffer::<u32>(device, &staging_buffer, n)?;
let assignments: Vec<usize> = assignments_u32.iter().map(|&a| a as usize).collect();
Some(assignments)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_use_gpu_threshold() {
assert!(!should_use_gpu(100, 16, 8));
assert!(should_use_gpu(10000, 256, 8));
assert!(!should_use_gpu(10_000_000 / (256 * 8), 256, 8));
}
#[test]
fn test_gpu_context_new_does_not_panic() {
let _ctx = PqGpuContext::new();
}
#[test]
fn test_gpu_kmeans_assign_matches_cpu() {
let sub_vectors = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.1, 0.9, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.9, 0.1],
vec![1.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
vec![0.5, 0.5, 0.0, 0.0],
vec![0.0, 0.0, 0.5, 0.5],
];
let centroids = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
];
let cpu_assignments: Vec<usize> = sub_vectors
.iter()
.map(|v| {
centroids
.iter()
.enumerate()
.map(|(idx, c)| {
let dist: f32 =
v.iter().zip(c.iter()).map(|(a, b)| (a - b) * (a - b)).sum();
(idx, dist)
})
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(idx, _)| idx)
.unwrap()
})
.collect();
if let Some(ctx) = PqGpuContext::new() {
if let Some(gpu_assignments) = gpu_kmeans_assign(&ctx, &sub_vectors, ¢roids, 4) {
assert_eq!(
gpu_assignments.len(),
sub_vectors.len(),
"GPU must return one assignment per vector"
);
assert_eq!(
gpu_assignments, cpu_assignments,
"GPU assignments must match CPU"
);
}
}
}
#[test]
fn test_gpu_kmeans_assign_empty_input() {
if let Some(ctx) = PqGpuContext::new() {
assert!(gpu_kmeans_assign(&ctx, &[], &[vec![1.0]], 1).is_none());
assert!(gpu_kmeans_assign(&ctx, &[vec![1.0]], &[], 1).is_none());
assert!(gpu_kmeans_assign(&ctx, &[vec![1.0]], &[vec![1.0]], 0).is_none());
}
}
#[test]
fn test_gpu_kmeans_assign_dimension_mismatch_returns_none() {
if let Some(ctx) = PqGpuContext::new() {
let sub_vectors = vec![vec![1.0, 0.0, 0.0]]; let centroids = vec![vec![1.0, 0.0, 0.0, 0.0]]; assert!(
gpu_kmeans_assign(&ctx, &sub_vectors, ¢roids, 4).is_none(),
"mismatched sub_vector dim must return None"
);
}
}
}