use tracing::{debug, info, warn};
const MIN_VECTORS_FOR_IVF_PQ: usize = 5_000;
const MIN_CORES_FOR_IVF_PQ: usize = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HardwareBackend {
CpuSimd,
NvidiaCuda,
AmdRocm,
}
pub struct HardwareProfile {
pub backend: HardwareBackend,
pub has_cuda: bool,
pub has_rocm: bool,
pub cpu_logical_cores: usize,
pub has_avx2: bool,
pub has_avx512: bool,
}
impl HardwareProfile {
pub fn detect() -> Self {
let backend = detect_backend();
Self {
backend,
has_cuda: backend == HardwareBackend::NvidiaCuda,
has_rocm: backend == HardwareBackend::AmdRocm,
cpu_logical_cores: rayon::current_num_threads(),
has_avx2: detect_avx2(),
has_avx512: detect_avx512(),
}
}
pub fn recommend_ivf_pq(&self, n_vectors: usize) -> bool {
if n_vectors < MIN_VECTORS_FOR_IVF_PQ {
return false;
}
self.has_cuda || self.has_rocm || self.cpu_logical_cores > MIN_CORES_FOR_IVF_PQ
}
}
use std::sync::OnceLock;
static BACKEND: OnceLock<HardwareBackend> = OnceLock::new();
pub fn detect_backend() -> HardwareBackend {
*BACKEND.get_or_init(|| {
let backend = if probe_rocm_driver() {
HardwareBackend::AmdRocm
} else if probe_cuda_driver() {
HardwareBackend::NvidiaCuda
} else {
HardwareBackend::CpuSimd
};
match backend {
HardwareBackend::AmdRocm => {
info!("ailake: GPU backend selected — AMD ROCm (hipBLAS SGEMM via libloading)");
}
HardwareBackend::NvidiaCuda => {
info!("ailake: GPU backend selected — NVIDIA CUDA (cuBLAS SGEMM via libloading)");
}
HardwareBackend::CpuSimd => {
info!(
"ailake: no GPU detected — using CPU SIMD backend (rayon + AVX2/NEON); \
to enable GPU acceleration install the NVIDIA CUDA runtime \
(libcudart + libcublas) or AMD ROCm (libamdhip64 + libhipblas)"
);
}
}
backend
})
}
pub fn detect_cuda() -> bool {
detect_backend() == HardwareBackend::NvidiaCuda
}
pub fn detect_rocm() -> bool {
detect_backend() == HardwareBackend::AmdRocm
}
#[cfg(target_os = "linux")]
const CUDA_DRIVER_LIB: &str = "libcuda.so.1";
#[cfg(windows)]
const CUDA_DRIVER_LIB: &str = "nvcuda.dll";
#[cfg(not(any(target_os = "linux", windows)))]
const CUDA_DRIVER_LIB: &str = "";
#[cfg(target_os = "linux")]
const ROCM_DRIVER_LIB: &str = "libamdhip64.so";
#[cfg(windows)]
const ROCM_DRIVER_LIB: &str = "amdhip64.dll";
#[cfg(not(any(target_os = "linux", windows)))]
const ROCM_DRIVER_LIB: &str = "";
type GpuResult = i32;
fn probe_cuda_driver() -> bool {
if CUDA_DRIVER_LIB.is_empty() {
return false;
}
let lib = match unsafe { libloading::Library::new(CUDA_DRIVER_LIB) } {
Ok(l) => l,
Err(e) => {
debug!(
"ailake: CUDA driver library `{}` not found ({}); \
GPU acceleration unavailable — install the NVIDIA CUDA driver to enable it",
CUDA_DRIVER_LIB, e
);
return false;
}
};
let cu_init: libloading::Symbol<unsafe extern "C" fn(u32) -> GpuResult> =
match unsafe { lib.get(b"cuInit\0") } {
Ok(f) => f,
Err(e) => {
warn!(
"ailake: `{}` loaded but `cuInit` symbol missing ({}); \
CUDA installation may be incomplete — falling back to CPU",
CUDA_DRIVER_LIB, e
);
return false;
}
};
let rc = unsafe { cu_init(0) };
if rc != 0 {
warn!(
"ailake: cuInit(0) returned error code {} — CUDA driver present but no usable GPU \
or driver not initialised; falling back to CPU SIMD",
rc
);
return false;
}
let cu_count: libloading::Symbol<unsafe extern "C" fn(*mut i32) -> GpuResult> =
match unsafe { lib.get(b"cuDeviceGetCount\0") } {
Ok(f) => f,
Err(e) => {
warn!(
"ailake: `cuDeviceGetCount` symbol missing in `{}` ({}); \
falling back to CPU",
CUDA_DRIVER_LIB, e
);
return false;
}
};
let mut count = 0i32;
let rc = unsafe { cu_count(&mut count) };
if rc == 0 && count == 0 {
warn!(
"ailake: CUDA driver initialised but no CUDA-capable devices found (count=0); \
falling back to CPU SIMD"
);
return false;
}
rc == 0 && count > 0
}
fn probe_rocm_driver() -> bool {
if ROCM_DRIVER_LIB.is_empty() {
return false;
}
let lib = match unsafe { libloading::Library::new(ROCM_DRIVER_LIB) } {
Ok(l) => l,
Err(e) => {
debug!(
"ailake: ROCm library `{}` not found ({}); \
AMD GPU acceleration unavailable — install the ROCm runtime to enable it",
ROCM_DRIVER_LIB, e
);
return false;
}
};
let hip_init: libloading::Symbol<unsafe extern "C" fn(u32) -> GpuResult> =
match unsafe { lib.get(b"hipInit\0") } {
Ok(f) => f,
Err(e) => {
warn!(
"ailake: `{}` loaded but `hipInit` symbol missing ({}); \
ROCm installation may be incomplete — falling back to CPU",
ROCM_DRIVER_LIB, e
);
return false;
}
};
let rc = unsafe { hip_init(0) };
if rc != 0 {
warn!(
"ailake: hipInit(0) returned error code {} — ROCm driver present but no usable GPU \
or driver not initialised; falling back to CPU SIMD",
rc
);
return false;
}
let hip_count: libloading::Symbol<unsafe extern "C" fn(*mut i32) -> GpuResult> =
match unsafe { lib.get(b"hipGetDeviceCount\0") } {
Ok(f) => f,
Err(e) => {
warn!(
"ailake: `hipGetDeviceCount` symbol missing in `{}` ({}); \
falling back to CPU",
ROCM_DRIVER_LIB, e
);
return false;
}
};
let mut count = 0i32;
let rc = unsafe { hip_count(&mut count) };
if rc == 0 && count == 0 {
warn!(
"ailake: ROCm driver initialised but no ROCm-capable devices found (count=0); \
falling back to CPU SIMD"
);
return false;
}
rc == 0 && count > 0
}
fn detect_avx2() -> bool {
#[cfg(target_arch = "x86_64")]
{
std::is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
false
}
fn detect_avx512() -> bool {
#[cfg(target_arch = "x86_64")]
{
std::is_x86_feature_detected!("avx512f")
}
#[cfg(not(target_arch = "x86_64"))]
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn detect_runs_without_panic() {
let p = HardwareProfile::detect();
assert!(p.cpu_logical_cores >= 1);
}
#[test]
fn small_dataset_always_hnsw() {
let p = HardwareProfile {
backend: HardwareBackend::NvidiaCuda,
has_cuda: true,
has_rocm: false,
cpu_logical_cores: 64,
has_avx2: true,
has_avx512: true,
};
assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ - 1));
}
#[test]
fn large_dataset_cuda_picks_ivf_pq() {
let p = HardwareProfile {
backend: HardwareBackend::NvidiaCuda,
has_cuda: true,
has_rocm: false,
cpu_logical_cores: 1,
has_avx2: false,
has_avx512: false,
};
assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn large_dataset_rocm_picks_ivf_pq() {
let p = HardwareProfile {
backend: HardwareBackend::AmdRocm,
has_cuda: false,
has_rocm: true,
cpu_logical_cores: 1,
has_avx2: false,
has_avx512: false,
};
assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn large_dataset_many_cores_picks_ivf_pq() {
let p = HardwareProfile {
backend: HardwareBackend::CpuSimd,
has_cuda: false,
has_rocm: false,
cpu_logical_cores: MIN_CORES_FOR_IVF_PQ + 1,
has_avx2: false,
has_avx512: false,
};
assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn large_dataset_exactly_threshold_picks_hnsw() {
let p = HardwareProfile {
backend: HardwareBackend::CpuSimd,
has_cuda: false,
has_rocm: false,
cpu_logical_cores: MIN_CORES_FOR_IVF_PQ,
has_avx2: false,
has_avx512: false,
};
assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn large_dataset_weak_hardware_picks_hnsw() {
let p = HardwareProfile {
backend: HardwareBackend::CpuSimd,
has_cuda: false,
has_rocm: false,
cpu_logical_cores: MIN_CORES_FOR_IVF_PQ - 1,
has_avx2: false,
has_avx512: false,
};
assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
}
#[test]
fn backend_consistency_cuda() {
let p = HardwareProfile {
backend: HardwareBackend::NvidiaCuda,
has_cuda: true,
has_rocm: false,
cpu_logical_cores: 4,
has_avx2: false,
has_avx512: false,
};
assert!(p.has_cuda);
assert!(!p.has_rocm);
assert_eq!(p.backend, HardwareBackend::NvidiaCuda);
}
#[test]
fn backend_consistency_rocm() {
let p = HardwareProfile {
backend: HardwareBackend::AmdRocm,
has_cuda: false,
has_rocm: true,
cpu_logical_cores: 4,
has_avx2: false,
has_avx512: false,
};
assert!(!p.has_cuda);
assert!(p.has_rocm);
assert_eq!(p.backend, HardwareBackend::AmdRocm);
}
}