Skip to main content

entrenar/finetune/
device.rs

1//! Compute device detection and management
2//!
3//! Provides CUDA detection with automatic fallback to CPU.
4
5use std::fmt;
6
7/// Compute device for training
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ComputeDevice {
10    /// CPU-only execution
11    Cpu,
12    /// CUDA GPU with device ID
13    Cuda { device_id: usize },
14    /// wgpu GPU with adapter index (Vulkan/Metal/DX12)
15    Wgpu { adapter_index: u32 },
16}
17
18impl ComputeDevice {
19    /// Auto-detect best available device
20    ///
21    /// Priority: CUDA (if ≥6GB) > wgpu (if available) > CPU
22    #[must_use]
23    pub fn auto_detect() -> Self {
24        contract_pre_device_dispatch!();
25        if Self::cuda_available() {
26            if let Some(info) = DeviceInfo::cuda_info(0) {
27                if info.memory_gb >= 6.0 {
28                    return Self::Cuda { device_id: 0 };
29                }
30            }
31        }
32        if Self::wgpu_available() {
33            return Self::Wgpu { adapter_index: 0 };
34        }
35        Self::Cpu
36    }
37
38    /// Check if CUDA is available
39    #[must_use]
40    pub fn cuda_available() -> bool {
41        // Check for CUDA via environment and nvidia-smi
42        if std::env::var("CUDA_VISIBLE_DEVICES").is_ok() {
43            return true;
44        }
45
46        // Try nvidia-smi
47        std::process::Command::new("nvidia-smi")
48            .arg("--query-gpu=name")
49            .arg("--format=csv,noheader")
50            .output()
51            .map(|o| o.status.success())
52            .unwrap_or(false)
53    }
54
55    /// Check if wgpu GPU is available
56    #[must_use]
57    pub fn wgpu_available() -> bool {
58        #[cfg(feature = "gpu")]
59        {
60            trueno::backends::gpu::GpuDevice::is_available()
61        }
62        #[cfg(not(feature = "gpu"))]
63        {
64            false
65        }
66    }
67
68    /// Check if this device is CUDA
69    #[must_use]
70    pub const fn is_cuda(&self) -> bool {
71        matches!(self, Self::Cuda { .. })
72    }
73
74    /// Check if this device is CPU
75    #[must_use]
76    pub const fn is_cpu(&self) -> bool {
77        matches!(self, Self::Cpu)
78    }
79
80    /// Check if this device is wgpu
81    #[must_use]
82    pub const fn is_wgpu(&self) -> bool {
83        matches!(self, Self::Wgpu { .. })
84    }
85
86    /// Get device ID for CUDA devices
87    #[must_use]
88    pub const fn device_id(&self) -> Option<usize> {
89        match self {
90            Self::Cuda { device_id } => Some(*device_id),
91            Self::Cpu | Self::Wgpu { .. } => None,
92        }
93    }
94
95    /// Get adapter index for wgpu devices
96    #[must_use]
97    pub const fn adapter_index(&self) -> Option<u32> {
98        match self {
99            Self::Wgpu { adapter_index } => Some(*adapter_index),
100            Self::Cpu | Self::Cuda { .. } => None,
101        }
102    }
103
104    /// Detect all available compute devices on this machine.
105    ///
106    /// Enumerates all CUDA GPUs and wgpu adapters. Returns at least
107    /// one device (CPU fallback if no GPUs found).
108    ///
109    /// Used by distributed training to discover multi-GPU configurations.
110    #[must_use]
111    pub fn detect_all_devices() -> Vec<Self> {
112        let mut devices = Vec::new();
113
114        // Enumerate CUDA devices
115        if Self::cuda_available() {
116            let cuda_count = Self::cuda_device_count();
117            for i in 0..cuda_count {
118                if let Some(info) = DeviceInfo::cuda_info(i) {
119                    if info.memory_gb >= 4.0 {
120                        devices.push(Self::Cuda { device_id: i });
121                    }
122                }
123            }
124        }
125
126        // Enumerate wgpu adapters
127        #[cfg(feature = "gpu")]
128        {
129            let wgpu_count = Self::wgpu_adapter_count();
130            for i in 0..wgpu_count {
131                devices.push(Self::Wgpu { adapter_index: i as u32 });
132            }
133        }
134
135        // CPU fallback
136        if devices.is_empty() {
137            devices.push(Self::Cpu);
138        }
139
140        devices
141    }
142
143    /// Count CUDA devices via nvidia-smi.
144    fn cuda_device_count() -> usize {
145        std::process::Command::new("nvidia-smi")
146            .args(["--query-gpu=name", "--format=csv,noheader"])
147            .output()
148            .ok()
149            .filter(|o| o.status.success())
150            .map_or(0, |o| String::from_utf8_lossy(&o.stdout).lines().count())
151    }
152
153    /// Count wgpu adapters.
154    #[cfg(feature = "gpu")]
155    fn wgpu_adapter_count() -> usize {
156        // wgpu adapter enumeration would go here
157        // For now, return 0 as wgpu multi-adapter is not yet implemented
158        0
159    }
160}
161
162impl Default for ComputeDevice {
163    fn default() -> Self {
164        Self::auto_detect()
165    }
166}
167
168impl fmt::Display for ComputeDevice {
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        match self {
171            Self::Cpu => write!(f, "CPU"),
172            Self::Cuda { device_id } => write!(f, "CUDA:{device_id}"),
173            Self::Wgpu { adapter_index } => write!(f, "wgpu:{adapter_index}"),
174        }
175    }
176}
177
178/// Device information
179#[derive(Debug, Clone)]
180pub struct DeviceInfo {
181    /// Device name
182    pub name: String,
183    /// Total memory in GB
184    pub memory_gb: f64,
185    /// CUDA compute capability (major.minor)
186    pub compute_capability: Option<(u32, u32)>,
187    /// Driver version
188    pub driver_version: Option<String>,
189}
190
191impl DeviceInfo {
192    /// Get CPU info
193    #[must_use]
194    pub fn cpu_info() -> Self {
195        let num_cores =
196            std::thread::available_parallelism().map(std::num::NonZero::get).unwrap_or(1);
197
198        Self {
199            name: format!("CPU ({num_cores} cores)"),
200            memory_gb: Self::system_memory_gb(),
201            compute_capability: None,
202            driver_version: None,
203        }
204    }
205
206    /// Get CUDA device info
207    #[must_use]
208    pub fn cuda_info(device_id: usize) -> Option<Self> {
209        // Query nvidia-smi for device info
210        let output = std::process::Command::new("nvidia-smi")
211            .args([
212                "--query-gpu=name,memory.total,driver_version",
213                "--format=csv,noheader,nounits",
214                &format!("--id={device_id}"),
215            ])
216            .output()
217            .ok()?;
218
219        if !output.status.success() {
220            return None;
221        }
222
223        let stdout = String::from_utf8_lossy(&output.stdout);
224        let parts: Vec<&str> = stdout.trim().split(", ").collect();
225
226        if parts.len() >= 3 {
227            let name = parts[0].to_string();
228            let memory_mb: f64 = parts[1].parse().unwrap_or(0.0);
229            let driver = parts[2].to_string();
230
231            Some(Self {
232                name,
233                memory_gb: memory_mb / 1024.0,
234                compute_capability: None, // Would need CUDA runtime to get this
235                driver_version: Some(driver),
236            })
237        } else {
238            None
239        }
240    }
241
242    /// Get system RAM in GB
243    fn system_memory_gb() -> f64 {
244        // Read from /proc/meminfo on Linux
245        if let Ok(content) = std::fs::read_to_string("/proc/meminfo") {
246            for line in content.lines() {
247                if line.starts_with("MemTotal:") {
248                    let parts: Vec<&str> = line.split_whitespace().collect();
249                    if parts.len() >= 2 {
250                        if let Ok(kb) = parts[1].parse::<f64>() {
251                            return kb / 1024.0 / 1024.0;
252                        }
253                    }
254                }
255            }
256        }
257        16.0 // Default fallback
258    }
259
260    /// Check if device has sufficient memory for QLoRA
261    #[must_use]
262    pub fn sufficient_for_qlora(&self) -> bool {
263        self.memory_gb >= 6.0
264    }
265
266    /// Check if device has sufficient memory for LoRA (fp16)
267    #[must_use]
268    pub fn sufficient_for_lora(&self) -> bool {
269        self.memory_gb >= 12.0
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_compute_device_cpu() {
279        let device = ComputeDevice::Cpu;
280        assert!(device.is_cpu());
281        assert!(!device.is_cuda());
282        assert_eq!(device.device_id(), None);
283        assert_eq!(device.to_string(), "CPU");
284    }
285
286    #[test]
287    fn test_compute_device_cuda() {
288        let device = ComputeDevice::Cuda { device_id: 0 };
289        assert!(device.is_cuda());
290        assert!(!device.is_cpu());
291        assert!(!device.is_wgpu());
292        assert_eq!(device.device_id(), Some(0));
293        assert_eq!(device.to_string(), "CUDA:0");
294    }
295
296    #[test]
297    fn test_compute_device_wgpu() {
298        let device = ComputeDevice::Wgpu { adapter_index: 1 };
299        assert!(device.is_wgpu());
300        assert!(!device.is_cpu());
301        assert!(!device.is_cuda());
302        assert_eq!(device.adapter_index(), Some(1));
303        assert_eq!(device.device_id(), None);
304        assert_eq!(device.to_string(), "wgpu:1");
305    }
306
307    #[test]
308    fn test_auto_detect_returns_valid_device() {
309        let device = ComputeDevice::auto_detect();
310        // Should return either CPU or CUDA
311        assert!(device.is_cpu() || device.is_cuda());
312    }
313
314    #[test]
315    fn test_device_info_cpu() {
316        let info = DeviceInfo::cpu_info();
317        assert!(info.name.contains("CPU"));
318        assert!(info.memory_gb > 0.0);
319        assert!(info.compute_capability.is_none());
320    }
321
322    #[test]
323    fn test_device_default() {
324        let device = ComputeDevice::default();
325        // Should be valid
326        assert!(device.is_cpu() || device.is_cuda());
327    }
328
329    #[test]
330    fn test_detect_all_devices() {
331        let devices = ComputeDevice::detect_all_devices();
332        assert!(!devices.is_empty(), "must detect at least one device");
333        // First device should be the best available
334        // Last resort is CPU
335    }
336
337    #[test]
338    fn test_sufficient_memory_checks() {
339        let small = DeviceInfo {
340            name: "Small GPU".into(),
341            memory_gb: 4.0,
342            compute_capability: None,
343            driver_version: None,
344        };
345        assert!(!small.sufficient_for_qlora());
346        assert!(!small.sufficient_for_lora());
347
348        let medium = DeviceInfo {
349            name: "Medium GPU".into(),
350            memory_gb: 8.0,
351            compute_capability: None,
352            driver_version: None,
353        };
354        assert!(medium.sufficient_for_qlora());
355        assert!(!medium.sufficient_for_lora());
356
357        let large = DeviceInfo {
358            name: "Large GPU".into(),
359            memory_gb: 16.0,
360            compute_capability: None,
361            driver_version: None,
362        };
363        assert!(large.sufficient_for_qlora());
364        assert!(large.sufficient_for_lora());
365    }
366}