mod shaders;
use std::sync::OnceLock;
use wgpu::util::DeviceExt;
use crate::simd_native;
static GPU_AVAILABLE: OnceLock<bool> = OnceLock::new();
pub struct GpuAccelerator {
device: wgpu::Device,
queue: wgpu::Queue,
cosine_pipeline: wgpu::ComputePipeline,
}
impl GpuAccelerator {
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn new() -> Option<Self> {
let backends = Self::preferred_backends();
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
backends,
..Default::default()
});
let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
}))?;
let (device, queue) = pollster::block_on(adapter.request_device(
&wgpu::DeviceDescriptor {
label: Some("VelesDB GPU"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::default(),
memory_hints: wgpu::MemoryHints::Performance,
},
None,
))
.ok()?;
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Cosine Similarity Shader"),
source: wgpu::ShaderSource::Wgsl(shaders::COSINE_SHADER.into()),
});
let bind_group_layout =
super::helpers::create_quad_bind_group_layout(&device, "Cosine Bind Group Layout");
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Cosine Pipeline Layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let cosine_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Cosine Similarity Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("batch_cosine"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
Some(Self {
device,
queue,
cosine_pipeline,
})
}
#[must_use]
fn preferred_backends() -> wgpu::Backends {
#[cfg(target_os = "linux")]
{
let has_display = std::env::var_os("DISPLAY").is_some()
|| std::env::var_os("WAYLAND_DISPLAY").is_some();
if !has_display {
return wgpu::Backends::VULKAN;
}
}
wgpu::Backends::all()
}
#[must_use]
pub fn is_available() -> bool {
*GPU_AVAILABLE.get_or_init(|| Self::new().is_some())
}
#[allow(clippy::too_many_lines)]
pub fn batch_cosine_similarity(
&self,
vectors: &[f32],
query: &[f32],
dimension: usize,
) -> crate::error::Result<Vec<f32>> {
if dimension == 0 || vectors.is_empty() {
return Ok(Vec::new());
}
let num_vectors = vectors.len() / dimension;
if num_vectors == 0 {
return Ok(Vec::new());
}
if u32::try_from(dimension).is_err() {
return Err(crate::error::Error::GpuError(format!(
"dimension {dimension} exceeds u32::MAX"
)));
}
if u32::try_from(num_vectors).is_err() {
return Err(crate::error::Error::GpuError(format!(
"num_vectors {num_vectors} exceeds u32::MAX"
)));
}
let query_buffer = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Query Buffer"),
contents: bytemuck::cast_slice(query),
usage: wgpu::BufferUsages::STORAGE,
});
let vectors_buffer = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Vectors Buffer"),
contents: bytemuck::cast_slice(vectors),
usage: wgpu::BufferUsages::STORAGE,
});
let results_size = (num_vectors * std::mem::size_of::<f32>()) as u64;
let results_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Results Buffer"),
size: results_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer"),
size: results_size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
#[allow(clippy::cast_possible_truncation)]
let params = [dimension as u32, num_vectors as u32];
let params_buffer = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Params Buffer"),
contents: bytemuck::cast_slice(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let bind_group_layout = self.cosine_pipeline.get_bind_group_layout(0);
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Cosine Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: query_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: vectors_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: results_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: params_buffer.as_entire_binding(),
},
],
});
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Cosine Encoder"),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Cosine Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(&self.cosine_pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
#[allow(clippy::cast_possible_truncation)]
let workgroups = num_vectors.div_ceil(256) as u32;
compute_pass.dispatch_workgroups(workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(&results_buffer, 0, &staging_buffer, 0, results_size);
self.queue.submit(std::iter::once(encoder.finish()));
super::helpers::readback_buffer::<f32>(&self.device, &staging_buffer, num_vectors)
.ok_or_else(|| {
crate::error::Error::GpuError("GPU map-async operation failed".to_string())
})
}
#[must_use]
pub fn batch_euclidean_distance(
&self,
vectors: &[f32],
query: &[f32],
dimension: usize,
) -> Vec<f32> {
batch_flat_simd(vectors, query, dimension, simd_native::euclidean_native)
}
#[must_use]
pub fn batch_dot_product(&self, vectors: &[f32], query: &[f32], dimension: usize) -> Vec<f32> {
batch_flat_simd(vectors, query, dimension, simd_native::dot_product_native)
}
}
fn batch_flat_simd(
vectors: &[f32],
query: &[f32],
dimension: usize,
distance_fn: fn(&[f32], &[f32]) -> f32,
) -> Vec<f32> {
if dimension == 0 || vectors.is_empty() {
return Vec::new();
}
let num_vectors = vectors.len() / dimension;
if num_vectors == 0 {
return Vec::new();
}
let mut results = Vec::with_capacity(num_vectors);
for i in 0..num_vectors {
let offset = i * dimension;
let vec = &vectors[offset..offset + dimension];
results.push(distance_fn(query, vec));
}
results
}