use super::context::CudaContext;
use crate::error::{HiveGpuError, Result};
use crate::traits::GpuVectorStorage;
use crate::types::{GpuDistanceMetric, GpuSearchResult, GpuVector};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tracing::{debug, info};
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use cudarc::cublas::{Gemv, GemvConfig, sys as cublas_sys};
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use cudarc::driver::{CudaSlice, DevicePtr, DevicePtrMut, result as cuda_result};
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
const MIN_INITIAL_VECTORS: usize = 1024;
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
const MIN_INITIAL_BYTES: usize = 1024 * 1024;
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub struct CudaVectorStorage {
context: Arc<CudaContext>,
storage: CudaSlice<f32>,
buffer_capacity: usize,
vector_count: usize,
dimension: usize,
metric: GpuDistanceMetric,
vector_id_map: HashMap<String, usize>,
index_to_id: Vec<String>,
removed_indices: HashSet<usize>,
payloads: HashMap<String, HashMap<String, String>>,
norms_sq: Vec<f32>,
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
impl std::fmt::Debug for CudaVectorStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaVectorStorage")
.field("vector_count", &self.vector_count)
.field("buffer_capacity", &self.buffer_capacity)
.field("dimension", &self.dimension)
.field("metric", &self.metric)
.field("removed", &self.removed_indices.len())
.finish()
}
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
impl CudaVectorStorage {
pub fn new(
context: Arc<CudaContext>,
dimension: usize,
metric: GpuDistanceMetric,
) -> Result<Self> {
if dimension == 0 {
return Err(HiveGpuError::InvalidConfiguration(
"dimension must be > 0".to_string(),
));
}
let min_vectors_by_bytes =
(MIN_INITIAL_BYTES / (dimension * std::mem::size_of::<f32>())).max(1);
let capacity = MIN_INITIAL_VECTORS.max(min_vectors_by_bytes);
let slots = capacity
.checked_mul(dimension)
.ok_or_else(|| HiveGpuError::InvalidConfiguration("capacity overflow".to_string()))?;
let storage = context
.device()
.alloc_zeros::<f32>(slots)
.map_err(|e| HiveGpuError::CudaError(format!("alloc_zeros({slots}): {e:?}")))?;
debug!(
"cuda storage created: dim={} capacity={} bytes={}",
dimension,
capacity,
slots * std::mem::size_of::<f32>()
);
Ok(Self {
context,
storage,
buffer_capacity: capacity,
vector_count: 0,
dimension,
metric,
vector_id_map: HashMap::new(),
index_to_id: Vec::new(),
removed_indices: HashSet::new(),
payloads: HashMap::new(),
norms_sq: Vec::new(),
})
}
fn validate_vector(&self, vector: &GpuVector) -> Result<()> {
if vector.data.len() != self.dimension {
return Err(HiveGpuError::DimensionMismatch {
expected: self.dimension,
actual: vector.data.len(),
});
}
if vector.id.is_empty() {
return Err(HiveGpuError::InvalidConfiguration(
"vector id must be non-empty".to_string(),
));
}
if vector.id.len() > 256 {
return Err(HiveGpuError::InvalidConfiguration(
"vector id must be <= 256 chars".to_string(),
));
}
if self.vector_id_map.contains_key(&vector.id) {
return Err(HiveGpuError::InvalidConfiguration(format!(
"duplicate vector id: {}",
vector.id
)));
}
for (i, &v) in vector.data.iter().enumerate() {
if !v.is_finite() {
return Err(HiveGpuError::InvalidConfiguration(format!(
"non-finite component at index {i} in vector {}",
vector.id
)));
}
}
Ok(())
}
fn ensure_capacity(&mut self, additional: usize) -> Result<()> {
let required = self
.vector_count
.checked_add(additional)
.ok_or_else(|| HiveGpuError::InvalidConfiguration("capacity overflow".to_string()))?;
if required <= self.buffer_capacity {
return Ok(());
}
let mut new_capacity = self.buffer_capacity;
while new_capacity < required {
let factor = if new_capacity < 1_000 {
2.0f32
} else if new_capacity < 10_000 {
1.5f32
} else {
1.2f32
};
new_capacity = ((new_capacity as f32) * factor).ceil() as usize;
new_capacity = new_capacity.max(required);
}
let slots = new_capacity
.checked_mul(self.dimension)
.ok_or_else(|| HiveGpuError::InvalidConfiguration("slots overflow".to_string()))?;
let mut new_buffer = self
.context
.device()
.alloc_zeros::<f32>(slots)
.map_err(|e| HiveGpuError::CudaError(format!("alloc_zeros({slots}): {e:?}")))?;
if self.vector_count > 0 {
let live_bytes = self.vector_count * self.dimension * std::mem::size_of::<f32>();
unsafe {
cuda_result::memcpy_dtod_sync(
*new_buffer.device_ptr_mut(),
*self.storage.device_ptr(),
live_bytes,
)
}
.map_err(|e| HiveGpuError::CudaError(format!("memcpy_dtod_sync: {e:?}")))?;
}
info!(
"cuda storage expand: {} -> {} vectors ({:.2} MiB)",
self.buffer_capacity,
new_capacity,
(slots * std::mem::size_of::<f32>()) as f64 / (1024.0 * 1024.0)
);
self.storage = new_buffer;
self.buffer_capacity = new_capacity;
Ok(())
}
fn gpu_scores(&self, query: &[f32]) -> Result<Vec<f32>> {
if self.vector_count == 0 {
return Ok(Vec::new());
}
if query.len() != self.dimension {
return Err(HiveGpuError::DimensionMismatch {
expected: self.dimension,
actual: query.len(),
});
}
for (i, &v) in query.iter().enumerate() {
if !v.is_finite() {
return Err(HiveGpuError::InvalidConfiguration(format!(
"non-finite query component at index {i}"
)));
}
}
let device = self.context.device();
let query_dev = device
.htod_copy(query.to_vec())
.map_err(|e| HiveGpuError::CudaError(format!("htod_copy query: {e:?}")))?;
let mut scores_dev = device
.alloc_zeros::<f32>(self.vector_count)
.map_err(|e| HiveGpuError::CudaError(format!("alloc_zeros scores: {e:?}")))?;
let cfg = GemvConfig::<f32> {
trans: cublas_sys::cublasOperation_t::CUBLAS_OP_T,
m: self.dimension as i32,
n: self.vector_count as i32,
alpha: 1.0,
lda: self.dimension as i32,
incx: 1,
beta: 0.0,
incy: 1,
};
unsafe {
self.context
.blas()
.gemv(cfg, &self.storage, &query_dev, &mut scores_dev)
}
.map_err(|e| HiveGpuError::CublasError(format!("sgemv: {e:?}")))?;
let scores = device
.dtoh_sync_copy(&scores_dev)
.map_err(|e| HiveGpuError::CudaError(format!("dtoh_sync_copy scores: {e:?}")))?;
Ok(scores)
}
fn apply_metric(&self, raw_scores: &mut [f32], query: &[f32]) {
let query_norm_sq = dot_self(query);
match self.metric {
GpuDistanceMetric::DotProduct => {}
GpuDistanceMetric::Cosine => {
let q_norm = query_norm_sq.sqrt();
for (i, s) in raw_scores.iter_mut().enumerate() {
let v_norm = self.norms_sq[i].sqrt();
let denom = q_norm * v_norm;
*s = if denom > 0.0 { *s / denom } else { 0.0 };
}
}
GpuDistanceMetric::Euclidean => {
for (i, s) in raw_scores.iter_mut().enumerate() {
*s = (self.norms_sq[i] - 2.0 * *s + query_norm_sq).max(0.0);
}
}
}
}
fn select_top_k(&self, mut scored: Vec<(usize, f32)>, limit: usize) -> Vec<GpuSearchResult> {
scored.retain(|(idx, _)| !self.removed_indices.contains(idx));
match self.metric {
GpuDistanceMetric::Euclidean => {
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
}
_ => {
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
}
}
scored.truncate(limit);
scored
.into_iter()
.map(|(index, score)| {
let id = self.index_to_id[index].clone();
let similarity = match self.metric {
GpuDistanceMetric::Euclidean => 1.0 / (1.0 + score.sqrt()),
_ => score,
};
GpuSearchResult {
id,
score: similarity,
index,
}
})
.collect()
}
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
impl GpuVectorStorage for CudaVectorStorage {
fn add_vectors(&mut self, vectors: &[GpuVector]) -> Result<Vec<usize>> {
if vectors.is_empty() {
return Ok(Vec::new());
}
let mut seen = HashSet::with_capacity(vectors.len());
for v in vectors {
self.validate_vector(v)?;
if !seen.insert(v.id.as_str()) {
return Err(HiveGpuError::InvalidConfiguration(format!(
"duplicate vector id within batch: {}",
v.id
)));
}
}
self.ensure_capacity(vectors.len())?;
let mut flat = Vec::with_capacity(vectors.len() * self.dimension);
for v in vectors {
flat.extend_from_slice(&v.data);
}
let staging = self
.context
.device()
.htod_copy(flat)
.map_err(|e| HiveGpuError::CudaError(format!("htod_copy batch: {e:?}")))?;
let bytes = vectors.len() * self.dimension * std::mem::size_of::<f32>();
let offset_bytes = (self.vector_count * self.dimension * std::mem::size_of::<f32>()) as u64;
unsafe {
let dst = *self.storage.device_ptr() + offset_bytes;
cuda_result::memcpy_dtod_sync(dst, *staging.device_ptr(), bytes)
}
.map_err(|e| HiveGpuError::CudaError(format!("memcpy_dtod_sync batch: {e:?}")))?;
let mut indices = Vec::with_capacity(vectors.len());
for v in vectors {
let index = self.vector_count;
self.vector_id_map.insert(v.id.clone(), index);
self.index_to_id.push(v.id.clone());
self.payloads.insert(v.id.clone(), v.metadata.clone());
self.norms_sq.push(dot_self(&v.data));
self.vector_count += 1;
indices.push(index);
}
Ok(indices)
}
fn search(&self, query: &[f32], limit: usize) -> Result<Vec<GpuSearchResult>> {
if limit == 0 || self.vector_count == 0 {
return Ok(Vec::new());
}
let mut scores = self.gpu_scores(query)?;
self.apply_metric(&mut scores, query);
let scored: Vec<(usize, f32)> = scores.into_iter().enumerate().collect();
Ok(self.select_top_k(scored, limit))
}
fn remove_vectors(&mut self, ids: &[String]) -> Result<()> {
for id in ids {
if let Some(&index) = self.vector_id_map.get(id) {
self.removed_indices.insert(index);
self.payloads.remove(id);
} else {
return Err(HiveGpuError::VectorNotFound(id.clone()));
}
}
Ok(())
}
fn vector_count(&self) -> usize {
self.vector_count.saturating_sub(self.removed_indices.len())
}
fn dimension(&self) -> usize {
self.dimension
}
fn get_vector(&self, id: &str) -> Result<Option<GpuVector>> {
let Some(&index) = self.vector_id_map.get(id) else {
return Ok(None);
};
if self.removed_indices.contains(&index) {
return Ok(None);
}
let offset = index * self.dimension;
let device = self.context.device();
let host_view = unsafe {
let src = *self.storage.device_ptr() + (offset * std::mem::size_of::<f32>()) as u64;
let mut dst = vec![0f32; self.dimension];
cuda_result::memcpy_dtoh_sync(&mut dst, src)
.map_err(|e| HiveGpuError::CudaError(format!("memcpy_dtoh_sync: {e:?}")))?;
dst
};
let _ = device;
let metadata = self.payloads.get(id).cloned().unwrap_or_default();
Ok(Some(GpuVector {
id: id.to_string(),
data: host_view,
metadata,
}))
}
fn clear(&mut self) -> Result<()> {
self.vector_count = 0;
self.buffer_capacity = self.buffer_capacity.max(MIN_INITIAL_VECTORS);
self.vector_id_map.clear();
self.index_to_id.clear();
self.removed_indices.clear();
self.payloads.clear();
self.norms_sq.clear();
Ok(())
}
}
#[inline]
fn dot_self(v: &[f32]) -> f32 {
v.iter().map(|&x| x * x).sum()
}