oxify_connect_vision/
gpu.rs

1//! GPU detection and management utilities.
2//!
3//! This module provides runtime GPU detection for CUDA and CoreML,
4//! automatic fallback to CPU when GPU is unavailable, and GPU memory
5//! management helpers.
6
7use std::sync::OnceLock;
8
9/// GPU availability information.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct GpuInfo {
12    /// CUDA is available
13    pub cuda_available: bool,
14    /// CoreML is available
15    pub coreml_available: bool,
16    /// Available GPU memory in bytes (if detectable)
17    pub gpu_memory_bytes: Option<u64>,
18}
19
20impl GpuInfo {
21    /// Detect GPU availability at runtime.
22    pub fn detect() -> Self {
23        Self {
24            cuda_available: Self::detect_cuda(),
25            coreml_available: Self::detect_coreml(),
26            gpu_memory_bytes: Self::detect_gpu_memory(),
27        }
28    }
29
30    /// Check if any GPU is available.
31    pub fn has_gpu(&self) -> bool {
32        self.cuda_available || self.coreml_available
33    }
34
35    /// Get recommended execution provider based on availability.
36    pub fn recommended_provider(&self) -> GpuProvider {
37        if self.cuda_available {
38            GpuProvider::Cuda
39        } else if self.coreml_available {
40            GpuProvider::CoreMl
41        } else {
42            GpuProvider::Cpu
43        }
44    }
45
46    /// Detect CUDA availability at runtime.
47    fn detect_cuda() -> bool {
48        #[cfg(feature = "cuda")]
49        {
50            // Try multiple detection methods
51
52            // Method 1: Check nvidia-smi command
53            if let Ok(output) = std::process::Command::new("nvidia-smi")
54                .arg("--query-gpu=name")
55                .arg("--format=csv,noheader")
56                .output()
57            {
58                if output.status.success() && !output.stdout.is_empty() {
59                    return true;
60                }
61            }
62
63            // Method 2: Check for CUDA library files
64            #[cfg(target_os = "linux")]
65            {
66                if std::path::Path::new("/usr/local/cuda/lib64/libcudart.so").exists()
67                    || std::path::Path::new("/usr/lib/x86_64-linux-gnu/libcudart.so").exists()
68                {
69                    return true;
70                }
71            }
72
73            #[cfg(target_os = "windows")]
74            {
75                // Check Windows CUDA paths
76                if let Ok(cuda_path) = std::env::var("CUDA_PATH") {
77                    let dll_path = std::path::Path::new(&cuda_path)
78                        .join("bin")
79                        .join("cudart64_110.dll");
80                    if dll_path.exists() {
81                        return true;
82                    }
83                }
84            }
85
86            // Method 3: Try to detect via environment variables
87            if std::env::var("CUDA_VISIBLE_DEVICES").is_ok() {
88                return true;
89            }
90
91            false
92        }
93        #[cfg(not(feature = "cuda"))]
94        {
95            false
96        }
97    }
98
99    /// Detect CoreML availability at runtime.
100    fn detect_coreml() -> bool {
101        #[cfg(feature = "coreml")]
102        {
103            #[cfg(target_os = "macos")]
104            {
105                // Check macOS version - CoreML requires 10.13+
106                if let Ok(output) = std::process::Command::new("sw_vers")
107                    .arg("-productVersion")
108                    .output()
109                {
110                    if let Ok(version_str) = String::from_utf8(output.stdout) {
111                        if let Some(major) = version_str.split('.').next() {
112                            if let Ok(major_num) = major.trim().parse::<u32>() {
113                                // macOS 10.13+ or macOS 11+ (Big Sur changed versioning)
114                                return major_num >= 11 || major_num == 10;
115                            }
116                        }
117                    }
118                }
119                // Assume available if we can't determine version
120                true
121            }
122            #[cfg(not(target_os = "macos"))]
123            {
124                false
125            }
126        }
127        #[cfg(not(feature = "coreml"))]
128        {
129            false
130        }
131    }
132
133    /// Detect available GPU memory (CUDA only).
134    fn detect_gpu_memory() -> Option<u64> {
135        #[cfg(feature = "cuda")]
136        {
137            if let Ok(output) = std::process::Command::new("nvidia-smi")
138                .arg("--query-gpu=memory.total")
139                .arg("--format=csv,noheader,nounits")
140                .output()
141            {
142                if output.status.success() {
143                    if let Ok(mem_str) = String::from_utf8(output.stdout) {
144                        if let Ok(mem_mb) = mem_str.trim().parse::<u64>() {
145                            return Some(mem_mb * 1024 * 1024); // Convert MB to bytes
146                        }
147                    }
148                }
149            }
150        }
151        None
152    }
153
154    /// Get cached GPU information (detected once per process).
155    pub fn cached() -> &'static GpuInfo {
156        static GPU_INFO: OnceLock<GpuInfo> = OnceLock::new();
157        GPU_INFO.get_or_init(GpuInfo::detect)
158    }
159}
160
161/// GPU execution provider types.
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
163pub enum GpuProvider {
164    /// CPU execution (fallback)
165    Cpu,
166    /// CUDA GPU acceleration
167    Cuda,
168    /// CoreML acceleration (macOS)
169    CoreMl,
170}
171
172impl GpuProvider {
173    /// Get the provider name as a string.
174    pub fn name(&self) -> &'static str {
175        match self {
176            GpuProvider::Cpu => "CPU",
177            GpuProvider::Cuda => "CUDA",
178            GpuProvider::CoreMl => "CoreML",
179        }
180    }
181
182    /// Check if this is a GPU provider (not CPU).
183    pub fn is_gpu(&self) -> bool {
184        matches!(self, GpuProvider::Cuda | GpuProvider::CoreMl)
185    }
186}
187
188/// GPU configuration for ONNX providers.
189#[derive(Debug, Clone)]
190pub struct GpuConfig {
191    /// Requested GPU provider (may fall back to CPU if unavailable)
192    pub requested: GpuProvider,
193    /// Actually used provider (after fallback check)
194    pub actual: GpuProvider,
195    /// GPU device ID (for multi-GPU systems)
196    pub device_id: u32,
197}
198
199impl GpuConfig {
200    /// Create a new GPU configuration with automatic fallback.
201    ///
202    /// If the requested GPU is not available, falls back to CPU and logs a warning.
203    pub fn new(use_gpu: bool) -> Self {
204        let gpu_info = GpuInfo::cached();
205
206        let requested = if use_gpu {
207            gpu_info.recommended_provider()
208        } else {
209            GpuProvider::Cpu
210        };
211
212        let actual = match requested {
213            GpuProvider::Cuda if !gpu_info.cuda_available => {
214                tracing::warn!("CUDA requested but not available, falling back to CPU");
215                GpuProvider::Cpu
216            }
217            GpuProvider::CoreMl if !gpu_info.coreml_available => {
218                tracing::warn!("CoreML requested but not available, falling back to CPU");
219                GpuProvider::Cpu
220            }
221            provider => provider,
222        };
223
224        Self {
225            requested,
226            actual,
227            device_id: 0,
228        }
229    }
230
231    /// Create a CPU-only configuration.
232    pub fn cpu() -> Self {
233        Self {
234            requested: GpuProvider::Cpu,
235            actual: GpuProvider::Cpu,
236            device_id: 0,
237        }
238    }
239
240    /// Create a CUDA configuration with fallback.
241    pub fn cuda(device_id: u32) -> Self {
242        let mut config = Self::new(true);
243        config.device_id = device_id;
244
245        // Force CUDA if requested
246        if config.actual == GpuProvider::Cpu && GpuInfo::cached().cuda_available {
247            config.actual = GpuProvider::Cuda;
248        }
249
250        config
251    }
252
253    /// Create a CoreML configuration with fallback.
254    pub fn coreml() -> Self {
255        let mut config = Self::new(true);
256
257        // Force CoreML if requested
258        if config.actual == GpuProvider::Cpu && GpuInfo::cached().coreml_available {
259            config.actual = GpuProvider::CoreMl;
260        }
261
262        config
263    }
264
265    /// Check if GPU was successfully configured (not fallen back to CPU).
266    pub fn is_using_gpu(&self) -> bool {
267        self.actual.is_gpu()
268    }
269
270    /// Get a description of the configuration.
271    pub fn description(&self) -> String {
272        if self.requested == self.actual {
273            format!("Using {}", self.actual.name())
274        } else {
275            format!(
276                "Requested {} but using {} (fallback)",
277                self.requested.name(),
278                self.actual.name()
279            )
280        }
281    }
282}
283
284impl Default for GpuConfig {
285    fn default() -> Self {
286        Self::new(false)
287    }
288}
289
290/// Check if sufficient GPU memory is available for a model.
291///
292/// Returns true if:
293/// - CPU is being used (no memory limit)
294/// - GPU has sufficient memory
295/// - Memory detection failed (assume sufficient)
296#[allow(dead_code)]
297pub fn check_gpu_memory(required_mb: u64) -> bool {
298    let gpu_info = GpuInfo::cached();
299
300    // CPU has no GPU memory limits
301    if !gpu_info.has_gpu() {
302        return true;
303    }
304
305    // If we can't detect memory, assume it's sufficient
306    let Some(total_bytes) = gpu_info.gpu_memory_bytes else {
307        return true;
308    };
309
310    let required_bytes = required_mb * 1024 * 1024;
311    let available_bytes = total_bytes * 8 / 10; // Use 80% as threshold
312
313    available_bytes >= required_bytes
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_gpu_info_detect() {
322        let info = GpuInfo::detect();
323        // Just ensure it doesn't panic
324        let _ = info.has_gpu();
325        let _ = info.recommended_provider();
326    }
327
328    #[test]
329    fn test_gpu_info_cached() {
330        let info1 = GpuInfo::cached();
331        let info2 = GpuInfo::cached();
332        // Should be the same instance
333        assert_eq!(info1, info2);
334    }
335
336    #[test]
337    fn test_gpu_provider_name() {
338        assert_eq!(GpuProvider::Cpu.name(), "CPU");
339        assert_eq!(GpuProvider::Cuda.name(), "CUDA");
340        assert_eq!(GpuProvider::CoreMl.name(), "CoreML");
341    }
342
343    #[test]
344    fn test_gpu_provider_is_gpu() {
345        assert!(!GpuProvider::Cpu.is_gpu());
346        assert!(GpuProvider::Cuda.is_gpu());
347        assert!(GpuProvider::CoreMl.is_gpu());
348    }
349
350    #[test]
351    fn test_gpu_config_cpu() {
352        let config = GpuConfig::cpu();
353        assert_eq!(config.requested, GpuProvider::Cpu);
354        assert_eq!(config.actual, GpuProvider::Cpu);
355        assert!(!config.is_using_gpu());
356    }
357
358    #[test]
359    fn test_gpu_config_default() {
360        let config = GpuConfig::default();
361        assert_eq!(config.requested, GpuProvider::Cpu);
362        assert_eq!(config.actual, GpuProvider::Cpu);
363    }
364
365    #[test]
366    fn test_gpu_config_new() {
367        // Test CPU-only config
368        let config = GpuConfig::new(false);
369        assert_eq!(config.actual, GpuProvider::Cpu);
370
371        // Test GPU config (may fall back to CPU)
372        let config = GpuConfig::new(true);
373        // Should be valid configuration
374        assert!(config.requested == config.actual || config.actual == GpuProvider::Cpu);
375    }
376
377    #[test]
378    fn test_gpu_config_description() {
379        let config = GpuConfig::cpu();
380        let desc = config.description();
381        assert!(desc.contains("CPU"));
382    }
383
384    #[test]
385    fn test_check_gpu_memory() {
386        // Should not panic
387        let _ = check_gpu_memory(1000);
388        let _ = check_gpu_memory(100000);
389    }
390
391    #[test]
392    fn test_gpu_config_cuda() {
393        let config = GpuConfig::cuda(0);
394        assert_eq!(config.device_id, 0);
395        // GpuConfig::cuda() requests CUDA, but actual provider depends on availability:
396        // - Linux with CUDA: Cuda or CPU fallback
397        // - macOS: CoreML or CPU fallback (CUDA not available)
398        // - Windows: CPU fallback
399        // The key assertion is that 'actual' is a valid GPU provider
400        assert!(
401            config.actual == GpuProvider::Cuda
402                || config.actual == GpuProvider::CoreMl
403                || config.actual == GpuProvider::Cpu
404        );
405    }
406
407    #[test]
408    #[cfg(target_os = "macos")]
409    fn test_gpu_config_coreml() {
410        let config = GpuConfig::coreml();
411        // Should either be CoreML or CPU (fallback)
412        assert!(config.actual == GpuProvider::CoreMl || config.actual == GpuProvider::Cpu);
413    }
414}