1#![allow(unsafe_code)]
12
13use crate::error::{InferenceError, MlError, Result};
14use tracing::{debug, info};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum GpuBackend {
19 Cuda,
21 Rocm,
23 DirectMl,
25 Vulkan,
27 WebGpu,
29 Metal,
31 OpenCl,
33}
34
35impl GpuBackend {
36 #[must_use]
38 pub fn name(&self) -> &'static str {
39 match self {
40 Self::Cuda => "CUDA",
41 Self::Rocm => "ROCm",
42 Self::DirectMl => "DirectML",
43 Self::Vulkan => "Vulkan",
44 Self::WebGpu => "WebGPU",
45 Self::Metal => "Metal",
46 Self::OpenCl => "OpenCL",
47 }
48 }
49
50 #[must_use]
52 pub fn is_available(&self) -> bool {
53 match self {
54 Self::Cuda => check_cuda_available(),
55 Self::Rocm => check_rocm_available(),
56 Self::DirectMl => check_directml_available(),
57 Self::Vulkan => check_vulkan_available(),
58 Self::WebGpu => check_webgpu_available(),
59 Self::Metal => check_metal_available(),
60 Self::OpenCl => check_opencl_available(),
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct GpuDevice {
68 pub id: usize,
70 pub name: String,
72 pub total_memory: usize,
74 pub free_memory: usize,
76 pub compute_capability: String,
78 pub backend: GpuBackend,
80}
81
82impl GpuDevice {
83 #[must_use]
85 pub fn memory_utilization(&self) -> f32 {
86 if self.total_memory > 0 {
87 let used = self.total_memory.saturating_sub(self.free_memory);
88 (used as f32 / self.total_memory as f32) * 100.0
89 } else {
90 0.0
91 }
92 }
93
94 #[must_use]
96 pub fn has_sufficient_memory(&self, required_bytes: usize) -> bool {
97 self.free_memory >= required_bytes
98 }
99}
100
101#[derive(Debug, Clone)]
103pub struct GpuConfig {
104 pub backend: Option<GpuBackend>,
106 pub device_id: Option<usize>,
108 pub mixed_precision: bool,
110 pub tensor_cores: bool,
112 pub memory_growth: bool,
114 pub memory_fraction: f32,
116}
117
118impl Default for GpuConfig {
119 fn default() -> Self {
120 Self {
121 backend: None,
122 device_id: None,
123 mixed_precision: true,
124 tensor_cores: true,
125 memory_growth: true,
126 memory_fraction: 0.9,
127 }
128 }
129}
130
131impl GpuConfig {
132 #[must_use]
134 pub fn builder() -> GpuConfigBuilder {
135 GpuConfigBuilder::default()
136 }
137}
138
139#[derive(Debug, Default)]
141pub struct GpuConfigBuilder {
142 backend: Option<GpuBackend>,
143 device_id: Option<usize>,
144 mixed_precision: Option<bool>,
145 tensor_cores: Option<bool>,
146 memory_growth: Option<bool>,
147 memory_fraction: Option<f32>,
148}
149
150impl GpuConfigBuilder {
151 #[must_use]
153 pub fn backend(mut self, backend: GpuBackend) -> Self {
154 self.backend = Some(backend);
155 self
156 }
157
158 #[must_use]
160 pub fn device_id(mut self, id: usize) -> Self {
161 self.device_id = Some(id);
162 self
163 }
164
165 #[must_use]
167 pub fn mixed_precision(mut self, enable: bool) -> Self {
168 self.mixed_precision = Some(enable);
169 self
170 }
171
172 #[must_use]
174 pub fn tensor_cores(mut self, enable: bool) -> Self {
175 self.tensor_cores = Some(enable);
176 self
177 }
178
179 #[must_use]
181 pub fn memory_growth(mut self, enable: bool) -> Self {
182 self.memory_growth = Some(enable);
183 self
184 }
185
186 #[must_use]
188 pub fn memory_fraction(mut self, fraction: f32) -> Self {
189 self.memory_fraction = Some(fraction.clamp(0.1, 1.0));
190 self
191 }
192
193 #[must_use]
195 pub fn build(self) -> GpuConfig {
196 GpuConfig {
197 backend: self.backend,
198 device_id: self.device_id,
199 mixed_precision: self.mixed_precision.unwrap_or(true),
200 tensor_cores: self.tensor_cores.unwrap_or(true),
201 memory_growth: self.memory_growth.unwrap_or(true),
202 memory_fraction: self.memory_fraction.unwrap_or(0.9),
203 }
204 }
205}
206
207pub fn list_devices() -> Result<Vec<GpuDevice>> {
212 info!("Enumerating GPU devices");
213
214 let mut devices = Vec::new();
215
216 for backend in &[
218 GpuBackend::Cuda,
219 GpuBackend::Rocm,
220 GpuBackend::Metal,
221 GpuBackend::Vulkan,
222 GpuBackend::DirectMl,
223 ] {
224 if backend.is_available() {
225 devices.extend(list_devices_for_backend(*backend)?);
226 }
227 }
228
229 info!("Found {} GPU device(s)", devices.len());
230 Ok(devices)
231}
232
233fn list_devices_for_backend(backend: GpuBackend) -> Result<Vec<GpuDevice>> {
235 debug!("Enumerating devices for backend: {}", backend.name());
236
237 match backend {
238 GpuBackend::Cuda => enumerate_cuda_devices(),
239 GpuBackend::Metal => enumerate_metal_devices(),
240 GpuBackend::Vulkan => enumerate_vulkan_devices(),
241 GpuBackend::OpenCl => enumerate_opencl_devices(),
242 GpuBackend::Rocm => enumerate_rocm_devices(),
243 GpuBackend::DirectMl => enumerate_directml_devices(),
244 GpuBackend::WebGpu => enumerate_webgpu_devices(),
245 }
246}
247
248pub fn select_device(config: &GpuConfig) -> Result<GpuDevice> {
253 let devices = list_devices()?;
254
255 if devices.is_empty() {
256 return Err(MlError::Inference(InferenceError::GpuNotAvailable {
257 message: "No GPU devices found".to_string(),
258 }));
259 }
260
261 let filtered: Vec<_> = if let Some(backend) = config.backend {
263 devices
264 .into_iter()
265 .filter(|d| d.backend == backend)
266 .collect()
267 } else {
268 devices
269 };
270
271 if filtered.is_empty() {
272 return Err(MlError::Inference(InferenceError::GpuNotAvailable {
273 message: "No GPU devices match the specified backend".to_string(),
274 }));
275 }
276
277 let device = if let Some(id) = config.device_id {
279 filtered.into_iter().find(|d| d.id == id).ok_or_else(|| {
280 MlError::Inference(InferenceError::GpuNotAvailable {
281 message: format!("Device ID {} not found", id),
282 })
283 })?
284 } else {
285 filtered
286 .into_iter()
287 .max_by_key(|d| d.free_memory)
288 .ok_or_else(|| {
289 MlError::Inference(InferenceError::GpuNotAvailable {
290 message: "Failed to select GPU device".to_string(),
291 })
292 })?
293 };
294
295 info!(
296 "Selected GPU: {} (free memory: {} MB)",
297 device.name,
298 device.free_memory / (1024 * 1024)
299 );
300
301 Ok(device)
302}
303
304fn check_cuda_available() -> bool {
310 #[cfg(feature = "cuda")]
311 {
312 cuda_check_runtime()
313 }
314 #[cfg(not(feature = "cuda"))]
315 false
316}
317
318#[cfg(feature = "cuda")]
320fn cuda_check_runtime() -> bool {
321 use libloading::Library;
322
323 let lib_names = if cfg!(target_os = "windows") {
325 vec!["nvcuda.dll", "cudart64_110.dll", "cudart64_12.dll"]
326 } else if cfg!(target_os = "macos") {
327 vec!["libcuda.dylib", "/usr/local/cuda/lib/libcuda.dylib"]
328 } else {
329 vec![
330 "libcuda.so",
331 "libcuda.so.1",
332 "/usr/lib/x86_64-linux-gnu/libcuda.so",
333 "/usr/local/cuda/lib64/libcuda.so",
334 ]
335 };
336
337 for lib_name in lib_names {
338 if unsafe { Library::new(lib_name) }.is_ok() {
339 debug!("Found CUDA runtime library: {}", lib_name);
340 return true;
341 }
342 }
343
344 debug!("CUDA runtime library not found");
345 false
346}
347
348fn enumerate_cuda_devices() -> Result<Vec<GpuDevice>> {
350 #[cfg(feature = "cuda")]
351 {
352 cuda_enumerate_devices_impl()
353 }
354 #[cfg(not(feature = "cuda"))]
355 {
356 Ok(Vec::new())
357 }
358}
359
360#[cfg(feature = "cuda")]
361fn cuda_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
362 use libloading::{Library, Symbol};
363 use std::ffi::{c_char, c_int};
364
365 type CUdevice = c_int;
367 type CUresult = c_int;
368
369 const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: c_int = 75;
371 const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: c_int = 76;
372
373 let lib_names = if cfg!(target_os = "windows") {
375 vec!["nvcuda.dll"]
376 } else if cfg!(target_os = "macos") {
377 vec!["libcuda.dylib"]
378 } else {
379 vec!["libcuda.so.1", "libcuda.so"]
380 };
381
382 let lib = lib_names
383 .iter()
384 .find_map(|name| unsafe { Library::new(*name).ok() })
385 .ok_or_else(|| {
386 MlError::Inference(InferenceError::GpuNotAvailable {
387 message: "CUDA library not found".to_string(),
388 })
389 })?;
390
391 let cu_init: Symbol<unsafe extern "C" fn(c_int) -> CUresult> = unsafe { lib.get(b"cuInit\0") }
393 .map_err(|e| {
394 MlError::Inference(InferenceError::GpuNotAvailable {
395 message: format!("Failed to load cuInit: {}", e),
396 })
397 })?;
398
399 let result = unsafe { cu_init(0) };
401 if result != 0 {
402 return Err(MlError::Inference(InferenceError::GpuNotAvailable {
403 message: format!("CUDA initialization failed with code: {}", result),
404 }));
405 }
406
407 let cu_device_get_count: Symbol<unsafe extern "C" fn(*mut c_int) -> CUresult> =
409 unsafe { lib.get(b"cuDeviceGetCount\0") }.map_err(|e| {
410 MlError::Inference(InferenceError::GpuNotAvailable {
411 message: format!("Failed to load cuDeviceGetCount: {}", e),
412 })
413 })?;
414
415 let mut count: c_int = 0;
417 let result = unsafe { cu_device_get_count(&mut count) };
418 if result != 0 {
419 return Err(MlError::Inference(InferenceError::GpuNotAvailable {
420 message: format!("Failed to get CUDA device count: {}", result),
421 }));
422 }
423
424 debug!("Found {} CUDA device(s)", count);
425
426 let cu_device_get: Symbol<unsafe extern "C" fn(*mut CUdevice, c_int) -> CUresult> =
428 unsafe { lib.get(b"cuDeviceGet\0") }.map_err(|e| {
429 MlError::Inference(InferenceError::GpuNotAvailable {
430 message: format!("Failed to load cuDeviceGet: {}", e),
431 })
432 })?;
433
434 let cu_device_get_name: Symbol<unsafe extern "C" fn(*mut c_char, c_int, CUdevice) -> CUresult> =
435 unsafe { lib.get(b"cuDeviceGetName\0") }.map_err(|e| {
436 MlError::Inference(InferenceError::GpuNotAvailable {
437 message: format!("Failed to load cuDeviceGetName: {}", e),
438 })
439 })?;
440
441 let cu_device_total_mem: Symbol<unsafe extern "C" fn(*mut usize, CUdevice) -> CUresult> =
442 unsafe { lib.get(b"cuDeviceTotalMem_v2\0") }
443 .or_else(|_| unsafe { lib.get(b"cuDeviceTotalMem\0") })
444 .map_err(|e| {
445 MlError::Inference(InferenceError::GpuNotAvailable {
446 message: format!("Failed to load cuDeviceTotalMem: {}", e),
447 })
448 })?;
449
450 let cu_device_get_attribute: Symbol<
451 unsafe extern "C" fn(*mut c_int, c_int, CUdevice) -> CUresult,
452 > = unsafe { lib.get(b"cuDeviceGetAttribute\0") }.map_err(|e| {
453 MlError::Inference(InferenceError::GpuNotAvailable {
454 message: format!("Failed to load cuDeviceGetAttribute: {}", e),
455 })
456 })?;
457
458 let mut devices = Vec::new();
460 for i in 0..count {
461 let mut device: CUdevice = 0;
463 let result = unsafe { cu_device_get(&mut device, i) };
464 if result != 0 {
465 debug!("Failed to get CUDA device {}: error code {}", i, result);
466 continue;
467 }
468
469 let mut name_buf = [0i8; 256];
471 let result = unsafe { cu_device_get_name(name_buf.as_mut_ptr(), 256, device) };
472 let device_name = if result == 0 {
473 unsafe {
474 std::ffi::CStr::from_ptr(name_buf.as_ptr())
475 .to_string_lossy()
476 .into_owned()
477 }
478 } else {
479 format!("NVIDIA CUDA Device {}", i)
480 };
481
482 let mut total_mem: usize = 0;
484 let result = unsafe { cu_device_total_mem(&mut total_mem, device) };
485 if result != 0 {
486 debug!(
487 "Failed to get total memory for device {}: error code {}",
488 i, result
489 );
490 total_mem = 0;
491 }
492
493 let free_mem = (total_mem as f64 * 0.95) as usize; let mut major: c_int = 0;
501 let mut minor: c_int = 0;
502 let result_major = unsafe {
503 cu_device_get_attribute(
504 &mut major,
505 CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
506 device,
507 )
508 };
509 let result_minor = unsafe {
510 cu_device_get_attribute(
511 &mut minor,
512 CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
513 device,
514 )
515 };
516
517 let compute_capability = if result_major == 0 && result_minor == 0 {
518 format!("{}.{}", major, minor)
519 } else {
520 "Unknown".to_string()
521 };
522
523 devices.push(GpuDevice {
524 id: i as usize,
525 name: device_name,
526 total_memory: total_mem,
527 free_memory: free_mem,
528 compute_capability,
529 backend: GpuBackend::Cuda,
530 });
531 }
532
533 Ok(devices)
534}
535
536fn check_metal_available() -> bool {
542 #[cfg(all(feature = "metal", target_os = "macos"))]
543 {
544 metal_check_available()
545 }
546 #[cfg(not(all(feature = "metal", target_os = "macos")))]
547 false
548}
549
550#[cfg(all(feature = "metal", target_os = "macos"))]
551fn metal_check_available() -> bool {
552 !metal::Device::all().is_empty()
555}
556
557fn enumerate_metal_devices() -> Result<Vec<GpuDevice>> {
559 #[cfg(all(feature = "metal", target_os = "macos"))]
560 {
561 metal_enumerate_devices_impl()
562 }
563 #[cfg(not(all(feature = "metal", target_os = "macos")))]
564 {
565 Ok(Vec::new())
566 }
567}
568
569#[cfg(all(feature = "metal", target_os = "macos"))]
570fn metal_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
571 let devices = metal::Device::all();
572
573 if devices.is_empty() {
574 return Err(MlError::Inference(InferenceError::GpuNotAvailable {
575 message: "No Metal devices found".to_string(),
576 }));
577 }
578
579 let mut gpu_devices = Vec::new();
580
581 for (idx, device) in devices.iter().enumerate() {
582 let name = device.name().to_string();
583 let total_memory = device.recommended_max_working_set_size() as usize;
584 let free_memory = total_memory; gpu_devices.push(GpuDevice {
587 id: idx,
588 name,
589 total_memory,
590 free_memory,
591 compute_capability: "Metal".to_string(),
592 backend: GpuBackend::Metal,
593 });
594 }
595
596 debug!("Found {} Metal device(s)", gpu_devices.len());
597 Ok(gpu_devices)
598}
599
600fn check_vulkan_available() -> bool {
606 #[cfg(feature = "vulkan")]
607 {
608 vulkan_check_available()
609 }
610 #[cfg(not(feature = "vulkan"))]
611 false
612}
613
614#[cfg(feature = "vulkan")]
615fn vulkan_check_available() -> bool {
616 use ash::{Entry, vk};
617
618 if let Ok(entry) = unsafe { Entry::load() } {
620 let app_info = vk::ApplicationInfo::default().api_version(vk::make_api_version(0, 1, 0, 0));
622
623 let create_info = vk::InstanceCreateInfo::default().application_info(&app_info);
624
625 if let Ok(instance) = unsafe { entry.create_instance(&create_info, None) } {
626 unsafe { instance.destroy_instance(None) };
627 return true;
628 }
629 }
630
631 false
632}
633
634fn enumerate_vulkan_devices() -> Result<Vec<GpuDevice>> {
636 #[cfg(feature = "vulkan")]
637 {
638 vulkan_enumerate_devices_impl()
639 }
640 #[cfg(not(feature = "vulkan"))]
641 {
642 Ok(Vec::new())
643 }
644}
645
646#[cfg(feature = "vulkan")]
647fn vulkan_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
648 use ash::{Entry, vk};
649
650 let entry = unsafe { Entry::load() }.map_err(|e| {
651 MlError::Inference(InferenceError::GpuNotAvailable {
652 message: format!("Failed to load Vulkan: {}", e),
653 })
654 })?;
655
656 let app_info = vk::ApplicationInfo::default().api_version(vk::make_api_version(0, 1, 0, 0));
657
658 let create_info = vk::InstanceCreateInfo::default().application_info(&app_info);
659
660 let instance = unsafe { entry.create_instance(&create_info, None) }.map_err(|e| {
661 MlError::Inference(InferenceError::GpuNotAvailable {
662 message: format!("Failed to create Vulkan instance: {}", e),
663 })
664 })?;
665
666 let physical_devices = unsafe { instance.enumerate_physical_devices() }.map_err(|e| {
667 unsafe { instance.destroy_instance(None) };
668 MlError::Inference(InferenceError::GpuNotAvailable {
669 message: format!("Failed to enumerate Vulkan devices: {}", e),
670 })
671 })?;
672
673 let mut devices = Vec::new();
674 for (idx, physical_device) in physical_devices.iter().enumerate() {
675 let properties = unsafe { instance.get_physical_device_properties(*physical_device) };
676 let memory_properties =
677 unsafe { instance.get_physical_device_memory_properties(*physical_device) };
678
679 let device_name = unsafe {
680 std::ffi::CStr::from_ptr(properties.device_name.as_ptr())
681 .to_string_lossy()
682 .into_owned()
683 };
684
685 let total_memory = memory_properties
687 .memory_heaps
688 .iter()
689 .take(memory_properties.memory_heap_count as usize)
690 .filter(|heap| heap.flags.contains(vk::MemoryHeapFlags::DEVICE_LOCAL))
691 .map(|heap| heap.size as usize)
692 .sum();
693
694 devices.push(GpuDevice {
695 id: idx,
696 name: device_name,
697 total_memory,
698 free_memory: total_memory, compute_capability: format!(
700 "Vulkan {}.{}",
701 vk::api_version_major(properties.api_version),
702 vk::api_version_minor(properties.api_version)
703 ),
704 backend: GpuBackend::Vulkan,
705 });
706 }
707
708 unsafe { instance.destroy_instance(None) };
709 debug!("Found {} Vulkan device(s)", devices.len());
710 Ok(devices)
711}
712
713fn check_opencl_available() -> bool {
719 #[cfg(feature = "opencl")]
720 {
721 opencl_check_available()
722 }
723 #[cfg(not(feature = "opencl"))]
724 false
725}
726
727#[cfg(feature = "opencl")]
728fn opencl_check_available() -> bool {
729 use opencl3::platform::get_platforms;
730
731 if let Ok(platforms) = get_platforms() {
732 return !platforms.is_empty();
733 }
734 false
735}
736
737fn enumerate_opencl_devices() -> Result<Vec<GpuDevice>> {
739 #[cfg(feature = "opencl")]
740 {
741 opencl_enumerate_devices_impl()
742 }
743 #[cfg(not(feature = "opencl"))]
744 {
745 Ok(Vec::new())
746 }
747}
748
749#[cfg(feature = "opencl")]
750fn opencl_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
751 use opencl3::device::{CL_DEVICE_TYPE_GPU, Device, get_all_devices};
752
753 let device_ids = get_all_devices(CL_DEVICE_TYPE_GPU).map_err(|e| {
754 MlError::Inference(InferenceError::GpuNotAvailable {
755 message: format!("Failed to get OpenCL devices: {:?}", e),
756 })
757 })?;
758
759 let mut devices = Vec::new();
760 for (idx, device_id) in device_ids.iter().enumerate() {
761 let device = Device::new(*device_id);
762
763 let name = device.name().map_err(|e| {
764 MlError::Inference(InferenceError::GpuNotAvailable {
765 message: format!("Failed to get device name: {:?}", e),
766 })
767 })?;
768
769 let total_memory = device.global_mem_size().map_err(|e| {
770 MlError::Inference(InferenceError::GpuNotAvailable {
771 message: format!("Failed to get device memory: {:?}", e),
772 })
773 })? as usize;
774
775 devices.push(GpuDevice {
776 id: idx,
777 name,
778 total_memory,
779 free_memory: total_memory, compute_capability: "OpenCL".to_string(),
781 backend: GpuBackend::OpenCl,
782 });
783 }
784
785 debug!("Found {} OpenCL device(s)", devices.len());
786 Ok(devices)
787}
788
789fn check_rocm_available() -> bool {
795 #[cfg(feature = "rocm")]
796 {
797 rocm_check_runtime()
798 }
799 #[cfg(not(feature = "rocm"))]
800 false
801}
802
803#[cfg(feature = "rocm")]
804fn rocm_check_runtime() -> bool {
805 use libloading::Library;
806
807 let lib_names = if cfg!(target_os = "linux") {
809 vec![
810 "libamdhip64.so",
811 "libamdhip64.so.6",
812 "/opt/rocm/lib/libamdhip64.so",
813 ]
814 } else {
815 vec![]
816 };
817
818 for lib_name in lib_names {
819 if unsafe { Library::new(lib_name) }.is_ok() {
820 debug!("Found ROCm runtime library: {}", lib_name);
821 return true;
822 }
823 }
824
825 debug!("ROCm runtime library not found");
826 false
827}
828
829fn enumerate_rocm_devices() -> Result<Vec<GpuDevice>> {
831 #[cfg(feature = "rocm")]
832 {
833 rocm_enumerate_devices_impl()
834 }
835 #[cfg(not(feature = "rocm"))]
836 {
837 Ok(Vec::new())
838 }
839}
840
841#[cfg(feature = "rocm")]
842fn rocm_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
843 use libloading::{Library, Symbol};
844 use std::ffi::{c_char, c_int};
845
846 type HipError = c_int;
848 const HIP_SUCCESS: HipError = 0;
849
850 #[repr(C)]
852 #[allow(non_snake_case)]
853 struct HipDeviceProp {
854 name: [c_char; 256],
855 totalGlobalMem: usize,
856 sharedMemPerBlock: usize,
857 regsPerBlock: c_int,
858 warpSize: c_int,
859 memPitch: usize,
860 maxThreadsPerBlock: c_int,
861 maxThreadsDim: [c_int; 3],
862 maxGridSize: [c_int; 3],
863 clockRate: c_int,
864 totalConstMem: usize,
865 major: c_int,
866 minor: c_int,
867 textureAlignment: usize,
868 deviceOverlap: c_int,
869 multiProcessorCount: c_int,
870 kernelExecTimeoutEnabled: c_int,
871 integrated: c_int,
872 canMapHostMemory: c_int,
873 computeMode: c_int,
874 maxTexture1D: c_int,
875 maxTexture2D: [c_int; 2],
876 maxTexture3D: [c_int; 3],
877 gcnArchName: [c_char; 256],
880 }
881
882 let lib_names = if cfg!(target_os = "linux") {
884 vec![
885 "libamdhip64.so",
886 "libamdhip64.so.6",
887 "/opt/rocm/lib/libamdhip64.so",
888 ]
889 } else {
890 vec![]
891 };
892
893 let lib = lib_names
894 .iter()
895 .find_map(|name| unsafe { Library::new(*name).ok() })
896 .ok_or_else(|| {
897 MlError::Inference(InferenceError::GpuNotAvailable {
898 message: "ROCm HIP library not found".to_string(),
899 })
900 })?;
901
902 let hip_get_device_count: Symbol<unsafe extern "C" fn(*mut c_int) -> HipError> =
904 unsafe { lib.get(b"hipGetDeviceCount\0") }.map_err(|e| {
905 MlError::Inference(InferenceError::GpuNotAvailable {
906 message: format!("Failed to load hipGetDeviceCount: {}", e),
907 })
908 })?;
909
910 let mut count: c_int = 0;
912 let result = unsafe { hip_get_device_count(&mut count) };
913 if result != HIP_SUCCESS {
914 return Err(MlError::Inference(InferenceError::GpuNotAvailable {
915 message: format!("hipGetDeviceCount failed with error code: {}", result),
916 }));
917 }
918
919 debug!("Found {} ROCm device(s)", count);
920
921 let hip_get_device_properties: Symbol<
923 unsafe extern "C" fn(*mut HipDeviceProp, c_int) -> HipError,
924 > = unsafe { lib.get(b"hipGetDeviceProperties\0") }.map_err(|e| {
925 MlError::Inference(InferenceError::GpuNotAvailable {
926 message: format!("Failed to load hipGetDeviceProperties: {}", e),
927 })
928 })?;
929
930 let hip_mem_get_info: Option<Symbol<unsafe extern "C" fn(*mut usize, *mut usize) -> HipError>> =
932 unsafe { lib.get(b"hipMemGetInfo\0").ok() };
933
934 let hip_set_device: Option<Symbol<unsafe extern "C" fn(c_int) -> HipError>> =
936 unsafe { lib.get(b"hipSetDevice\0").ok() };
937
938 let mut devices = Vec::new();
940 for i in 0..count {
941 let mut props: HipDeviceProp = unsafe { std::mem::zeroed() };
943 let result = unsafe { hip_get_device_properties(&mut props, i) };
944 if result != HIP_SUCCESS {
945 debug!(
946 "Failed to get properties for ROCm device {}: error code {}",
947 i, result
948 );
949 continue;
950 }
951
952 let device_name = unsafe {
954 std::ffi::CStr::from_ptr(props.name.as_ptr())
955 .to_string_lossy()
956 .into_owned()
957 };
958
959 let total_memory = props.totalGlobalMem;
960
961 let free_memory = if let (Some(set_dev), Some(get_info)) =
963 (hip_set_device.as_ref(), hip_mem_get_info.as_ref())
964 {
965 let set_result = unsafe { set_dev(i) };
967 if set_result == HIP_SUCCESS {
968 let mut free: usize = 0;
969 let mut total: usize = 0;
970 let result = unsafe { get_info(&mut free, &mut total) };
971 if result == HIP_SUCCESS {
972 free
973 } else {
974 (total_memory as f64 * 0.95) as usize
976 }
977 } else {
978 (total_memory as f64 * 0.95) as usize
980 }
981 } else {
982 (total_memory as f64 * 0.95) as usize
984 };
985
986 let compute_capability = unsafe {
988 let gcn_name = std::ffi::CStr::from_ptr(props.gcnArchName.as_ptr())
989 .to_string_lossy()
990 .into_owned();
991 if gcn_name.is_empty() {
992 format!("{}.{}", props.major, props.minor)
993 } else {
994 gcn_name
995 }
996 };
997
998 devices.push(GpuDevice {
999 id: i as usize,
1000 name: device_name,
1001 total_memory,
1002 free_memory,
1003 compute_capability,
1004 backend: GpuBackend::Rocm,
1005 });
1006 }
1007
1008 Ok(devices)
1009}
1010
1011fn check_directml_available() -> bool {
1017 #[cfg(all(feature = "directml", target_os = "windows"))]
1018 {
1019 directml_check_runtime()
1020 }
1021 #[cfg(not(all(feature = "directml", target_os = "windows")))]
1022 false
1023}
1024
1025#[cfg(all(feature = "directml", target_os = "windows"))]
1026fn directml_check_runtime() -> bool {
1027 use libloading::Library;
1028
1029 #[cfg(windows)]
1032 {
1033 use std::mem;
1034
1035 #[repr(C)]
1036 struct OsVersionInfoExW {
1037 dw_os_version_info_size: u32,
1038 dw_major_version: u32,
1039 dw_minor_version: u32,
1040 dw_build_number: u32,
1041 dw_platform_id: u32,
1042 sz_csd_version: [u16; 128],
1043 w_service_pack_major: u16,
1044 w_service_pack_minor: u16,
1045 w_suite_mask: u16,
1046 w_product_type: u8,
1047 w_reserved: u8,
1048 }
1049
1050 if let Ok(ntdll) = unsafe { Library::new("ntdll.dll") } {
1052 type RtlGetVersion = unsafe extern "system" fn(*mut OsVersionInfoExW) -> i32;
1053 if let Ok(rtl_get_version) = unsafe { ntdll.get::<RtlGetVersion>(b"RtlGetVersion\0") } {
1054 let mut version_info: OsVersionInfoExW = unsafe { mem::zeroed() };
1055 version_info.dw_os_version_info_size = mem::size_of::<OsVersionInfoExW>() as u32;
1056
1057 let status = unsafe { rtl_get_version(&mut version_info) };
1058 if status == 0 {
1059 let is_compatible = version_info.dw_major_version > 10
1061 || (version_info.dw_major_version == 10
1062 && version_info.dw_build_number >= 18362);
1063
1064 if !is_compatible {
1065 debug!(
1066 "Windows version {}.{}.{} is too old for DirectML (requires 10.0.18362+)",
1067 version_info.dw_major_version,
1068 version_info.dw_minor_version,
1069 version_info.dw_build_number
1070 );
1071 return false;
1072 }
1073 }
1074 }
1075 }
1076 }
1077
1078 let directml_names = vec!["DirectML.dll", "DirectML.Debug.dll"];
1080
1081 for lib_name in directml_names {
1082 if let Ok(lib) = unsafe { Library::new(lib_name) } {
1083 let has_v1_0 = unsafe { lib.get::<*const ()>(b"DMLCreateDevice\0") }.is_ok();
1085 let has_v1_1 = unsafe { lib.get::<*const ()>(b"DMLCreateDevice1\0") }.is_ok();
1086
1087 if has_v1_0 || has_v1_1 {
1088 debug!("Found DirectML library: {}", lib_name);
1089 return true;
1090 } else {
1091 debug!(
1092 "DirectML library {} found but missing required exports",
1093 lib_name
1094 );
1095 }
1096 }
1097 }
1098
1099 debug!("DirectML library not found");
1100 false
1101}
1102
1103fn enumerate_directml_devices() -> Result<Vec<GpuDevice>> {
1105 #[cfg(all(feature = "directml", target_os = "windows"))]
1106 {
1107 directml_enumerate_devices_impl()
1108 }
1109 #[cfg(not(all(feature = "directml", target_os = "windows")))]
1110 {
1111 Ok(Vec::new())
1112 }
1113}
1114
1115#[cfg(all(feature = "directml", target_os = "windows"))]
1116fn directml_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
1117 use libloading::{Library, Symbol};
1118 use std::ffi::c_void;
1119
1120 #[repr(C)]
1122 struct Guid {
1123 data1: u32,
1124 data2: u16,
1125 data3: u16,
1126 data4: [u8; 8],
1127 }
1128
1129 #[repr(C)]
1131 struct DxgiAdapterDesc1 {
1132 description: [u16; 128],
1133 vendor_id: u32,
1134 device_id: u32,
1135 sub_sys_id: u32,
1136 revision: u32,
1137 dedicated_video_memory: usize,
1138 dedicated_system_memory: usize,
1139 shared_system_memory: usize,
1140 adapter_luid: [u8; 8],
1141 flags: u32,
1142 }
1143
1144 type HResult = i32;
1146 const S_OK: HResult = 0;
1147 const DXGI_ERROR_NOT_FOUND: HResult = -2005270526; const D3D_FEATURE_LEVEL_11_0: u32 = 0xb000;
1151
1152 const IID_IDXGI_FACTORY1: Guid = Guid {
1154 data1: 0x770aae78,
1155 data2: 0xf26f,
1156 data3: 0x4dba,
1157 data4: [0xa8, 0x29, 0x25, 0x3c, 0x83, 0xd1, 0xb3, 0x87],
1158 };
1159
1160 let dxgi_lib = unsafe { Library::new("dxgi.dll") }.map_err(|e| {
1162 MlError::Inference(InferenceError::GpuNotAvailable {
1163 message: format!("Failed to load dxgi.dll: {}", e),
1164 })
1165 })?;
1166
1167 let create_factory: Symbol<
1169 unsafe extern "system" fn(*const Guid, *mut *mut c_void) -> HResult,
1170 > = unsafe { dxgi_lib.get(b"CreateDXGIFactory1\0") }.map_err(|e| {
1171 MlError::Inference(InferenceError::GpuNotAvailable {
1172 message: format!("Failed to load CreateDXGIFactory1: {}", e),
1173 })
1174 })?;
1175
1176 let mut factory: *mut c_void = std::ptr::null_mut();
1178 let hr = unsafe { create_factory(&IID_IDXGI_FACTORY1, &mut factory) };
1179 if hr != S_OK || factory.is_null() {
1180 return Err(MlError::Inference(InferenceError::GpuNotAvailable {
1181 message: format!("Failed to create DXGI factory: HRESULT = 0x{:08X}", hr),
1182 }));
1183 }
1184
1185 type EnumAdapters1Fn = unsafe extern "system" fn(*mut c_void, u32, *mut *mut c_void) -> HResult;
1188
1189 let vtable = unsafe { *(factory as *const *const c_void) };
1190 let enum_adapters1: EnumAdapters1Fn =
1191 unsafe { std::mem::transmute(*((vtable as *const *const c_void).offset(7))) };
1192
1193 let mut devices = Vec::new();
1194 let mut adapter_index = 0u32;
1195
1196 loop {
1197 let mut adapter: *mut c_void = std::ptr::null_mut();
1198 let hr = unsafe { enum_adapters1(factory, adapter_index, &mut adapter) };
1199
1200 if hr == DXGI_ERROR_NOT_FOUND {
1201 break;
1202 }
1203
1204 if hr != S_OK || adapter.is_null() {
1205 break;
1206 }
1207
1208 type GetDesc1Fn = unsafe extern "system" fn(*mut c_void, *mut DxgiAdapterDesc1) -> HResult;
1210
1211 let adapter_vtable = unsafe { *(adapter as *const *const c_void) };
1212 let get_desc1: GetDesc1Fn =
1213 unsafe { std::mem::transmute(*((adapter_vtable as *const *const c_void).offset(10))) };
1214
1215 let mut desc: DxgiAdapterDesc1 = unsafe { std::mem::zeroed() };
1216 let hr = unsafe { get_desc1(adapter, &mut desc) };
1217
1218 if hr == S_OK {
1219 let description = String::from_utf16_lossy(&desc.description)
1221 .trim_end_matches('\0')
1222 .to_string();
1223
1224 if desc.dedicated_video_memory > 0 {
1226 devices.push(GpuDevice {
1227 id: adapter_index as usize,
1228 name: description,
1229 total_memory: desc.dedicated_video_memory,
1230 free_memory: (desc.dedicated_video_memory as f64 * 0.95) as usize,
1231 compute_capability: "DirectML".to_string(),
1232 backend: GpuBackend::DirectMl,
1233 });
1234 }
1235 }
1236
1237 type ReleaseFn = unsafe extern "system" fn(*mut c_void) -> u32;
1239 let release: ReleaseFn = unsafe {
1240 let adapter_vtable = *(adapter as *const *const c_void);
1241 std::mem::transmute(*((adapter_vtable as *const *const c_void).offset(2)))
1242 };
1243 unsafe { release(adapter) };
1244
1245 adapter_index += 1;
1246 }
1247
1248 type ReleaseFn = unsafe extern "system" fn(*mut c_void) -> u32;
1250 let release: ReleaseFn =
1251 unsafe { std::mem::transmute(*((vtable as *const *const c_void).offset(2))) };
1252 unsafe { release(factory) };
1253
1254 debug!("Found {} DirectML-compatible device(s)", devices.len());
1255 Ok(devices)
1256}
1257
1258fn check_webgpu_available() -> bool {
1264 #[cfg(all(feature = "webgpu", target_arch = "wasm32"))]
1265 {
1266 webgpu_check_available()
1267 }
1268 #[cfg(not(all(feature = "webgpu", target_arch = "wasm32")))]
1269 false
1270}
1271
1272#[cfg(all(feature = "webgpu", target_arch = "wasm32"))]
1273fn webgpu_check_available() -> bool {
1274 use js_sys::Reflect;
1275 use wasm_bindgen::JsValue;
1276 use web_sys::window;
1277
1278 let window = match window() {
1280 Some(w) => w,
1281 None => {
1282 debug!("WebGPU: window object not available");
1283 return false;
1284 }
1285 };
1286
1287 let navigator = window.navigator();
1289 let navigator_val = JsValue::from(&navigator);
1290
1291 let gpu_key = JsValue::from_str("gpu");
1293 match Reflect::has(&navigator_val, &gpu_key) {
1294 Ok(has_gpu) => {
1295 if !has_gpu {
1296 debug!("WebGPU: navigator.gpu property not found");
1297 return false;
1298 }
1299 }
1300 Err(e) => {
1301 debug!("WebGPU: Failed to check for gpu property: {:?}", e);
1302 return false;
1303 }
1304 }
1305
1306 let gpu = match Reflect::get(&navigator_val, &gpu_key) {
1308 Ok(g) => g,
1309 Err(e) => {
1310 debug!("WebGPU: Failed to get gpu object: {:?}", e);
1311 return false;
1312 }
1313 };
1314
1315 if gpu.is_null() || gpu.is_undefined() {
1317 debug!("WebGPU: navigator.gpu is null or undefined");
1318 return false;
1319 }
1320
1321 let request_adapter_key = JsValue::from_str("requestAdapter");
1323 match Reflect::has(&gpu, &request_adapter_key) {
1324 Ok(has_request_adapter) => {
1325 if !has_request_adapter {
1326 debug!("WebGPU: requestAdapter method not found");
1327 return false;
1328 }
1329 }
1330 Err(e) => {
1331 debug!("WebGPU: Failed to check for requestAdapter: {:?}", e);
1332 return false;
1333 }
1334 }
1335
1336 debug!("WebGPU is available");
1337 true
1338}
1339
1340fn enumerate_webgpu_devices() -> Result<Vec<GpuDevice>> {
1342 #[cfg(all(feature = "webgpu", target_arch = "wasm32"))]
1343 {
1344 if webgpu_check_available() {
1358 debug!("WebGPU is available (async enumeration not supported in sync context)");
1359 Ok(vec![GpuDevice {
1360 id: 0,
1361 name: "WebGPU Device (async enumeration required)".to_string(),
1362 total_memory: 0, free_memory: 0, compute_capability: "WebGPU".to_string(),
1365 backend: GpuBackend::WebGpu,
1366 }])
1367 } else {
1368 Ok(Vec::new())
1369 }
1370 }
1371 #[cfg(not(all(feature = "webgpu", target_arch = "wasm32")))]
1372 {
1373 Ok(Vec::new())
1374 }
1375}
1376
1377#[derive(Debug, Clone, Default)]
1379pub struct GpuMemoryStats {
1380 pub allocated: usize,
1382 pub peak: usize,
1384 pub num_allocations: usize,
1386 pub num_deallocations: usize,
1388}
1389
1390impl GpuMemoryStats {
1391 #[must_use]
1393 pub fn current_usage(&self) -> usize {
1394 self.allocated
1395 }
1396
1397 #[must_use]
1399 pub fn active_allocations(&self) -> usize {
1400 self.num_allocations.saturating_sub(self.num_deallocations)
1401 }
1402}
1403
1404#[cfg(test)]
1405mod tests {
1406 use super::*;
1407
1408 #[test]
1409 fn test_gpu_config_builder() {
1410 let config = GpuConfig::builder()
1411 .backend(GpuBackend::Cuda)
1412 .device_id(1)
1413 .mixed_precision(false)
1414 .tensor_cores(false)
1415 .memory_growth(false)
1416 .memory_fraction(0.8)
1417 .build();
1418
1419 assert_eq!(config.backend, Some(GpuBackend::Cuda));
1420 assert_eq!(config.device_id, Some(1));
1421 assert!(!config.mixed_precision);
1422 assert!(!config.tensor_cores);
1423 assert!(!config.memory_growth);
1424 assert!((config.memory_fraction - 0.8).abs() < 1e-6);
1425 }
1426
1427 #[test]
1428 fn test_memory_fraction_clamping() {
1429 let config1 = GpuConfig::builder().memory_fraction(1.5).build();
1430 assert!((config1.memory_fraction - 1.0).abs() < 1e-6);
1431
1432 let config2 = GpuConfig::builder().memory_fraction(-0.5).build();
1433 assert!((config2.memory_fraction - 0.1).abs() < 1e-6);
1434 }
1435
1436 #[test]
1437 fn test_gpu_device_memory_utilization() {
1438 let device = GpuDevice {
1439 id: 0,
1440 name: "Test GPU".to_string(),
1441 total_memory: 8_000_000_000, free_memory: 2_000_000_000, compute_capability: "8.0".to_string(),
1444 backend: GpuBackend::Cuda,
1445 };
1446
1447 let utilization = device.memory_utilization();
1449 assert!((utilization - 75.0).abs() < 1.0);
1450
1451 assert!(device.has_sufficient_memory(1_000_000_000)); assert!(!device.has_sufficient_memory(3_000_000_000)); }
1454
1455 #[test]
1456 fn test_backend_names() {
1457 assert_eq!(GpuBackend::Cuda.name(), "CUDA");
1458 assert_eq!(GpuBackend::Rocm.name(), "ROCm");
1459 assert_eq!(GpuBackend::DirectMl.name(), "DirectML");
1460 assert_eq!(GpuBackend::Vulkan.name(), "Vulkan");
1461 assert_eq!(GpuBackend::WebGpu.name(), "WebGPU");
1462 assert_eq!(GpuBackend::Metal.name(), "Metal");
1463 assert_eq!(GpuBackend::OpenCl.name(), "OpenCL");
1464 }
1465
1466 #[test]
1467 fn test_gpu_memory_stats() {
1468 let stats = GpuMemoryStats {
1469 allocated: 1000,
1470 peak: 1500,
1471 num_allocations: 10,
1472 num_deallocations: 3,
1473 };
1474
1475 assert_eq!(stats.current_usage(), 1000);
1476 assert_eq!(stats.active_allocations(), 7);
1477 }
1478
1479 #[test]
1480 fn test_backend_availability() {
1481 let _cuda_available = GpuBackend::Cuda.is_available();
1483 let _metal_available = GpuBackend::Metal.is_available();
1484 let _vulkan_available = GpuBackend::Vulkan.is_available();
1485 let _opencl_available = GpuBackend::OpenCl.is_available();
1486 let _rocm_available = GpuBackend::Rocm.is_available();
1487 let _directml_available = GpuBackend::DirectMl.is_available();
1488 let _webgpu_available = GpuBackend::WebGpu.is_available();
1489 }
1490
1491 #[test]
1492 fn test_list_devices() {
1493 let result = list_devices();
1495 assert!(result.is_ok());
1496 let devices = result.ok().unwrap_or_default();
1497
1498 for device in devices {
1500 assert!(!device.name.is_empty());
1501 let _ = device.total_memory;
1503 }
1504 }
1505
1506 #[test]
1507 fn test_select_device_no_gpu() {
1508 let config = GpuConfig::default();
1509 let _result = select_device(&config);
1511 }
1512
1513 #[test]
1514 fn test_device_enumeration_without_features() {
1515 #[cfg(not(feature = "cuda"))]
1517 {
1518 let cuda_devices = enumerate_cuda_devices();
1519 assert!(cuda_devices.is_ok());
1520 assert!(cuda_devices.ok().is_none_or(|d| d.is_empty()));
1521 }
1522
1523 #[cfg(not(feature = "metal"))]
1524 {
1525 let metal_devices = enumerate_metal_devices();
1526 assert!(metal_devices.is_ok());
1527 assert!(metal_devices.ok().is_none_or(|d| d.is_empty()));
1528 }
1529
1530 #[cfg(not(feature = "vulkan"))]
1531 {
1532 let vulkan_devices = enumerate_vulkan_devices();
1533 assert!(vulkan_devices.is_ok());
1534 assert!(vulkan_devices.ok().is_none_or(|d| d.is_empty()));
1535 }
1536
1537 #[cfg(not(feature = "opencl"))]
1538 {
1539 let opencl_devices = enumerate_opencl_devices();
1540 assert!(opencl_devices.is_ok());
1541 assert!(opencl_devices.ok().is_none_or(|d| d.is_empty()));
1542 }
1543 }
1544
1545 #[test]
1546 #[cfg(all(feature = "metal", target_os = "macos"))]
1547 fn test_metal_enumeration() {
1548 let devices = enumerate_metal_devices();
1550 assert!(devices.is_ok());
1551
1552 if let Ok(devs) = devices {
1553 if !devs.is_empty() {
1554 for device in devs {
1555 assert!(!device.name.is_empty());
1556 assert_eq!(device.backend, GpuBackend::Metal);
1557 }
1558 }
1559 }
1560 }
1561
1562 #[test]
1563 #[cfg(feature = "cuda")]
1564 fn test_cuda_detection() {
1565 let available = check_cuda_available();
1567
1568 if available {
1570 let devices = enumerate_cuda_devices();
1571 assert!(devices.is_ok());
1572
1573 if let Ok(devs) = devices {
1574 for device in devs {
1575 assert_eq!(device.backend, GpuBackend::Cuda);
1576 }
1577 }
1578 }
1579 }
1580
1581 #[test]
1582 #[cfg(feature = "vulkan")]
1583 fn test_vulkan_detection() {
1584 let available = check_vulkan_available();
1586
1587 if available {
1589 let devices = enumerate_vulkan_devices();
1590 assert!(devices.is_ok());
1591
1592 if let Ok(devs) = devices {
1593 for device in devs {
1594 assert_eq!(device.backend, GpuBackend::Vulkan);
1595 }
1596 }
1597 }
1598 }
1599
1600 #[test]
1601 #[cfg(feature = "opencl")]
1602 fn test_opencl_detection() {
1603 let available = check_opencl_available();
1605
1606 if available {
1608 let devices = enumerate_opencl_devices();
1609 assert!(devices.is_ok());
1610
1611 if let Ok(devs) = devices {
1612 for device in devs {
1613 assert_eq!(device.backend, GpuBackend::OpenCl);
1614 }
1615 }
1616 }
1617 }
1618
1619 #[test]
1620 fn test_device_selection_with_backend_filter() {
1621 let devices = list_devices().ok().unwrap_or_default();
1622
1623 if !devices.is_empty() {
1624 let first_backend = devices[0].backend;
1625 let config = GpuConfig::builder().backend(first_backend).build();
1626
1627 let result = select_device(&config);
1628
1629 if result.is_ok() {
1630 let device = result.ok().unwrap_or_else(|| devices[0].clone());
1631 assert_eq!(device.backend, first_backend);
1632 }
1633 }
1634 }
1635
1636 #[test]
1637 fn test_zero_memory_device_utilization() {
1638 let device = GpuDevice {
1639 id: 0,
1640 name: "Zero Memory Device".to_string(),
1641 total_memory: 0,
1642 free_memory: 0,
1643 compute_capability: "0.0".to_string(),
1644 backend: GpuBackend::Cuda,
1645 };
1646
1647 assert_eq!(device.memory_utilization(), 0.0);
1648 assert!(!device.has_sufficient_memory(1));
1649 }
1650}