const MIN_VECTORS_FOR_IVF_PQ: usize = 5_000;
const MIN_CORES_FOR_IVF_PQ: usize = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HardwareBackend {
CpuSimd,
NvidiaCuda,
AmdRocm,
}
pub struct HardwareProfile {
pub backend: HardwareBackend,
pub has_cuda: bool,
pub has_rocm: bool,
pub cpu_logical_cores: usize,
pub has_avx2: bool,
pub has_avx512: bool,
}
impl HardwareProfile {
pub fn detect() -> Self {
let backend = detect_backend();
Self {
backend,
has_cuda: backend == HardwareBackend::NvidiaCuda,
has_rocm: backend == HardwareBackend::AmdRocm,
cpu_logical_cores: rayon::current_num_threads(),
has_avx2: detect_avx2(),
has_avx512: detect_avx512(),
}
}
pub fn recommend_ivf_pq(&self, n_vectors: usize) -> bool {
if n_vectors < MIN_VECTORS_FOR_IVF_PQ {
return false;
}
self.has_cuda || self.has_rocm || self.cpu_logical_cores > MIN_CORES_FOR_IVF_PQ
}
}
use std::sync::OnceLock;
static BACKEND: OnceLock<HardwareBackend> = OnceLock::new();
pub fn detect_backend() -> HardwareBackend {
*BACKEND.get_or_init(|| {
if probe_rocm_driver() {
HardwareBackend::AmdRocm
} else if probe_cuda_driver() {
HardwareBackend::NvidiaCuda
} else {
HardwareBackend::CpuSimd
}
})
}
pub fn detect_cuda() -> bool {
detect_backend() == HardwareBackend::NvidiaCuda
}
pub fn detect_rocm() -> bool {
detect_backend() == HardwareBackend::AmdRocm
}
#[cfg(target_os = "linux")]
const CUDA_DRIVER_LIB: &str = "libcuda.so.1";
#[cfg(windows)]
const CUDA_DRIVER_LIB: &str = "nvcuda.dll";
#[cfg(not(any(target_os = "linux", windows)))]
const CUDA_DRIVER_LIB: &str = "";
#[cfg(target_os = "linux")]
const ROCM_DRIVER_LIB: &str = "libamdhip64.so";
#[cfg(windows)]
const ROCM_DRIVER_LIB: &str = "amdhip64.dll";
#[cfg(not(any(target_os = "linux", windows)))]
const ROCM_DRIVER_LIB: &str = "";
type GpuResult = i32;
fn probe_cuda_driver() -> bool {
if CUDA_DRIVER_LIB.is_empty() {
return false;
}
let lib = match unsafe { libloading::Library::new(CUDA_DRIVER_LIB) } {
Ok(l) => l,
Err(_) => return false,
};
let cu_init: libloading::Symbol<unsafe extern "C" fn(u32) -> GpuResult> =
match unsafe { lib.get(b"cuInit\0") } {
Ok(f) => f,
Err(_) => return false,
};
if unsafe { cu_init(0) } != 0 {
return false;
}
let cu_count: libloading::Symbol<unsafe extern "C" fn(*mut i32) -> GpuResult> =
match unsafe { lib.get(b"cuDeviceGetCount\0") } {
Ok(f) => f,
Err(_) => return false,
};
let mut count = 0i32;
let rc = unsafe { cu_count(&mut count) };
rc == 0 && count > 0
}
fn probe_rocm_driver() -> bool {
if ROCM_DRIVER_LIB.is_empty() {
return false;
}
let lib = match unsafe { libloading::Library::new(ROCM_DRIVER_LIB) } {
Ok(l) => l,
Err(_) => return false,
};
let hip_init: libloading::Symbol<unsafe extern "C" fn(u32) -> GpuResult> =
match unsafe { lib.get(b"hipInit\0") } {
Ok(f) => f,
Err(_) => return false,
};
if unsafe { hip_init(0) } != 0 {
return false;
}
let hip_count: libloading::Symbol<unsafe extern "C" fn(*mut i32) -> GpuResult> =
match unsafe { lib.get(b"hipGetDeviceCount\0") } {
Ok(f) => f,
Err(_) => return false,
};
let mut count = 0i32;
let rc = unsafe { hip_count(&mut count) };
rc == 0 && count > 0
}
fn detect_avx2() -> bool {
#[cfg(target_arch = "x86_64")]
{
std::is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
false
}
fn detect_avx512() -> bool {
#[cfg(target_arch = "x86_64")]
{
std::is_x86_feature_detected!("avx512f")
}
#[cfg(not(target_arch = "x86_64"))]
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_runs_without_panic() {
let p = HardwareProfile::detect();
assert!(p.cpu_logical_cores >= 1);
}
#[test]
fn small_dataset_always_hnsw() {
let p = HardwareProfile {
backend: HardwareBackend::NvidiaCuda,
has_cuda: true,
has_rocm: false,
cpu_logical_cores: 64,
has_avx2: true,
has_avx512: true,
};
assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ - 1));
}
#[test]
fn large_dataset_cuda_picks_ivf_pq() {
let p = HardwareProfile {
backend: HardwareBackend::NvidiaCuda,
has_cuda: true,
has_rocm: false,
cpu_logical_cores: 1,
has_avx2: false,
has_avx512: false,
};
assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn large_dataset_rocm_picks_ivf_pq() {
let p = HardwareProfile {
backend: HardwareBackend::AmdRocm,
has_cuda: false,
has_rocm: true,
cpu_logical_cores: 1,
has_avx2: false,
has_avx512: false,
};
assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn large_dataset_many_cores_picks_ivf_pq() {
let p = HardwareProfile {
backend: HardwareBackend::CpuSimd,
has_cuda: false,
has_rocm: false,
cpu_logical_cores: MIN_CORES_FOR_IVF_PQ + 1,
has_avx2: false,
has_avx512: false,
};
assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn large_dataset_exactly_threshold_picks_hnsw() {
let p = HardwareProfile {
backend: HardwareBackend::CpuSimd,
has_cuda: false,
has_rocm: false,
cpu_logical_cores: MIN_CORES_FOR_IVF_PQ,
has_avx2: false,
has_avx512: false,
};
assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn large_dataset_weak_hardware_picks_hnsw() {
let p = HardwareProfile {
backend: HardwareBackend::CpuSimd,
has_cuda: false,
has_rocm: false,
cpu_logical_cores: MIN_CORES_FOR_IVF_PQ - 1,
has_avx2: false,
has_avx512: false,
};
assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn backend_consistency_cuda() {
let p = HardwareProfile {
backend: HardwareBackend::NvidiaCuda,
has_cuda: true,
has_rocm: false,
cpu_logical_cores: 4,
has_avx2: false,
has_avx512: false,
};
assert!(p.has_cuda);
assert!(!p.has_rocm);
assert_eq!(p.backend, HardwareBackend::NvidiaCuda);
}
#[test]
fn backend_consistency_rocm() {
let p = HardwareProfile {
backend: HardwareBackend::AmdRocm,
has_cuda: false,
has_rocm: true,
cpu_logical_cores: 4,
has_avx2: false,
has_avx512: false,
};
assert!(!p.has_cuda);
assert!(p.has_rocm);
assert_eq!(p.backend, HardwareBackend::AmdRocm);
}
}