Skip to main content

ailake_index/
hardware.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Hardware capability detection for adaptive index selection.
3//!
4//! Detection priority: AMD ROCm → NVIDIA CUDA → CPU SIMD.
5//! AMD is checked first because ROCm installations often provide a CUDA compatibility
6//! layer (`libcuda.so.1`), which would incorrectly report as NVIDIA without the priority check.
7
8use tracing::{debug, info, warn};
9
10/// Minimum vectors to justify IVF-PQ training (k-means + PQ codebook overhead).
11const MIN_VECTORS_FOR_IVF_PQ: usize = 5_000;
12
13/// Minimum logical CPU cores (exclusive) to consider "powerful" when no GPU is available.
14/// IVF-PQ requires strictly more than this value (i.e. > 8).
15const MIN_CORES_FOR_IVF_PQ: usize = 8;
16
17/// Active GPU/CPU compute backend.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum HardwareBackend {
20    /// No GPU detected — use SIMD-accelerated CPU kernels (rayon).
21    CpuSimd,
22    /// NVIDIA GPU with CUDA driver (candle-core + CUDA kernels).
23    NvidiaCuda,
24    /// AMD GPU with ROCm/HIP driver (hipBLAS SGEMM via libloading).
25    AmdRocm,
26}
27
28/// Detected hardware capabilities.
29pub struct HardwareProfile {
30    /// Active GPU backend (highest priority GPU, or CpuSimd when none found).
31    pub backend: HardwareBackend,
32    /// CUDA-capable GPU found at runtime (NVIDIA only; requires `gpu` feature at compile time).
33    pub has_cuda: bool,
34    /// AMD ROCm/HIP GPU found at runtime.
35    pub has_rocm: bool,
36    /// Logical CPU cores available to rayon's global thread pool.
37    pub cpu_logical_cores: usize,
38    /// x86_64: AVX2 support detected via CPUID.
39    pub has_avx2: bool,
40    /// x86_64: AVX-512F support detected via CPUID.
41    pub has_avx512: bool,
42}
43
44impl HardwareProfile {
45    /// Probe the current machine's capabilities.
46    pub fn detect() -> Self {
47        let backend = detect_backend();
48        Self {
49            backend,
50            has_cuda: backend == HardwareBackend::NvidiaCuda,
51            has_rocm: backend == HardwareBackend::AmdRocm,
52            cpu_logical_cores: rayon::current_num_threads(),
53            has_avx2: detect_avx2(),
54            has_avx512: detect_avx512(),
55        }
56    }
57
58    /// True when IVF-PQ training is justified for a dataset of `n_vectors` vectors.
59    ///
60    /// Returns false when:
61    /// - `n_vectors < MIN_VECTORS_FOR_IVF_PQ` (k-means clusters would be meaningless)
62    /// - Neither GPU (CUDA or ROCm) nor a sufficiently parallel CPU is available
63    pub fn recommend_ivf_pq(&self, n_vectors: usize) -> bool {
64        if n_vectors < MIN_VECTORS_FOR_IVF_PQ {
65            return false;
66        }
67        self.has_cuda || self.has_rocm || self.cpu_logical_cores > MIN_CORES_FOR_IVF_PQ
68    }
69}
70
71// ── Backend detection ─────────────────────────────────────────────────────────
72
73use std::sync::OnceLock;
74
75static BACKEND: OnceLock<HardwareBackend> = OnceLock::new();
76
77/// Returns the best available compute backend, probing once per process.
78///
79/// Priority: AMD ROCm > NVIDIA CUDA > CPU SIMD.
80/// AMD is checked first to correctly identify ROCm machines that also expose
81/// a CUDA compatibility layer.
82pub fn detect_backend() -> HardwareBackend {
83    *BACKEND.get_or_init(|| {
84        let backend = if probe_rocm_driver() {
85            HardwareBackend::AmdRocm
86        } else if probe_cuda_driver() {
87            HardwareBackend::NvidiaCuda
88        } else {
89            HardwareBackend::CpuSimd
90        };
91        match backend {
92            HardwareBackend::AmdRocm => {
93                info!("ailake: GPU backend selected — AMD ROCm (hipBLAS SGEMM via libloading)");
94            }
95            HardwareBackend::NvidiaCuda => {
96                info!("ailake: GPU backend selected — NVIDIA CUDA (cuBLAS SGEMM via libloading)");
97            }
98            HardwareBackend::CpuSimd => {
99                info!(
100                    "ailake: no GPU detected — using CPU SIMD backend (rayon + AVX2/NEON); \
101                     to enable GPU acceleration install the NVIDIA CUDA runtime \
102                     (libcudart + libcublas) or AMD ROCm (libamdhip64 + libhipblas)"
103                );
104            }
105        }
106        backend
107    })
108}
109
110/// True only when an NVIDIA CUDA GPU is the active backend.
111/// Returns false on AMD ROCm machines (even those with a CUDA compat layer).
112pub fn detect_cuda() -> bool {
113    detect_backend() == HardwareBackend::NvidiaCuda
114}
115
116/// True only when an AMD ROCm/HIP GPU is the active backend.
117pub fn detect_rocm() -> bool {
118    detect_backend() == HardwareBackend::AmdRocm
119}
120
121// ── Library names ─────────────────────────────────────────────────────────────
122
123#[cfg(target_os = "linux")]
124const CUDA_DRIVER_LIB: &str = "libcuda.so.1";
125#[cfg(windows)]
126const CUDA_DRIVER_LIB: &str = "nvcuda.dll";
127#[cfg(not(any(target_os = "linux", windows)))]
128const CUDA_DRIVER_LIB: &str = "";
129
130#[cfg(target_os = "linux")]
131const ROCM_DRIVER_LIB: &str = "libamdhip64.so";
132#[cfg(windows)]
133const ROCM_DRIVER_LIB: &str = "amdhip64.dll";
134#[cfg(not(any(target_os = "linux", windows)))]
135const ROCM_DRIVER_LIB: &str = "";
136
137/// CUDA/HIP driver API result code — 0 means success.
138type GpuResult = i32;
139
140// ── CUDA probe ────────────────────────────────────────────────────────────────
141
142fn probe_cuda_driver() -> bool {
143    if CUDA_DRIVER_LIB.is_empty() {
144        return false;
145    }
146    let lib = match unsafe { libloading::Library::new(CUDA_DRIVER_LIB) } {
147        Ok(l) => l,
148        Err(e) => {
149            debug!(
150                "ailake: CUDA driver library `{}` not found ({}); \
151                 GPU acceleration unavailable — install the NVIDIA CUDA driver to enable it",
152                CUDA_DRIVER_LIB, e
153            );
154            return false;
155        }
156    };
157
158    let cu_init: libloading::Symbol<unsafe extern "C" fn(u32) -> GpuResult> =
159        match unsafe { lib.get(b"cuInit\0") } {
160            Ok(f) => f,
161            Err(e) => {
162                warn!(
163                    "ailake: `{}` loaded but `cuInit` symbol missing ({}); \
164                     CUDA installation may be incomplete — falling back to CPU",
165                    CUDA_DRIVER_LIB, e
166                );
167                return false;
168            }
169        };
170    let rc = unsafe { cu_init(0) };
171    if rc != 0 {
172        warn!(
173            "ailake: cuInit(0) returned error code {} — CUDA driver present but no usable GPU \
174             or driver not initialised; falling back to CPU SIMD",
175            rc
176        );
177        return false;
178    }
179
180    let cu_count: libloading::Symbol<unsafe extern "C" fn(*mut i32) -> GpuResult> =
181        match unsafe { lib.get(b"cuDeviceGetCount\0") } {
182            Ok(f) => f,
183            Err(e) => {
184                warn!(
185                    "ailake: `cuDeviceGetCount` symbol missing in `{}` ({}); \
186                     falling back to CPU",
187                    CUDA_DRIVER_LIB, e
188                );
189                return false;
190            }
191        };
192    let mut count = 0i32;
193    let rc = unsafe { cu_count(&mut count) };
194    if rc == 0 && count == 0 {
195        warn!(
196            "ailake: CUDA driver initialised but no CUDA-capable devices found (count=0); \
197             falling back to CPU SIMD"
198        );
199        return false;
200    }
201    rc == 0 && count > 0
202}
203
204// ── ROCm probe ────────────────────────────────────────────────────────────────
205
206/// Dynamically probe the AMD HIP runtime library.
207///
208/// Uses `hipInit(0)` + `hipGetDeviceCount` via libloading — no ROCm toolkit
209/// required at compile time. Returns false when:
210/// - `libamdhip64.so` is not installed
211/// - `hipInit` fails (no GPU / driver not loaded)
212/// - no ROCm-capable devices found
213fn probe_rocm_driver() -> bool {
214    if ROCM_DRIVER_LIB.is_empty() {
215        return false;
216    }
217    let lib = match unsafe { libloading::Library::new(ROCM_DRIVER_LIB) } {
218        Ok(l) => l,
219        Err(e) => {
220            debug!(
221                "ailake: ROCm library `{}` not found ({}); \
222                 AMD GPU acceleration unavailable — install the ROCm runtime to enable it",
223                ROCM_DRIVER_LIB, e
224            );
225            return false;
226        }
227    };
228
229    // hipInit(0) must succeed before any other HIP driver call.
230    let hip_init: libloading::Symbol<unsafe extern "C" fn(u32) -> GpuResult> =
231        match unsafe { lib.get(b"hipInit\0") } {
232            Ok(f) => f,
233            Err(e) => {
234                warn!(
235                    "ailake: `{}` loaded but `hipInit` symbol missing ({}); \
236                     ROCm installation may be incomplete — falling back to CPU",
237                    ROCM_DRIVER_LIB, e
238                );
239                return false;
240            }
241        };
242    let rc = unsafe { hip_init(0) };
243    if rc != 0 {
244        warn!(
245            "ailake: hipInit(0) returned error code {} — ROCm driver present but no usable GPU \
246             or driver not initialised; falling back to CPU SIMD",
247            rc
248        );
249        return false;
250    }
251
252    let hip_count: libloading::Symbol<unsafe extern "C" fn(*mut i32) -> GpuResult> =
253        match unsafe { lib.get(b"hipGetDeviceCount\0") } {
254            Ok(f) => f,
255            Err(e) => {
256                warn!(
257                    "ailake: `hipGetDeviceCount` symbol missing in `{}` ({}); \
258                     falling back to CPU",
259                    ROCM_DRIVER_LIB, e
260                );
261                return false;
262            }
263        };
264    let mut count = 0i32;
265    let rc = unsafe { hip_count(&mut count) };
266    if rc == 0 && count == 0 {
267        warn!(
268            "ailake: ROCm driver initialised but no ROCm-capable devices found (count=0); \
269             falling back to CPU SIMD"
270        );
271        return false;
272    }
273    rc == 0 && count > 0
274}
275
276// ── SIMD detection ────────────────────────────────────────────────────────────
277
278fn detect_avx2() -> bool {
279    #[cfg(target_arch = "x86_64")]
280    {
281        std::is_x86_feature_detected!("avx2")
282    }
283    #[cfg(not(target_arch = "x86_64"))]
284    false
285}
286
287fn detect_avx512() -> bool {
288    #[cfg(target_arch = "x86_64")]
289    {
290        std::is_x86_feature_detected!("avx512f")
291    }
292    #[cfg(not(target_arch = "x86_64"))]
293    false
294}
295
296// ── Tests ─────────────────────────────────────────────────────────────────────
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn detect_runs_without_panic() {
304        let p = HardwareProfile::detect();
305        assert!(p.cpu_logical_cores >= 1);
306    }
307
308    #[test]
309    fn small_dataset_always_hnsw() {
310        let p = HardwareProfile {
311            backend: HardwareBackend::NvidiaCuda,
312            has_cuda: true,
313            has_rocm: false,
314            cpu_logical_cores: 64,
315            has_avx2: true,
316            has_avx512: true,
317        };
318        assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ - 1));
319    }
320
321    #[test]
322    fn large_dataset_cuda_picks_ivf_pq() {
323        let p = HardwareProfile {
324            backend: HardwareBackend::NvidiaCuda,
325            has_cuda: true,
326            has_rocm: false,
327            cpu_logical_cores: 1,
328            has_avx2: false,
329            has_avx512: false,
330        };
331        assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
332    }
333
334    #[test]
335    fn large_dataset_rocm_picks_ivf_pq() {
336        let p = HardwareProfile {
337            backend: HardwareBackend::AmdRocm,
338            has_cuda: false,
339            has_rocm: true,
340            cpu_logical_cores: 1,
341            has_avx2: false,
342            has_avx512: false,
343        };
344        assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
345    }
346
347    #[test]
348    fn large_dataset_many_cores_picks_ivf_pq() {
349        let p = HardwareProfile {
350            backend: HardwareBackend::CpuSimd,
351            has_cuda: false,
352            has_rocm: false,
353            cpu_logical_cores: MIN_CORES_FOR_IVF_PQ + 1,
354            has_avx2: false,
355            has_avx512: false,
356        };
357        assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
358    }
359
360    #[test]
361    fn large_dataset_exactly_threshold_picks_hnsw() {
362        // strictly > 8, so exactly 8 cores → HNSW
363        let p = HardwareProfile {
364            backend: HardwareBackend::CpuSimd,
365            has_cuda: false,
366            has_rocm: false,
367            cpu_logical_cores: MIN_CORES_FOR_IVF_PQ,
368            has_avx2: false,
369            has_avx512: false,
370        };
371        assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
372    }
373
374    #[test]
375    fn large_dataset_weak_hardware_picks_hnsw() {
376        let p = HardwareProfile {
377            backend: HardwareBackend::CpuSimd,
378            has_cuda: false,
379            has_rocm: false,
380            cpu_logical_cores: MIN_CORES_FOR_IVF_PQ - 1,
381            has_avx2: false,
382            has_avx512: false,
383        };
384        assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
385    }
386
387    #[test]
388    fn backend_consistency_cuda() {
389        let p = HardwareProfile {
390            backend: HardwareBackend::NvidiaCuda,
391            has_cuda: true,
392            has_rocm: false,
393            cpu_logical_cores: 4,
394            has_avx2: false,
395            has_avx512: false,
396        };
397        assert!(p.has_cuda);
398        assert!(!p.has_rocm);
399        assert_eq!(p.backend, HardwareBackend::NvidiaCuda);
400    }
401
402    #[test]
403    fn backend_consistency_rocm() {
404        let p = HardwareProfile {
405            backend: HardwareBackend::AmdRocm,
406            has_cuda: false,
407            has_rocm: true,
408            cpu_logical_cores: 4,
409            has_avx2: false,
410            has_avx512: false,
411        };
412        assert!(!p.has_cuda);
413        assert!(p.has_rocm);
414        assert_eq!(p.backend, HardwareBackend::AmdRocm);
415    }
416}