Skip to main content

oxigdal_ml/
gpu.rs

1//! GPU acceleration for ML inference
2//!
3//! This module provides GPU acceleration support across multiple backends
4//! including CUDA, ROCm, DirectML, Vulkan, and WebGPU.
5//!
6//! # Safety
7//!
8//! This module requires unsafe code for FFI operations with GPU libraries.
9//! All unsafe operations are carefully reviewed and documented.
10
11#![allow(unsafe_code)]
12
13use crate::error::{InferenceError, MlError, Result};
14use tracing::{debug, info};
15
16/// GPU backend types
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum GpuBackend {
19    /// NVIDIA CUDA
20    Cuda,
21    /// AMD ROCm
22    Rocm,
23    /// DirectML (Windows)
24    DirectMl,
25    /// Vulkan (cross-platform)
26    Vulkan,
27    /// WebGPU (web/browser)
28    WebGpu,
29    /// Metal (Apple)
30    Metal,
31    /// OpenCL (cross-platform)
32    OpenCl,
33}
34
35impl GpuBackend {
36    /// Returns the backend name
37    #[must_use]
38    pub fn name(&self) -> &'static str {
39        match self {
40            Self::Cuda => "CUDA",
41            Self::Rocm => "ROCm",
42            Self::DirectMl => "DirectML",
43            Self::Vulkan => "Vulkan",
44            Self::WebGpu => "WebGPU",
45            Self::Metal => "Metal",
46            Self::OpenCl => "OpenCL",
47        }
48    }
49
50    /// Checks if the backend is available on the current platform
51    #[must_use]
52    pub fn is_available(&self) -> bool {
53        match self {
54            Self::Cuda => check_cuda_available(),
55            Self::Rocm => check_rocm_available(),
56            Self::DirectMl => check_directml_available(),
57            Self::Vulkan => check_vulkan_available(),
58            Self::WebGpu => check_webgpu_available(),
59            Self::Metal => check_metal_available(),
60            Self::OpenCl => check_opencl_available(),
61        }
62    }
63}
64
65/// GPU device information
66#[derive(Debug, Clone)]
67pub struct GpuDevice {
68    /// Device ID
69    pub id: usize,
70    /// Device name
71    pub name: String,
72    /// Total memory in bytes
73    pub total_memory: usize,
74    /// Free memory in bytes
75    pub free_memory: usize,
76    /// Compute capability (CUDA) or equivalent
77    pub compute_capability: String,
78    /// Backend type
79    pub backend: GpuBackend,
80}
81
82impl GpuDevice {
83    /// Returns memory utilization percentage
84    #[must_use]
85    pub fn memory_utilization(&self) -> f32 {
86        if self.total_memory > 0 {
87            let used = self.total_memory.saturating_sub(self.free_memory);
88            (used as f32 / self.total_memory as f32) * 100.0
89        } else {
90            0.0
91        }
92    }
93
94    /// Checks if the device has sufficient free memory
95    #[must_use]
96    pub fn has_sufficient_memory(&self, required_bytes: usize) -> bool {
97        self.free_memory >= required_bytes
98    }
99}
100
101/// GPU configuration
102#[derive(Debug, Clone)]
103pub struct GpuConfig {
104    /// Preferred backend (None for auto-select)
105    pub backend: Option<GpuBackend>,
106    /// Device ID (None for auto-select)
107    pub device_id: Option<usize>,
108    /// Enable mixed precision (FP16)
109    pub mixed_precision: bool,
110    /// Enable tensor cores (CUDA)
111    pub tensor_cores: bool,
112    /// Memory growth (allocate on demand)
113    pub memory_growth: bool,
114    /// Per-process GPU memory fraction
115    pub memory_fraction: f32,
116}
117
118impl Default for GpuConfig {
119    fn default() -> Self {
120        Self {
121            backend: None,
122            device_id: None,
123            mixed_precision: true,
124            tensor_cores: true,
125            memory_growth: true,
126            memory_fraction: 0.9,
127        }
128    }
129}
130
131impl GpuConfig {
132    /// Creates a configuration builder
133    #[must_use]
134    pub fn builder() -> GpuConfigBuilder {
135        GpuConfigBuilder::default()
136    }
137}
138
139/// Builder for GPU configuration
140#[derive(Debug, Default)]
141pub struct GpuConfigBuilder {
142    backend: Option<GpuBackend>,
143    device_id: Option<usize>,
144    mixed_precision: Option<bool>,
145    tensor_cores: Option<bool>,
146    memory_growth: Option<bool>,
147    memory_fraction: Option<f32>,
148}
149
150impl GpuConfigBuilder {
151    /// Sets the GPU backend
152    #[must_use]
153    pub fn backend(mut self, backend: GpuBackend) -> Self {
154        self.backend = Some(backend);
155        self
156    }
157
158    /// Sets the device ID
159    #[must_use]
160    pub fn device_id(mut self, id: usize) -> Self {
161        self.device_id = Some(id);
162        self
163    }
164
165    /// Enables mixed precision
166    #[must_use]
167    pub fn mixed_precision(mut self, enable: bool) -> Self {
168        self.mixed_precision = Some(enable);
169        self
170    }
171
172    /// Enables tensor cores
173    #[must_use]
174    pub fn tensor_cores(mut self, enable: bool) -> Self {
175        self.tensor_cores = Some(enable);
176        self
177    }
178
179    /// Enables memory growth
180    #[must_use]
181    pub fn memory_growth(mut self, enable: bool) -> Self {
182        self.memory_growth = Some(enable);
183        self
184    }
185
186    /// Sets memory fraction
187    #[must_use]
188    pub fn memory_fraction(mut self, fraction: f32) -> Self {
189        self.memory_fraction = Some(fraction.clamp(0.1, 1.0));
190        self
191    }
192
193    /// Builds the configuration
194    #[must_use]
195    pub fn build(self) -> GpuConfig {
196        GpuConfig {
197            backend: self.backend,
198            device_id: self.device_id,
199            mixed_precision: self.mixed_precision.unwrap_or(true),
200            tensor_cores: self.tensor_cores.unwrap_or(true),
201            memory_growth: self.memory_growth.unwrap_or(true),
202            memory_fraction: self.memory_fraction.unwrap_or(0.9),
203        }
204    }
205}
206
207/// Lists available GPU devices
208///
209/// # Errors
210/// Returns an error if device enumeration fails
211pub fn list_devices() -> Result<Vec<GpuDevice>> {
212    info!("Enumerating GPU devices");
213
214    let mut devices = Vec::new();
215
216    // Try each backend
217    for backend in &[
218        GpuBackend::Cuda,
219        GpuBackend::Rocm,
220        GpuBackend::Metal,
221        GpuBackend::Vulkan,
222        GpuBackend::DirectMl,
223    ] {
224        if backend.is_available() {
225            devices.extend(list_devices_for_backend(*backend)?);
226        }
227    }
228
229    info!("Found {} GPU device(s)", devices.len());
230    Ok(devices)
231}
232
233/// Lists devices for a specific backend
234fn list_devices_for_backend(backend: GpuBackend) -> Result<Vec<GpuDevice>> {
235    debug!("Enumerating devices for backend: {}", backend.name());
236
237    match backend {
238        GpuBackend::Cuda => enumerate_cuda_devices(),
239        GpuBackend::Metal => enumerate_metal_devices(),
240        GpuBackend::Vulkan => enumerate_vulkan_devices(),
241        GpuBackend::OpenCl => enumerate_opencl_devices(),
242        GpuBackend::Rocm => enumerate_rocm_devices(),
243        GpuBackend::DirectMl => enumerate_directml_devices(),
244        GpuBackend::WebGpu => enumerate_webgpu_devices(),
245    }
246}
247
248/// Selects the best available GPU device
249///
250/// # Errors
251/// Returns an error if no GPU is available
252pub fn select_device(config: &GpuConfig) -> Result<GpuDevice> {
253    let devices = list_devices()?;
254
255    if devices.is_empty() {
256        return Err(MlError::Inference(InferenceError::GpuNotAvailable {
257            message: "No GPU devices found".to_string(),
258        }));
259    }
260
261    // Filter by backend if specified
262    let filtered: Vec<_> = if let Some(backend) = config.backend {
263        devices
264            .into_iter()
265            .filter(|d| d.backend == backend)
266            .collect()
267    } else {
268        devices
269    };
270
271    if filtered.is_empty() {
272        return Err(MlError::Inference(InferenceError::GpuNotAvailable {
273            message: "No GPU devices match the specified backend".to_string(),
274        }));
275    }
276
277    // Select by device ID or pick the one with most free memory
278    let device = if let Some(id) = config.device_id {
279        filtered.into_iter().find(|d| d.id == id).ok_or_else(|| {
280            MlError::Inference(InferenceError::GpuNotAvailable {
281                message: format!("Device ID {} not found", id),
282            })
283        })?
284    } else {
285        filtered
286            .into_iter()
287            .max_by_key(|d| d.free_memory)
288            .ok_or_else(|| {
289                MlError::Inference(InferenceError::GpuNotAvailable {
290                    message: "Failed to select GPU device".to_string(),
291                })
292            })?
293    };
294
295    info!(
296        "Selected GPU: {} (free memory: {} MB)",
297        device.name,
298        device.free_memory / (1024 * 1024)
299    );
300
301    Ok(device)
302}
303
304// ============================================================================
305// CUDA Backend Implementation
306// ============================================================================
307
308/// Checks if CUDA is available on the system
309fn check_cuda_available() -> bool {
310    #[cfg(feature = "cuda")]
311    {
312        cuda_check_runtime()
313    }
314    #[cfg(not(feature = "cuda"))]
315    false
316}
317
318/// Checks for CUDA runtime library
319#[cfg(feature = "cuda")]
320fn cuda_check_runtime() -> bool {
321    use libloading::Library;
322
323    // Try to load CUDA runtime library
324    let lib_names = if cfg!(target_os = "windows") {
325        vec!["nvcuda.dll", "cudart64_110.dll", "cudart64_12.dll"]
326    } else if cfg!(target_os = "macos") {
327        vec!["libcuda.dylib", "/usr/local/cuda/lib/libcuda.dylib"]
328    } else {
329        vec![
330            "libcuda.so",
331            "libcuda.so.1",
332            "/usr/lib/x86_64-linux-gnu/libcuda.so",
333            "/usr/local/cuda/lib64/libcuda.so",
334        ]
335    };
336
337    for lib_name in lib_names {
338        if unsafe { Library::new(lib_name) }.is_ok() {
339            debug!("Found CUDA runtime library: {}", lib_name);
340            return true;
341        }
342    }
343
344    debug!("CUDA runtime library not found");
345    false
346}
347
348/// Enumerates CUDA devices
349fn enumerate_cuda_devices() -> Result<Vec<GpuDevice>> {
350    #[cfg(feature = "cuda")]
351    {
352        cuda_enumerate_devices_impl()
353    }
354    #[cfg(not(feature = "cuda"))]
355    {
356        Ok(Vec::new())
357    }
358}
359
360#[cfg(feature = "cuda")]
361fn cuda_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
362    use libloading::{Library, Symbol};
363    use std::ffi::{c_char, c_int};
364
365    // CUDA type definitions
366    type CUdevice = c_int;
367    type CUresult = c_int;
368
369    // CUDA device attributes
370    const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: c_int = 75;
371    const CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: c_int = 76;
372
373    // Try to load CUDA runtime
374    let lib_names = if cfg!(target_os = "windows") {
375        vec!["nvcuda.dll"]
376    } else if cfg!(target_os = "macos") {
377        vec!["libcuda.dylib"]
378    } else {
379        vec!["libcuda.so.1", "libcuda.so"]
380    };
381
382    let lib = lib_names
383        .iter()
384        .find_map(|name| unsafe { Library::new(*name).ok() })
385        .ok_or_else(|| {
386            MlError::Inference(InferenceError::GpuNotAvailable {
387                message: "CUDA library not found".to_string(),
388            })
389        })?;
390
391    // Load cuInit function
392    let cu_init: Symbol<unsafe extern "C" fn(c_int) -> CUresult> = unsafe { lib.get(b"cuInit\0") }
393        .map_err(|e| {
394            MlError::Inference(InferenceError::GpuNotAvailable {
395                message: format!("Failed to load cuInit: {}", e),
396            })
397        })?;
398
399    // Initialize CUDA
400    let result = unsafe { cu_init(0) };
401    if result != 0 {
402        return Err(MlError::Inference(InferenceError::GpuNotAvailable {
403            message: format!("CUDA initialization failed with code: {}", result),
404        }));
405    }
406
407    // Load cuDeviceGetCount function
408    let cu_device_get_count: Symbol<unsafe extern "C" fn(*mut c_int) -> CUresult> =
409        unsafe { lib.get(b"cuDeviceGetCount\0") }.map_err(|e| {
410            MlError::Inference(InferenceError::GpuNotAvailable {
411                message: format!("Failed to load cuDeviceGetCount: {}", e),
412            })
413        })?;
414
415    // Get device count
416    let mut count: c_int = 0;
417    let result = unsafe { cu_device_get_count(&mut count) };
418    if result != 0 {
419        return Err(MlError::Inference(InferenceError::GpuNotAvailable {
420            message: format!("Failed to get CUDA device count: {}", result),
421        }));
422    }
423
424    debug!("Found {} CUDA device(s)", count);
425
426    // Load additional CUDA functions for device property queries
427    let cu_device_get: Symbol<unsafe extern "C" fn(*mut CUdevice, c_int) -> CUresult> =
428        unsafe { lib.get(b"cuDeviceGet\0") }.map_err(|e| {
429            MlError::Inference(InferenceError::GpuNotAvailable {
430                message: format!("Failed to load cuDeviceGet: {}", e),
431            })
432        })?;
433
434    let cu_device_get_name: Symbol<unsafe extern "C" fn(*mut c_char, c_int, CUdevice) -> CUresult> =
435        unsafe { lib.get(b"cuDeviceGetName\0") }.map_err(|e| {
436            MlError::Inference(InferenceError::GpuNotAvailable {
437                message: format!("Failed to load cuDeviceGetName: {}", e),
438            })
439        })?;
440
441    let cu_device_total_mem: Symbol<unsafe extern "C" fn(*mut usize, CUdevice) -> CUresult> =
442        unsafe { lib.get(b"cuDeviceTotalMem_v2\0") }
443            .or_else(|_| unsafe { lib.get(b"cuDeviceTotalMem\0") })
444            .map_err(|e| {
445                MlError::Inference(InferenceError::GpuNotAvailable {
446                    message: format!("Failed to load cuDeviceTotalMem: {}", e),
447                })
448            })?;
449
450    let cu_device_get_attribute: Symbol<
451        unsafe extern "C" fn(*mut c_int, c_int, CUdevice) -> CUresult,
452    > = unsafe { lib.get(b"cuDeviceGetAttribute\0") }.map_err(|e| {
453        MlError::Inference(InferenceError::GpuNotAvailable {
454            message: format!("Failed to load cuDeviceGetAttribute: {}", e),
455        })
456    })?;
457
458    // Enumerate devices and query their properties
459    let mut devices = Vec::new();
460    for i in 0..count {
461        // Get device handle
462        let mut device: CUdevice = 0;
463        let result = unsafe { cu_device_get(&mut device, i) };
464        if result != 0 {
465            debug!("Failed to get CUDA device {}: error code {}", i, result);
466            continue;
467        }
468
469        // Query device name
470        let mut name_buf = [0i8; 256];
471        let result = unsafe { cu_device_get_name(name_buf.as_mut_ptr(), 256, device) };
472        let device_name = if result == 0 {
473            unsafe {
474                std::ffi::CStr::from_ptr(name_buf.as_ptr())
475                    .to_string_lossy()
476                    .into_owned()
477            }
478        } else {
479            format!("NVIDIA CUDA Device {}", i)
480        };
481
482        // Query total memory
483        let mut total_mem: usize = 0;
484        let result = unsafe { cu_device_total_mem(&mut total_mem, device) };
485        if result != 0 {
486            debug!(
487                "Failed to get total memory for device {}: error code {}",
488                i, result
489            );
490            total_mem = 0;
491        }
492
493        // Note: cuMemGetInfo requires a CUDA context to be active, which is not
494        // available during enumeration. We'll report total memory and assume
495        // most is free for simplicity. Applications should check actual free
496        // memory after context creation.
497        let free_mem = (total_mem as f64 * 0.95) as usize; // Estimate 95% available
498
499        // Query compute capability
500        let mut major: c_int = 0;
501        let mut minor: c_int = 0;
502        let result_major = unsafe {
503            cu_device_get_attribute(
504                &mut major,
505                CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
506                device,
507            )
508        };
509        let result_minor = unsafe {
510            cu_device_get_attribute(
511                &mut minor,
512                CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
513                device,
514            )
515        };
516
517        let compute_capability = if result_major == 0 && result_minor == 0 {
518            format!("{}.{}", major, minor)
519        } else {
520            "Unknown".to_string()
521        };
522
523        devices.push(GpuDevice {
524            id: i as usize,
525            name: device_name,
526            total_memory: total_mem,
527            free_memory: free_mem,
528            compute_capability,
529            backend: GpuBackend::Cuda,
530        });
531    }
532
533    Ok(devices)
534}
535
536// ============================================================================
537// Metal Backend Implementation (macOS)
538// ============================================================================
539
540/// Checks if Metal is available on the system
541fn check_metal_available() -> bool {
542    #[cfg(all(feature = "metal", target_os = "macos"))]
543    {
544        metal_check_available()
545    }
546    #[cfg(not(all(feature = "metal", target_os = "macos")))]
547    false
548}
549
550#[cfg(all(feature = "metal", target_os = "macos"))]
551fn metal_check_available() -> bool {
552    // Metal is always available on macOS 10.11+
553    // We can verify by trying to create a device
554    !metal::Device::all().is_empty()
555}
556
557/// Enumerates Metal devices
558fn enumerate_metal_devices() -> Result<Vec<GpuDevice>> {
559    #[cfg(all(feature = "metal", target_os = "macos"))]
560    {
561        metal_enumerate_devices_impl()
562    }
563    #[cfg(not(all(feature = "metal", target_os = "macos")))]
564    {
565        Ok(Vec::new())
566    }
567}
568
569#[cfg(all(feature = "metal", target_os = "macos"))]
570fn metal_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
571    let devices = metal::Device::all();
572
573    if devices.is_empty() {
574        return Err(MlError::Inference(InferenceError::GpuNotAvailable {
575            message: "No Metal devices found".to_string(),
576        }));
577    }
578
579    let mut gpu_devices = Vec::new();
580
581    for (idx, device) in devices.iter().enumerate() {
582        let name = device.name().to_string();
583        let total_memory = device.recommended_max_working_set_size() as usize;
584        let free_memory = total_memory; // Metal doesn't provide free memory directly
585
586        gpu_devices.push(GpuDevice {
587            id: idx,
588            name,
589            total_memory,
590            free_memory,
591            compute_capability: "Metal".to_string(),
592            backend: GpuBackend::Metal,
593        });
594    }
595
596    debug!("Found {} Metal device(s)", gpu_devices.len());
597    Ok(gpu_devices)
598}
599
600// ============================================================================
601// Vulkan Backend Implementation
602// ============================================================================
603
604/// Checks if Vulkan is available on the system
605fn check_vulkan_available() -> bool {
606    #[cfg(feature = "vulkan")]
607    {
608        vulkan_check_available()
609    }
610    #[cfg(not(feature = "vulkan"))]
611    false
612}
613
614#[cfg(feature = "vulkan")]
615fn vulkan_check_available() -> bool {
616    use ash::{Entry, vk};
617
618    // Try to load Vulkan
619    if let Ok(entry) = unsafe { Entry::load() } {
620        // Try to create instance
621        let app_info = vk::ApplicationInfo::default().api_version(vk::make_api_version(0, 1, 0, 0));
622
623        let create_info = vk::InstanceCreateInfo::default().application_info(&app_info);
624
625        if let Ok(instance) = unsafe { entry.create_instance(&create_info, None) } {
626            unsafe { instance.destroy_instance(None) };
627            return true;
628        }
629    }
630
631    false
632}
633
634/// Enumerates Vulkan devices
635fn enumerate_vulkan_devices() -> Result<Vec<GpuDevice>> {
636    #[cfg(feature = "vulkan")]
637    {
638        vulkan_enumerate_devices_impl()
639    }
640    #[cfg(not(feature = "vulkan"))]
641    {
642        Ok(Vec::new())
643    }
644}
645
646#[cfg(feature = "vulkan")]
647fn vulkan_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
648    use ash::{Entry, vk};
649
650    let entry = unsafe { Entry::load() }.map_err(|e| {
651        MlError::Inference(InferenceError::GpuNotAvailable {
652            message: format!("Failed to load Vulkan: {}", e),
653        })
654    })?;
655
656    let app_info = vk::ApplicationInfo::default().api_version(vk::make_api_version(0, 1, 0, 0));
657
658    let create_info = vk::InstanceCreateInfo::default().application_info(&app_info);
659
660    let instance = unsafe { entry.create_instance(&create_info, None) }.map_err(|e| {
661        MlError::Inference(InferenceError::GpuNotAvailable {
662            message: format!("Failed to create Vulkan instance: {}", e),
663        })
664    })?;
665
666    let physical_devices = unsafe { instance.enumerate_physical_devices() }.map_err(|e| {
667        unsafe { instance.destroy_instance(None) };
668        MlError::Inference(InferenceError::GpuNotAvailable {
669            message: format!("Failed to enumerate Vulkan devices: {}", e),
670        })
671    })?;
672
673    let mut devices = Vec::new();
674    for (idx, physical_device) in physical_devices.iter().enumerate() {
675        let properties = unsafe { instance.get_physical_device_properties(*physical_device) };
676        let memory_properties =
677            unsafe { instance.get_physical_device_memory_properties(*physical_device) };
678
679        let device_name = unsafe {
680            std::ffi::CStr::from_ptr(properties.device_name.as_ptr())
681                .to_string_lossy()
682                .into_owned()
683        };
684
685        // Calculate total memory from memory heaps
686        let total_memory = memory_properties
687            .memory_heaps
688            .iter()
689            .take(memory_properties.memory_heap_count as usize)
690            .filter(|heap| heap.flags.contains(vk::MemoryHeapFlags::DEVICE_LOCAL))
691            .map(|heap| heap.size as usize)
692            .sum();
693
694        devices.push(GpuDevice {
695            id: idx,
696            name: device_name,
697            total_memory,
698            free_memory: total_memory, // Vulkan doesn't provide free memory directly
699            compute_capability: format!(
700                "Vulkan {}.{}",
701                vk::api_version_major(properties.api_version),
702                vk::api_version_minor(properties.api_version)
703            ),
704            backend: GpuBackend::Vulkan,
705        });
706    }
707
708    unsafe { instance.destroy_instance(None) };
709    debug!("Found {} Vulkan device(s)", devices.len());
710    Ok(devices)
711}
712
713// ============================================================================
714// OpenCL Backend Implementation
715// ============================================================================
716
717/// Checks if OpenCL is available on the system
718fn check_opencl_available() -> bool {
719    #[cfg(feature = "opencl")]
720    {
721        opencl_check_available()
722    }
723    #[cfg(not(feature = "opencl"))]
724    false
725}
726
727#[cfg(feature = "opencl")]
728fn opencl_check_available() -> bool {
729    use opencl3::platform::get_platforms;
730
731    if let Ok(platforms) = get_platforms() {
732        return !platforms.is_empty();
733    }
734    false
735}
736
737/// Enumerates OpenCL devices
738fn enumerate_opencl_devices() -> Result<Vec<GpuDevice>> {
739    #[cfg(feature = "opencl")]
740    {
741        opencl_enumerate_devices_impl()
742    }
743    #[cfg(not(feature = "opencl"))]
744    {
745        Ok(Vec::new())
746    }
747}
748
749#[cfg(feature = "opencl")]
750fn opencl_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
751    use opencl3::device::{CL_DEVICE_TYPE_GPU, Device, get_all_devices};
752
753    let device_ids = get_all_devices(CL_DEVICE_TYPE_GPU).map_err(|e| {
754        MlError::Inference(InferenceError::GpuNotAvailable {
755            message: format!("Failed to get OpenCL devices: {:?}", e),
756        })
757    })?;
758
759    let mut devices = Vec::new();
760    for (idx, device_id) in device_ids.iter().enumerate() {
761        let device = Device::new(*device_id);
762
763        let name = device.name().map_err(|e| {
764            MlError::Inference(InferenceError::GpuNotAvailable {
765                message: format!("Failed to get device name: {:?}", e),
766            })
767        })?;
768
769        let total_memory = device.global_mem_size().map_err(|e| {
770            MlError::Inference(InferenceError::GpuNotAvailable {
771                message: format!("Failed to get device memory: {:?}", e),
772            })
773        })? as usize;
774
775        devices.push(GpuDevice {
776            id: idx,
777            name,
778            total_memory,
779            free_memory: total_memory, // OpenCL doesn't provide free memory directly
780            compute_capability: "OpenCL".to_string(),
781            backend: GpuBackend::OpenCl,
782        });
783    }
784
785    debug!("Found {} OpenCL device(s)", devices.len());
786    Ok(devices)
787}
788
789// ============================================================================
790// ROCm Backend Implementation
791// ============================================================================
792
793/// Checks if ROCm is available on the system
794fn check_rocm_available() -> bool {
795    #[cfg(feature = "rocm")]
796    {
797        rocm_check_runtime()
798    }
799    #[cfg(not(feature = "rocm"))]
800    false
801}
802
803#[cfg(feature = "rocm")]
804fn rocm_check_runtime() -> bool {
805    use libloading::Library;
806
807    // Try to load ROCm runtime library
808    let lib_names = if cfg!(target_os = "linux") {
809        vec![
810            "libamdhip64.so",
811            "libamdhip64.so.6",
812            "/opt/rocm/lib/libamdhip64.so",
813        ]
814    } else {
815        vec![]
816    };
817
818    for lib_name in lib_names {
819        if unsafe { Library::new(lib_name) }.is_ok() {
820            debug!("Found ROCm runtime library: {}", lib_name);
821            return true;
822        }
823    }
824
825    debug!("ROCm runtime library not found");
826    false
827}
828
829/// Enumerates ROCm devices
830fn enumerate_rocm_devices() -> Result<Vec<GpuDevice>> {
831    #[cfg(feature = "rocm")]
832    {
833        rocm_enumerate_devices_impl()
834    }
835    #[cfg(not(feature = "rocm"))]
836    {
837        Ok(Vec::new())
838    }
839}
840
841#[cfg(feature = "rocm")]
842fn rocm_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
843    use libloading::{Library, Symbol};
844    use std::ffi::{c_char, c_int};
845
846    // HIP type definitions
847    type HipError = c_int;
848    const HIP_SUCCESS: HipError = 0;
849
850    // Simplified hipDeviceProp_t structure (only fields we need)
851    #[repr(C)]
852    #[allow(non_snake_case)]
853    struct HipDeviceProp {
854        name: [c_char; 256],
855        totalGlobalMem: usize,
856        sharedMemPerBlock: usize,
857        regsPerBlock: c_int,
858        warpSize: c_int,
859        memPitch: usize,
860        maxThreadsPerBlock: c_int,
861        maxThreadsDim: [c_int; 3],
862        maxGridSize: [c_int; 3],
863        clockRate: c_int,
864        totalConstMem: usize,
865        major: c_int,
866        minor: c_int,
867        textureAlignment: usize,
868        deviceOverlap: c_int,
869        multiProcessorCount: c_int,
870        kernelExecTimeoutEnabled: c_int,
871        integrated: c_int,
872        canMapHostMemory: c_int,
873        computeMode: c_int,
874        maxTexture1D: c_int,
875        maxTexture2D: [c_int; 2],
876        maxTexture3D: [c_int; 3],
877        // Note: Full structure has more fields, but we only need these
878        // The actual memory layout may vary by HIP version
879        gcnArchName: [c_char; 256],
880    }
881
882    // Try to load ROCm HIP runtime library
883    let lib_names = if cfg!(target_os = "linux") {
884        vec![
885            "libamdhip64.so",
886            "libamdhip64.so.6",
887            "/opt/rocm/lib/libamdhip64.so",
888        ]
889    } else {
890        vec![]
891    };
892
893    let lib = lib_names
894        .iter()
895        .find_map(|name| unsafe { Library::new(*name).ok() })
896        .ok_or_else(|| {
897            MlError::Inference(InferenceError::GpuNotAvailable {
898                message: "ROCm HIP library not found".to_string(),
899            })
900        })?;
901
902    // Load hipGetDeviceCount function
903    let hip_get_device_count: Symbol<unsafe extern "C" fn(*mut c_int) -> HipError> =
904        unsafe { lib.get(b"hipGetDeviceCount\0") }.map_err(|e| {
905            MlError::Inference(InferenceError::GpuNotAvailable {
906                message: format!("Failed to load hipGetDeviceCount: {}", e),
907            })
908        })?;
909
910    // Get device count
911    let mut count: c_int = 0;
912    let result = unsafe { hip_get_device_count(&mut count) };
913    if result != HIP_SUCCESS {
914        return Err(MlError::Inference(InferenceError::GpuNotAvailable {
915            message: format!("hipGetDeviceCount failed with error code: {}", result),
916        }));
917    }
918
919    debug!("Found {} ROCm device(s)", count);
920
921    // Load hipGetDeviceProperties function
922    let hip_get_device_properties: Symbol<
923        unsafe extern "C" fn(*mut HipDeviceProp, c_int) -> HipError,
924    > = unsafe { lib.get(b"hipGetDeviceProperties\0") }.map_err(|e| {
925        MlError::Inference(InferenceError::GpuNotAvailable {
926            message: format!("Failed to load hipGetDeviceProperties: {}", e),
927        })
928    })?;
929
930    // Load hipMemGetInfo function (optional - may not be available without active context)
931    let hip_mem_get_info: Option<Symbol<unsafe extern "C" fn(*mut usize, *mut usize) -> HipError>> =
932        unsafe { lib.get(b"hipMemGetInfo\0").ok() };
933
934    // Load hipSetDevice to set device context (optional)
935    let hip_set_device: Option<Symbol<unsafe extern "C" fn(c_int) -> HipError>> =
936        unsafe { lib.get(b"hipSetDevice\0").ok() };
937
938    // Enumerate devices and query their properties
939    let mut devices = Vec::new();
940    for i in 0..count {
941        // Query device properties
942        let mut props: HipDeviceProp = unsafe { std::mem::zeroed() };
943        let result = unsafe { hip_get_device_properties(&mut props, i) };
944        if result != HIP_SUCCESS {
945            debug!(
946                "Failed to get properties for ROCm device {}: error code {}",
947                i, result
948            );
949            continue;
950        }
951
952        // Extract device name
953        let device_name = unsafe {
954            std::ffi::CStr::from_ptr(props.name.as_ptr())
955                .to_string_lossy()
956                .into_owned()
957        };
958
959        let total_memory = props.totalGlobalMem;
960
961        // Try to get free memory if possible
962        let free_memory = if let (Some(set_dev), Some(get_info)) =
963            (hip_set_device.as_ref(), hip_mem_get_info.as_ref())
964        {
965            // Try to set device context
966            let set_result = unsafe { set_dev(i) };
967            if set_result == HIP_SUCCESS {
968                let mut free: usize = 0;
969                let mut total: usize = 0;
970                let result = unsafe { get_info(&mut free, &mut total) };
971                if result == HIP_SUCCESS {
972                    free
973                } else {
974                    // Estimate 95% available if query fails
975                    (total_memory as f64 * 0.95) as usize
976                }
977            } else {
978                // Estimate 95% available if context creation fails
979                (total_memory as f64 * 0.95) as usize
980            }
981        } else {
982            // Estimate 95% available if functions not available
983            (total_memory as f64 * 0.95) as usize
984        };
985
986        // Extract GCN architecture name (compute capability equivalent)
987        let compute_capability = unsafe {
988            let gcn_name = std::ffi::CStr::from_ptr(props.gcnArchName.as_ptr())
989                .to_string_lossy()
990                .into_owned();
991            if gcn_name.is_empty() {
992                format!("{}.{}", props.major, props.minor)
993            } else {
994                gcn_name
995            }
996        };
997
998        devices.push(GpuDevice {
999            id: i as usize,
1000            name: device_name,
1001            total_memory,
1002            free_memory,
1003            compute_capability,
1004            backend: GpuBackend::Rocm,
1005        });
1006    }
1007
1008    Ok(devices)
1009}
1010
1011// ============================================================================
1012// DirectML Backend Implementation (Windows)
1013// ============================================================================
1014
1015/// Checks if DirectML is available on the system
1016fn check_directml_available() -> bool {
1017    #[cfg(all(feature = "directml", target_os = "windows"))]
1018    {
1019        directml_check_runtime()
1020    }
1021    #[cfg(not(all(feature = "directml", target_os = "windows")))]
1022    false
1023}
1024
1025#[cfg(all(feature = "directml", target_os = "windows"))]
1026fn directml_check_runtime() -> bool {
1027    use libloading::Library;
1028
1029    // Check Windows version - DirectML requires Windows 10 1903+ (build 18362)
1030    // Using RtlGetVersion as GetVersionEx is deprecated and may return incorrect values
1031    #[cfg(windows)]
1032    {
1033        use std::mem;
1034
1035        #[repr(C)]
1036        struct OsVersionInfoExW {
1037            dw_os_version_info_size: u32,
1038            dw_major_version: u32,
1039            dw_minor_version: u32,
1040            dw_build_number: u32,
1041            dw_platform_id: u32,
1042            sz_csd_version: [u16; 128],
1043            w_service_pack_major: u16,
1044            w_service_pack_minor: u16,
1045            w_suite_mask: u16,
1046            w_product_type: u8,
1047            w_reserved: u8,
1048        }
1049
1050        // Try to load ntdll.dll and get version info
1051        if let Ok(ntdll) = unsafe { Library::new("ntdll.dll") } {
1052            type RtlGetVersion = unsafe extern "system" fn(*mut OsVersionInfoExW) -> i32;
1053            if let Ok(rtl_get_version) = unsafe { ntdll.get::<RtlGetVersion>(b"RtlGetVersion\0") } {
1054                let mut version_info: OsVersionInfoExW = unsafe { mem::zeroed() };
1055                version_info.dw_os_version_info_size = mem::size_of::<OsVersionInfoExW>() as u32;
1056
1057                let status = unsafe { rtl_get_version(&mut version_info) };
1058                if status == 0 {
1059                    // Check for Windows 10 1903+ (10.0.18362)
1060                    let is_compatible = version_info.dw_major_version > 10
1061                        || (version_info.dw_major_version == 10
1062                            && version_info.dw_build_number >= 18362);
1063
1064                    if !is_compatible {
1065                        debug!(
1066                            "Windows version {}.{}.{} is too old for DirectML (requires 10.0.18362+)",
1067                            version_info.dw_major_version,
1068                            version_info.dw_minor_version,
1069                            version_info.dw_build_number
1070                        );
1071                        return false;
1072                    }
1073                }
1074            }
1075        }
1076    }
1077
1078    // Try to load DirectML.dll
1079    let directml_names = vec!["DirectML.dll", "DirectML.Debug.dll"];
1080
1081    for lib_name in directml_names {
1082        if let Ok(lib) = unsafe { Library::new(lib_name) } {
1083            // Verify minimum DirectML version by checking exports
1084            let has_v1_0 = unsafe { lib.get::<*const ()>(b"DMLCreateDevice\0") }.is_ok();
1085            let has_v1_1 = unsafe { lib.get::<*const ()>(b"DMLCreateDevice1\0") }.is_ok();
1086
1087            if has_v1_0 || has_v1_1 {
1088                debug!("Found DirectML library: {}", lib_name);
1089                return true;
1090            } else {
1091                debug!(
1092                    "DirectML library {} found but missing required exports",
1093                    lib_name
1094                );
1095            }
1096        }
1097    }
1098
1099    debug!("DirectML library not found");
1100    false
1101}
1102
1103/// Enumerates DirectML devices
1104fn enumerate_directml_devices() -> Result<Vec<GpuDevice>> {
1105    #[cfg(all(feature = "directml", target_os = "windows"))]
1106    {
1107        directml_enumerate_devices_impl()
1108    }
1109    #[cfg(not(all(feature = "directml", target_os = "windows")))]
1110    {
1111        Ok(Vec::new())
1112    }
1113}
1114
1115#[cfg(all(feature = "directml", target_os = "windows"))]
1116fn directml_enumerate_devices_impl() -> Result<Vec<GpuDevice>> {
1117    use libloading::{Library, Symbol};
1118    use std::ffi::c_void;
1119
1120    // COM GUID structure
1121    #[repr(C)]
1122    struct Guid {
1123        data1: u32,
1124        data2: u16,
1125        data3: u16,
1126        data4: [u8; 8],
1127    }
1128
1129    // DXGI_ADAPTER_DESC1 structure (simplified)
1130    #[repr(C)]
1131    struct DxgiAdapterDesc1 {
1132        description: [u16; 128],
1133        vendor_id: u32,
1134        device_id: u32,
1135        sub_sys_id: u32,
1136        revision: u32,
1137        dedicated_video_memory: usize,
1138        dedicated_system_memory: usize,
1139        shared_system_memory: usize,
1140        adapter_luid: [u8; 8],
1141        flags: u32,
1142    }
1143
1144    // HRESULT type
1145    type HResult = i32;
1146    const S_OK: HResult = 0;
1147    const DXGI_ERROR_NOT_FOUND: HResult = -2005270526; // 0x887A0002
1148
1149    // Feature level
1150    const D3D_FEATURE_LEVEL_11_0: u32 = 0xb000;
1151
1152    // IID for DXGI interfaces
1153    const IID_IDXGI_FACTORY1: Guid = Guid {
1154        data1: 0x770aae78,
1155        data2: 0xf26f,
1156        data3: 0x4dba,
1157        data4: [0xa8, 0x29, 0x25, 0x3c, 0x83, 0xd1, 0xb3, 0x87],
1158    };
1159
1160    // Try to load dxgi.dll
1161    let dxgi_lib = unsafe { Library::new("dxgi.dll") }.map_err(|e| {
1162        MlError::Inference(InferenceError::GpuNotAvailable {
1163            message: format!("Failed to load dxgi.dll: {}", e),
1164        })
1165    })?;
1166
1167    // Load CreateDXGIFactory1
1168    let create_factory: Symbol<
1169        unsafe extern "system" fn(*const Guid, *mut *mut c_void) -> HResult,
1170    > = unsafe { dxgi_lib.get(b"CreateDXGIFactory1\0") }.map_err(|e| {
1171        MlError::Inference(InferenceError::GpuNotAvailable {
1172            message: format!("Failed to load CreateDXGIFactory1: {}", e),
1173        })
1174    })?;
1175
1176    // Create DXGI factory
1177    let mut factory: *mut c_void = std::ptr::null_mut();
1178    let hr = unsafe { create_factory(&IID_IDXGI_FACTORY1, &mut factory) };
1179    if hr != S_OK || factory.is_null() {
1180        return Err(MlError::Inference(InferenceError::GpuNotAvailable {
1181            message: format!("Failed to create DXGI factory: HRESULT = 0x{:08X}", hr),
1182        }));
1183    }
1184
1185    // IDXGIFactory1 vtable offsets (simplified)
1186    // EnumAdapters1 is at offset 7 in the vtable
1187    type EnumAdapters1Fn = unsafe extern "system" fn(*mut c_void, u32, *mut *mut c_void) -> HResult;
1188
1189    let vtable = unsafe { *(factory as *const *const c_void) };
1190    let enum_adapters1: EnumAdapters1Fn =
1191        unsafe { std::mem::transmute(*((vtable as *const *const c_void).offset(7))) };
1192
1193    let mut devices = Vec::new();
1194    let mut adapter_index = 0u32;
1195
1196    loop {
1197        let mut adapter: *mut c_void = std::ptr::null_mut();
1198        let hr = unsafe { enum_adapters1(factory, adapter_index, &mut adapter) };
1199
1200        if hr == DXGI_ERROR_NOT_FOUND {
1201            break;
1202        }
1203
1204        if hr != S_OK || adapter.is_null() {
1205            break;
1206        }
1207
1208        // IDXGIAdapter1::GetDesc1 is at offset 10 in the vtable
1209        type GetDesc1Fn = unsafe extern "system" fn(*mut c_void, *mut DxgiAdapterDesc1) -> HResult;
1210
1211        let adapter_vtable = unsafe { *(adapter as *const *const c_void) };
1212        let get_desc1: GetDesc1Fn =
1213            unsafe { std::mem::transmute(*((adapter_vtable as *const *const c_void).offset(10))) };
1214
1215        let mut desc: DxgiAdapterDesc1 = unsafe { std::mem::zeroed() };
1216        let hr = unsafe { get_desc1(adapter, &mut desc) };
1217
1218        if hr == S_OK {
1219            // Convert UTF-16 description to String
1220            let description = String::from_utf16_lossy(&desc.description)
1221                .trim_end_matches('\0')
1222                .to_string();
1223
1224            // Only include adapters with dedicated video memory (discrete GPUs)
1225            if desc.dedicated_video_memory > 0 {
1226                devices.push(GpuDevice {
1227                    id: adapter_index as usize,
1228                    name: description,
1229                    total_memory: desc.dedicated_video_memory,
1230                    free_memory: (desc.dedicated_video_memory as f64 * 0.95) as usize,
1231                    compute_capability: "DirectML".to_string(),
1232                    backend: GpuBackend::DirectMl,
1233                });
1234            }
1235        }
1236
1237        // Release adapter
1238        type ReleaseFn = unsafe extern "system" fn(*mut c_void) -> u32;
1239        let release: ReleaseFn = unsafe {
1240            let adapter_vtable = *(adapter as *const *const c_void);
1241            std::mem::transmute(*((adapter_vtable as *const *const c_void).offset(2)))
1242        };
1243        unsafe { release(adapter) };
1244
1245        adapter_index += 1;
1246    }
1247
1248    // Release factory
1249    type ReleaseFn = unsafe extern "system" fn(*mut c_void) -> u32;
1250    let release: ReleaseFn =
1251        unsafe { std::mem::transmute(*((vtable as *const *const c_void).offset(2))) };
1252    unsafe { release(factory) };
1253
1254    debug!("Found {} DirectML-compatible device(s)", devices.len());
1255    Ok(devices)
1256}
1257
1258// ============================================================================
1259// WebGPU Backend Implementation
1260// ============================================================================
1261
1262/// Checks if WebGPU is available
1263fn check_webgpu_available() -> bool {
1264    #[cfg(all(feature = "webgpu", target_arch = "wasm32"))]
1265    {
1266        webgpu_check_available()
1267    }
1268    #[cfg(not(all(feature = "webgpu", target_arch = "wasm32")))]
1269    false
1270}
1271
1272#[cfg(all(feature = "webgpu", target_arch = "wasm32"))]
1273fn webgpu_check_available() -> bool {
1274    use js_sys::Reflect;
1275    use wasm_bindgen::JsValue;
1276    use web_sys::window;
1277
1278    // Get the window object
1279    let window = match window() {
1280        Some(w) => w,
1281        None => {
1282            debug!("WebGPU: window object not available");
1283            return false;
1284        }
1285    };
1286
1287    // Get the navigator object
1288    let navigator = window.navigator();
1289    let navigator_val = JsValue::from(&navigator);
1290
1291    // Check if navigator has a 'gpu' property
1292    let gpu_key = JsValue::from_str("gpu");
1293    match Reflect::has(&navigator_val, &gpu_key) {
1294        Ok(has_gpu) => {
1295            if !has_gpu {
1296                debug!("WebGPU: navigator.gpu property not found");
1297                return false;
1298            }
1299        }
1300        Err(e) => {
1301            debug!("WebGPU: Failed to check for gpu property: {:?}", e);
1302            return false;
1303        }
1304    }
1305
1306    // Get the gpu object
1307    let gpu = match Reflect::get(&navigator_val, &gpu_key) {
1308        Ok(g) => g,
1309        Err(e) => {
1310            debug!("WebGPU: Failed to get gpu object: {:?}", e);
1311            return false;
1312        }
1313    };
1314
1315    // Check if gpu is defined (not null/undefined)
1316    if gpu.is_null() || gpu.is_undefined() {
1317        debug!("WebGPU: navigator.gpu is null or undefined");
1318        return false;
1319    }
1320
1321    // Check if requestAdapter exists on the gpu object
1322    let request_adapter_key = JsValue::from_str("requestAdapter");
1323    match Reflect::has(&gpu, &request_adapter_key) {
1324        Ok(has_request_adapter) => {
1325            if !has_request_adapter {
1326                debug!("WebGPU: requestAdapter method not found");
1327                return false;
1328            }
1329        }
1330        Err(e) => {
1331            debug!("WebGPU: Failed to check for requestAdapter: {:?}", e);
1332            return false;
1333        }
1334    }
1335
1336    debug!("WebGPU is available");
1337    true
1338}
1339
1340/// Enumerates WebGPU devices
1341fn enumerate_webgpu_devices() -> Result<Vec<GpuDevice>> {
1342    #[cfg(all(feature = "webgpu", target_arch = "wasm32"))]
1343    {
1344        // WebGPU device enumeration requires async operations which cannot be
1345        // done in a synchronous context. Instead, we return a placeholder device
1346        // if WebGPU is available. Applications should use the async WebGPU APIs
1347        // directly for proper adapter enumeration.
1348        //
1349        // For a proper implementation, applications should:
1350        // 1. Use wasm-bindgen-futures to handle Promises
1351        // 2. Call navigator.gpu.requestAdapter() with different power preferences
1352        // 3. Query adapter.info for device information
1353        // 4. Query adapter.limits for capability information
1354        //
1355        // Since this function is synchronous and WebGPU is inherently async,
1356        // we provide a simplified implementation that checks availability.
1357        if webgpu_check_available() {
1358            debug!("WebGPU is available (async enumeration not supported in sync context)");
1359            Ok(vec![GpuDevice {
1360                id: 0,
1361                name: "WebGPU Device (async enumeration required)".to_string(),
1362                total_memory: 0, // Unknown without async query
1363                free_memory: 0,  // Unknown without async query
1364                compute_capability: "WebGPU".to_string(),
1365                backend: GpuBackend::WebGpu,
1366            }])
1367        } else {
1368            Ok(Vec::new())
1369        }
1370    }
1371    #[cfg(not(all(feature = "webgpu", target_arch = "wasm32")))]
1372    {
1373        Ok(Vec::new())
1374    }
1375}
1376
1377/// GPU memory statistics
1378#[derive(Debug, Clone, Default)]
1379pub struct GpuMemoryStats {
1380    /// Total memory allocated in bytes
1381    pub allocated: usize,
1382    /// Peak memory usage in bytes
1383    pub peak: usize,
1384    /// Number of allocations
1385    pub num_allocations: usize,
1386    /// Number of deallocations
1387    pub num_deallocations: usize,
1388}
1389
1390impl GpuMemoryStats {
1391    /// Returns the current memory usage
1392    #[must_use]
1393    pub fn current_usage(&self) -> usize {
1394        self.allocated
1395    }
1396
1397    /// Returns the number of active allocations
1398    #[must_use]
1399    pub fn active_allocations(&self) -> usize {
1400        self.num_allocations.saturating_sub(self.num_deallocations)
1401    }
1402}
1403
1404#[cfg(test)]
1405mod tests {
1406    use super::*;
1407
1408    #[test]
1409    fn test_gpu_config_builder() {
1410        let config = GpuConfig::builder()
1411            .backend(GpuBackend::Cuda)
1412            .device_id(1)
1413            .mixed_precision(false)
1414            .tensor_cores(false)
1415            .memory_growth(false)
1416            .memory_fraction(0.8)
1417            .build();
1418
1419        assert_eq!(config.backend, Some(GpuBackend::Cuda));
1420        assert_eq!(config.device_id, Some(1));
1421        assert!(!config.mixed_precision);
1422        assert!(!config.tensor_cores);
1423        assert!(!config.memory_growth);
1424        assert!((config.memory_fraction - 0.8).abs() < 1e-6);
1425    }
1426
1427    #[test]
1428    fn test_memory_fraction_clamping() {
1429        let config1 = GpuConfig::builder().memory_fraction(1.5).build();
1430        assert!((config1.memory_fraction - 1.0).abs() < 1e-6);
1431
1432        let config2 = GpuConfig::builder().memory_fraction(-0.5).build();
1433        assert!((config2.memory_fraction - 0.1).abs() < 1e-6);
1434    }
1435
1436    #[test]
1437    fn test_gpu_device_memory_utilization() {
1438        let device = GpuDevice {
1439            id: 0,
1440            name: "Test GPU".to_string(),
1441            total_memory: 8_000_000_000, // 8 GB
1442            free_memory: 2_000_000_000,  // 2 GB free
1443            compute_capability: "8.0".to_string(),
1444            backend: GpuBackend::Cuda,
1445        };
1446
1447        // 6 GB used out of 8 GB = 75%
1448        let utilization = device.memory_utilization();
1449        assert!((utilization - 75.0).abs() < 1.0);
1450
1451        assert!(device.has_sufficient_memory(1_000_000_000)); // 1 GB
1452        assert!(!device.has_sufficient_memory(3_000_000_000)); // 3 GB
1453    }
1454
1455    #[test]
1456    fn test_backend_names() {
1457        assert_eq!(GpuBackend::Cuda.name(), "CUDA");
1458        assert_eq!(GpuBackend::Rocm.name(), "ROCm");
1459        assert_eq!(GpuBackend::DirectMl.name(), "DirectML");
1460        assert_eq!(GpuBackend::Vulkan.name(), "Vulkan");
1461        assert_eq!(GpuBackend::WebGpu.name(), "WebGPU");
1462        assert_eq!(GpuBackend::Metal.name(), "Metal");
1463        assert_eq!(GpuBackend::OpenCl.name(), "OpenCL");
1464    }
1465
1466    #[test]
1467    fn test_gpu_memory_stats() {
1468        let stats = GpuMemoryStats {
1469            allocated: 1000,
1470            peak: 1500,
1471            num_allocations: 10,
1472            num_deallocations: 3,
1473        };
1474
1475        assert_eq!(stats.current_usage(), 1000);
1476        assert_eq!(stats.active_allocations(), 7);
1477    }
1478
1479    #[test]
1480    fn test_backend_availability() {
1481        // Test that backend availability checks don't panic
1482        let _cuda_available = GpuBackend::Cuda.is_available();
1483        let _metal_available = GpuBackend::Metal.is_available();
1484        let _vulkan_available = GpuBackend::Vulkan.is_available();
1485        let _opencl_available = GpuBackend::OpenCl.is_available();
1486        let _rocm_available = GpuBackend::Rocm.is_available();
1487        let _directml_available = GpuBackend::DirectMl.is_available();
1488        let _webgpu_available = GpuBackend::WebGpu.is_available();
1489    }
1490
1491    #[test]
1492    fn test_list_devices() {
1493        // Test that list_devices doesn't panic (may return empty on systems without GPUs)
1494        let result = list_devices();
1495        assert!(result.is_ok());
1496        let devices = result.ok().unwrap_or_default();
1497
1498        // If devices are found, verify their properties
1499        for device in devices {
1500            assert!(!device.name.is_empty());
1501            // total_memory is always a valid u64 value
1502            let _ = device.total_memory;
1503        }
1504    }
1505
1506    #[test]
1507    fn test_select_device_no_gpu() {
1508        let config = GpuConfig::default();
1509        // This may fail on systems without GPU, which is expected
1510        let _result = select_device(&config);
1511    }
1512
1513    #[test]
1514    fn test_device_enumeration_without_features() {
1515        // Without features enabled, these should return empty vectors
1516        #[cfg(not(feature = "cuda"))]
1517        {
1518            let cuda_devices = enumerate_cuda_devices();
1519            assert!(cuda_devices.is_ok());
1520            assert!(cuda_devices.ok().is_none_or(|d| d.is_empty()));
1521        }
1522
1523        #[cfg(not(feature = "metal"))]
1524        {
1525            let metal_devices = enumerate_metal_devices();
1526            assert!(metal_devices.is_ok());
1527            assert!(metal_devices.ok().is_none_or(|d| d.is_empty()));
1528        }
1529
1530        #[cfg(not(feature = "vulkan"))]
1531        {
1532            let vulkan_devices = enumerate_vulkan_devices();
1533            assert!(vulkan_devices.is_ok());
1534            assert!(vulkan_devices.ok().is_none_or(|d| d.is_empty()));
1535        }
1536
1537        #[cfg(not(feature = "opencl"))]
1538        {
1539            let opencl_devices = enumerate_opencl_devices();
1540            assert!(opencl_devices.is_ok());
1541            assert!(opencl_devices.ok().is_none_or(|d| d.is_empty()));
1542        }
1543    }
1544
1545    #[test]
1546    #[cfg(all(feature = "metal", target_os = "macos"))]
1547    fn test_metal_enumeration() {
1548        // On macOS with Metal feature, we should be able to enumerate devices
1549        let devices = enumerate_metal_devices();
1550        assert!(devices.is_ok());
1551
1552        if let Ok(devs) = devices {
1553            if !devs.is_empty() {
1554                for device in devs {
1555                    assert!(!device.name.is_empty());
1556                    assert_eq!(device.backend, GpuBackend::Metal);
1557                }
1558            }
1559        }
1560    }
1561
1562    #[test]
1563    #[cfg(feature = "cuda")]
1564    fn test_cuda_detection() {
1565        // Test CUDA detection without panicking
1566        let available = check_cuda_available();
1567
1568        // If CUDA is available, try to enumerate devices
1569        if available {
1570            let devices = enumerate_cuda_devices();
1571            assert!(devices.is_ok());
1572
1573            if let Ok(devs) = devices {
1574                for device in devs {
1575                    assert_eq!(device.backend, GpuBackend::Cuda);
1576                }
1577            }
1578        }
1579    }
1580
1581    #[test]
1582    #[cfg(feature = "vulkan")]
1583    fn test_vulkan_detection() {
1584        // Test Vulkan detection without panicking
1585        let available = check_vulkan_available();
1586
1587        // If Vulkan is available, try to enumerate devices
1588        if available {
1589            let devices = enumerate_vulkan_devices();
1590            assert!(devices.is_ok());
1591
1592            if let Ok(devs) = devices {
1593                for device in devs {
1594                    assert_eq!(device.backend, GpuBackend::Vulkan);
1595                }
1596            }
1597        }
1598    }
1599
1600    #[test]
1601    #[cfg(feature = "opencl")]
1602    fn test_opencl_detection() {
1603        // Test OpenCL detection without panicking
1604        let available = check_opencl_available();
1605
1606        // If OpenCL is available, try to enumerate devices
1607        if available {
1608            let devices = enumerate_opencl_devices();
1609            assert!(devices.is_ok());
1610
1611            if let Ok(devs) = devices {
1612                for device in devs {
1613                    assert_eq!(device.backend, GpuBackend::OpenCl);
1614                }
1615            }
1616        }
1617    }
1618
1619    #[test]
1620    fn test_device_selection_with_backend_filter() {
1621        let devices = list_devices().ok().unwrap_or_default();
1622
1623        if !devices.is_empty() {
1624            let first_backend = devices[0].backend;
1625            let config = GpuConfig::builder().backend(first_backend).build();
1626
1627            let result = select_device(&config);
1628
1629            if result.is_ok() {
1630                let device = result.ok().unwrap_or_else(|| devices[0].clone());
1631                assert_eq!(device.backend, first_backend);
1632            }
1633        }
1634    }
1635
1636    #[test]
1637    fn test_zero_memory_device_utilization() {
1638        let device = GpuDevice {
1639            id: 0,
1640            name: "Zero Memory Device".to_string(),
1641            total_memory: 0,
1642            free_memory: 0,
1643            compute_capability: "0.0".to_string(),
1644            backend: GpuBackend::Cuda,
1645        };
1646
1647        assert_eq!(device.memory_utilization(), 0.0);
1648        assert!(!device.has_sufficient_memory(1));
1649    }
1650}