use crate::error::{HiveGpuError, Result};
use crate::traits::{GpuBackend, GpuContext};
use crate::types::{GpuCapabilities, GpuDeviceInfo, GpuMemoryStats};
use std::sync::Arc;
use tracing::debug;
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use cudarc::cublas::CudaBlas;
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use cudarc::driver::{CudaDevice, result as cuda_result, sys as cuda_sys};
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
#[derive(Debug, Clone)]
pub struct CudaContext {
device: Arc<CudaDevice>,
blas: Arc<CudaBlas>,
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
impl CudaContext {
pub fn new() -> Result<Self> {
Self::new_with_device(0)
}
pub fn new_with_device(ordinal: usize) -> Result<Self> {
let device = CudaDevice::new(ordinal)
.map_err(|e| HiveGpuError::CudaError(format!("CudaDevice::new({ordinal}): {e:?}")))?;
let blas = CudaBlas::new(device.clone())
.map_err(|e| HiveGpuError::CublasError(format!("CudaBlas::new: {e:?}")))?;
debug!(
"cuda context ready: ordinal={} name={:?}",
ordinal,
device.name().ok()
);
Ok(Self {
device,
blas: Arc::new(blas),
})
}
pub fn device_count() -> Result<usize> {
cuda_result::init().map_err(|e| HiveGpuError::CudaError(format!("cuInit: {e:?}")))?;
let count = cuda_result::device::get_count()
.map_err(|e| HiveGpuError::CudaError(format!("cuDeviceGetCount: {e:?}")))?;
Ok(count.max(0) as usize)
}
pub fn is_available() -> bool {
std::panic::catch_unwind(|| {
matches!(cuda_result::init(), Ok(()))
&& cuda_result::device::get_count()
.map(|c| c > 0)
.unwrap_or(false)
})
.unwrap_or(false)
}
pub fn device(&self) -> &Arc<CudaDevice> {
&self.device
}
pub fn blas(&self) -> &CudaBlas {
&self.blas
}
pub fn device_id(&self) -> u32 {
self.device.ordinal() as u32
}
pub fn device_name(&self) -> String {
self.device.name().unwrap_or_else(|_| "unknown".to_string())
}
pub fn compute_capability(&self) -> (u32, u32) {
let major = self
.device
.attribute(cuda_sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)
.unwrap_or(0) as u32;
let minor = self
.device
.attribute(cuda_sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)
.unwrap_or(0) as u32;
(major, minor)
}
pub fn total_memory(&self) -> u64 {
Self::mem_info(&self.device)
.map(|(_, total)| total)
.unwrap_or(0)
}
pub fn available_memory(&self) -> u64 {
Self::mem_info(&self.device)
.map(|(free, _)| free)
.unwrap_or(0)
}
pub fn supports_required_features(&self) -> bool {
self.compute_capability().0 >= 7
}
fn mem_info(device: &Arc<CudaDevice>) -> Result<(u64, u64)> {
device
.bind_to_thread()
.map_err(|e| HiveGpuError::CudaError(format!("bind_to_thread: {e:?}")))?;
let (free, total) = cuda_result::mem_get_info()
.map_err(|e| HiveGpuError::CudaError(format!("cuMemGetInfo: {e:?}")))?;
Ok((free as u64, total as u64))
}
fn pci_bus_id_string(&self) -> Option<String> {
let bus = self
.device
.attribute(cuda_sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID)
.ok()?;
let dev = self
.device
.attribute(cuda_sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID)
.ok()?;
let domain = self
.device
.attribute(cuda_sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID)
.unwrap_or(0);
Some(format!("{domain:04x}:{bus:02x}:{dev:02x}.0"))
}
fn driver_version_string() -> String {
let mut version: i32 = 0;
let status = unsafe { cuda_sys::lib().cuDriverGetVersion(&mut version) };
if status == cuda_sys::CUresult::CUDA_SUCCESS {
format!("CUDA {}.{}", version / 1000, (version % 1000) / 10)
} else {
"CUDA unknown".to_string()
}
}
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
impl GpuBackend for CudaContext {
fn device_info(&self) -> GpuDeviceInfo {
let (free, total) = Self::mem_info(&self.device).unwrap_or((0, 0));
let used = total.saturating_sub(free);
let (major, minor) = self.compute_capability();
let max_threads_per_block = self
.device
.attribute(cuda_sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
.unwrap_or(0) as u32;
let max_shared_memory_per_block = self
.device
.attribute(
cuda_sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK,
)
.unwrap_or(0) as u64;
GpuDeviceInfo {
name: self.device_name(),
backend: "CUDA".to_string(),
total_vram_bytes: total,
available_vram_bytes: free,
used_vram_bytes: used,
driver_version: Self::driver_version_string(),
compute_capability: Some(format!("{major}.{minor}")),
max_threads_per_block,
max_shared_memory_per_block,
device_id: self.device.ordinal() as i32,
pci_bus_id: self.pci_bus_id_string(),
}
}
fn supports_operations(&self) -> GpuCapabilities {
GpuCapabilities {
supports_hnsw: false, supports_batch: true,
max_dimension: 4096,
max_batch_size: 100_000,
}
}
fn memory_stats(&self) -> GpuMemoryStats {
let (free, total) = Self::mem_info(&self.device).unwrap_or((0, 0));
let used = total.saturating_sub(free) as usize;
let available = free as usize;
let utilization = if total == 0 {
0.0
} else {
used as f32 / total as f32
};
GpuMemoryStats {
total_allocated: used,
available,
utilization,
buffer_count: 0,
}
}
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
impl GpuContext for CudaContext {
fn create_storage(
&self,
dimension: usize,
metric: crate::types::GpuDistanceMetric,
) -> Result<Box<dyn crate::traits::GpuVectorStorage>> {
use crate::cuda::vector_storage::CudaVectorStorage;
let storage = CudaVectorStorage::new(Arc::new(self.clone()), dimension, metric)?;
Ok(Box::new(storage))
}
fn create_storage_with_config(
&self,
dimension: usize,
metric: crate::types::GpuDistanceMetric,
_config: crate::types::HnswConfig,
) -> Result<Box<dyn crate::traits::GpuVectorStorage>> {
self.create_storage(dimension, metric)
}
fn memory_stats(&self) -> GpuMemoryStats {
GpuBackend::memory_stats(self)
}
fn device_info(&self) -> Result<GpuDeviceInfo> {
Ok(GpuBackend::device_info(self))
}
}