1use tracing::{debug, info, warn};
9
10const MIN_VECTORS_FOR_IVF_PQ: usize = 5_000;
12
13const MIN_CORES_FOR_IVF_PQ: usize = 8;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum HardwareBackend {
20 CpuSimd,
22 NvidiaCuda,
24 AmdRocm,
26}
27
28pub struct HardwareProfile {
30 pub backend: HardwareBackend,
32 pub has_cuda: bool,
34 pub has_rocm: bool,
36 pub cpu_logical_cores: usize,
38 pub has_avx2: bool,
40 pub has_avx512: bool,
42}
43
44impl HardwareProfile {
45 pub fn detect() -> Self {
47 let backend = detect_backend();
48 Self {
49 backend,
50 has_cuda: backend == HardwareBackend::NvidiaCuda,
51 has_rocm: backend == HardwareBackend::AmdRocm,
52 cpu_logical_cores: rayon::current_num_threads(),
53 has_avx2: detect_avx2(),
54 has_avx512: detect_avx512(),
55 }
56 }
57
58 pub fn recommend_ivf_pq(&self, n_vectors: usize) -> bool {
64 if n_vectors < MIN_VECTORS_FOR_IVF_PQ {
65 return false;
66 }
67 self.has_cuda || self.has_rocm || self.cpu_logical_cores > MIN_CORES_FOR_IVF_PQ
68 }
69}
70
71use std::sync::OnceLock;
74
75static BACKEND: OnceLock<HardwareBackend> = OnceLock::new();
76
77pub fn detect_backend() -> HardwareBackend {
83 *BACKEND.get_or_init(|| {
84 let backend = if probe_rocm_driver() {
85 HardwareBackend::AmdRocm
86 } else if probe_cuda_driver() {
87 HardwareBackend::NvidiaCuda
88 } else {
89 HardwareBackend::CpuSimd
90 };
91 match backend {
92 HardwareBackend::AmdRocm => {
93 info!("ailake: GPU backend selected — AMD ROCm (hipBLAS SGEMM via libloading)");
94 }
95 HardwareBackend::NvidiaCuda => {
96 info!("ailake: GPU backend selected — NVIDIA CUDA (cuBLAS SGEMM via libloading)");
97 }
98 HardwareBackend::CpuSimd => {
99 info!(
100 "ailake: no GPU detected — using CPU SIMD backend (rayon + AVX2/NEON); \
101 to enable GPU acceleration install the NVIDIA CUDA runtime \
102 (libcudart + libcublas) or AMD ROCm (libamdhip64 + libhipblas)"
103 );
104 }
105 }
106 backend
107 })
108}
109
110pub fn detect_cuda() -> bool {
113 detect_backend() == HardwareBackend::NvidiaCuda
114}
115
116pub fn detect_rocm() -> bool {
118 detect_backend() == HardwareBackend::AmdRocm
119}
120
121#[cfg(target_os = "linux")]
124const CUDA_DRIVER_LIB: &str = "libcuda.so.1";
125#[cfg(windows)]
126const CUDA_DRIVER_LIB: &str = "nvcuda.dll";
127#[cfg(not(any(target_os = "linux", windows)))]
128const CUDA_DRIVER_LIB: &str = "";
129
130#[cfg(target_os = "linux")]
131const ROCM_DRIVER_LIB: &str = "libamdhip64.so";
132#[cfg(windows)]
133const ROCM_DRIVER_LIB: &str = "amdhip64.dll";
134#[cfg(not(any(target_os = "linux", windows)))]
135const ROCM_DRIVER_LIB: &str = "";
136
137type GpuResult = i32;
139
140fn probe_cuda_driver() -> bool {
143 if CUDA_DRIVER_LIB.is_empty() {
144 return false;
145 }
146 let lib = match unsafe { libloading::Library::new(CUDA_DRIVER_LIB) } {
147 Ok(l) => l,
148 Err(e) => {
149 debug!(
150 "ailake: CUDA driver library `{}` not found ({}); \
151 GPU acceleration unavailable — install the NVIDIA CUDA driver to enable it",
152 CUDA_DRIVER_LIB, e
153 );
154 return false;
155 }
156 };
157
158 let cu_init: libloading::Symbol<unsafe extern "C" fn(u32) -> GpuResult> =
159 match unsafe { lib.get(b"cuInit\0") } {
160 Ok(f) => f,
161 Err(e) => {
162 warn!(
163 "ailake: `{}` loaded but `cuInit` symbol missing ({}); \
164 CUDA installation may be incomplete — falling back to CPU",
165 CUDA_DRIVER_LIB, e
166 );
167 return false;
168 }
169 };
170 let rc = unsafe { cu_init(0) };
171 if rc != 0 {
172 warn!(
173 "ailake: cuInit(0) returned error code {} — CUDA driver present but no usable GPU \
174 or driver not initialised; falling back to CPU SIMD",
175 rc
176 );
177 return false;
178 }
179
180 let cu_count: libloading::Symbol<unsafe extern "C" fn(*mut i32) -> GpuResult> =
181 match unsafe { lib.get(b"cuDeviceGetCount\0") } {
182 Ok(f) => f,
183 Err(e) => {
184 warn!(
185 "ailake: `cuDeviceGetCount` symbol missing in `{}` ({}); \
186 falling back to CPU",
187 CUDA_DRIVER_LIB, e
188 );
189 return false;
190 }
191 };
192 let mut count = 0i32;
193 let rc = unsafe { cu_count(&mut count) };
194 if rc == 0 && count == 0 {
195 warn!(
196 "ailake: CUDA driver initialised but no CUDA-capable devices found (count=0); \
197 falling back to CPU SIMD"
198 );
199 return false;
200 }
201 rc == 0 && count > 0
202}
203
204fn probe_rocm_driver() -> bool {
214 if ROCM_DRIVER_LIB.is_empty() {
215 return false;
216 }
217 let lib = match unsafe { libloading::Library::new(ROCM_DRIVER_LIB) } {
218 Ok(l) => l,
219 Err(e) => {
220 debug!(
221 "ailake: ROCm library `{}` not found ({}); \
222 AMD GPU acceleration unavailable — install the ROCm runtime to enable it",
223 ROCM_DRIVER_LIB, e
224 );
225 return false;
226 }
227 };
228
229 let hip_init: libloading::Symbol<unsafe extern "C" fn(u32) -> GpuResult> =
231 match unsafe { lib.get(b"hipInit\0") } {
232 Ok(f) => f,
233 Err(e) => {
234 warn!(
235 "ailake: `{}` loaded but `hipInit` symbol missing ({}); \
236 ROCm installation may be incomplete — falling back to CPU",
237 ROCM_DRIVER_LIB, e
238 );
239 return false;
240 }
241 };
242 let rc = unsafe { hip_init(0) };
243 if rc != 0 {
244 warn!(
245 "ailake: hipInit(0) returned error code {} — ROCm driver present but no usable GPU \
246 or driver not initialised; falling back to CPU SIMD",
247 rc
248 );
249 return false;
250 }
251
252 let hip_count: libloading::Symbol<unsafe extern "C" fn(*mut i32) -> GpuResult> =
253 match unsafe { lib.get(b"hipGetDeviceCount\0") } {
254 Ok(f) => f,
255 Err(e) => {
256 warn!(
257 "ailake: `hipGetDeviceCount` symbol missing in `{}` ({}); \
258 falling back to CPU",
259 ROCM_DRIVER_LIB, e
260 );
261 return false;
262 }
263 };
264 let mut count = 0i32;
265 let rc = unsafe { hip_count(&mut count) };
266 if rc == 0 && count == 0 {
267 warn!(
268 "ailake: ROCm driver initialised but no ROCm-capable devices found (count=0); \
269 falling back to CPU SIMD"
270 );
271 return false;
272 }
273 rc == 0 && count > 0
274}
275
276fn detect_avx2() -> bool {
279 #[cfg(target_arch = "x86_64")]
280 {
281 std::is_x86_feature_detected!("avx2")
282 }
283 #[cfg(not(target_arch = "x86_64"))]
284 false
285}
286
287fn detect_avx512() -> bool {
288 #[cfg(target_arch = "x86_64")]
289 {
290 std::is_x86_feature_detected!("avx512f")
291 }
292 #[cfg(not(target_arch = "x86_64"))]
293 false
294}
295
296#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn detect_runs_without_panic() {
304 let p = HardwareProfile::detect();
305 assert!(p.cpu_logical_cores >= 1);
306 }
307
308 #[test]
309 fn small_dataset_always_hnsw() {
310 let p = HardwareProfile {
311 backend: HardwareBackend::NvidiaCuda,
312 has_cuda: true,
313 has_rocm: false,
314 cpu_logical_cores: 64,
315 has_avx2: true,
316 has_avx512: true,
317 };
318 assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ - 1));
319 }
320
321 #[test]
322 fn large_dataset_cuda_picks_ivf_pq() {
323 let p = HardwareProfile {
324 backend: HardwareBackend::NvidiaCuda,
325 has_cuda: true,
326 has_rocm: false,
327 cpu_logical_cores: 1,
328 has_avx2: false,
329 has_avx512: false,
330 };
331 assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
332 }
333
334 #[test]
335 fn large_dataset_rocm_picks_ivf_pq() {
336 let p = HardwareProfile {
337 backend: HardwareBackend::AmdRocm,
338 has_cuda: false,
339 has_rocm: true,
340 cpu_logical_cores: 1,
341 has_avx2: false,
342 has_avx512: false,
343 };
344 assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
345 }
346
347 #[test]
348 fn large_dataset_many_cores_picks_ivf_pq() {
349 let p = HardwareProfile {
350 backend: HardwareBackend::CpuSimd,
351 has_cuda: false,
352 has_rocm: false,
353 cpu_logical_cores: MIN_CORES_FOR_IVF_PQ + 1,
354 has_avx2: false,
355 has_avx512: false,
356 };
357 assert!(p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
358 }
359
360 #[test]
361 fn large_dataset_exactly_threshold_picks_hnsw() {
362 let p = HardwareProfile {
364 backend: HardwareBackend::CpuSimd,
365 has_cuda: false,
366 has_rocm: false,
367 cpu_logical_cores: MIN_CORES_FOR_IVF_PQ,
368 has_avx2: false,
369 has_avx512: false,
370 };
371 assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
372 }
373
374 #[test]
375 fn large_dataset_weak_hardware_picks_hnsw() {
376 let p = HardwareProfile {
377 backend: HardwareBackend::CpuSimd,
378 has_cuda: false,
379 has_rocm: false,
380 cpu_logical_cores: MIN_CORES_FOR_IVF_PQ - 1,
381 has_avx2: false,
382 has_avx512: false,
383 };
384 assert!(!p.recommend_ivf_pq(MIN_VECTORS_FOR_IVF_PQ));
385 }
386
387 #[test]
388 fn backend_consistency_cuda() {
389 let p = HardwareProfile {
390 backend: HardwareBackend::NvidiaCuda,
391 has_cuda: true,
392 has_rocm: false,
393 cpu_logical_cores: 4,
394 has_avx2: false,
395 has_avx512: false,
396 };
397 assert!(p.has_cuda);
398 assert!(!p.has_rocm);
399 assert_eq!(p.backend, HardwareBackend::NvidiaCuda);
400 }
401
402 #[test]
403 fn backend_consistency_rocm() {
404 let p = HardwareProfile {
405 backend: HardwareBackend::AmdRocm,
406 has_cuda: false,
407 has_rocm: true,
408 cpu_logical_cores: 4,
409 has_avx2: false,
410 has_avx512: false,
411 };
412 assert!(!p.has_cuda);
413 assert!(p.has_rocm);
414 assert_eq!(p.backend, HardwareBackend::AmdRocm);
415 }
416}