entrenar/finetune/
device.rs1use std::fmt;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ComputeDevice {
10 Cpu,
12 Cuda { device_id: usize },
14 Wgpu { adapter_index: u32 },
16}
17
18impl ComputeDevice {
19 #[must_use]
23 pub fn auto_detect() -> Self {
24 contract_pre_device_dispatch!();
25 if Self::cuda_available() {
26 if let Some(info) = DeviceInfo::cuda_info(0) {
27 if info.memory_gb >= 6.0 {
28 return Self::Cuda { device_id: 0 };
29 }
30 }
31 }
32 if Self::wgpu_available() {
33 return Self::Wgpu { adapter_index: 0 };
34 }
35 Self::Cpu
36 }
37
38 #[must_use]
40 pub fn cuda_available() -> bool {
41 if std::env::var("CUDA_VISIBLE_DEVICES").is_ok() {
43 return true;
44 }
45
46 std::process::Command::new("nvidia-smi")
48 .arg("--query-gpu=name")
49 .arg("--format=csv,noheader")
50 .output()
51 .map(|o| o.status.success())
52 .unwrap_or(false)
53 }
54
55 #[must_use]
57 pub fn wgpu_available() -> bool {
58 #[cfg(feature = "gpu")]
59 {
60 trueno::backends::gpu::GpuDevice::is_available()
61 }
62 #[cfg(not(feature = "gpu"))]
63 {
64 false
65 }
66 }
67
68 #[must_use]
70 pub const fn is_cuda(&self) -> bool {
71 matches!(self, Self::Cuda { .. })
72 }
73
74 #[must_use]
76 pub const fn is_cpu(&self) -> bool {
77 matches!(self, Self::Cpu)
78 }
79
80 #[must_use]
82 pub const fn is_wgpu(&self) -> bool {
83 matches!(self, Self::Wgpu { .. })
84 }
85
86 #[must_use]
88 pub const fn device_id(&self) -> Option<usize> {
89 match self {
90 Self::Cuda { device_id } => Some(*device_id),
91 Self::Cpu | Self::Wgpu { .. } => None,
92 }
93 }
94
95 #[must_use]
97 pub const fn adapter_index(&self) -> Option<u32> {
98 match self {
99 Self::Wgpu { adapter_index } => Some(*adapter_index),
100 Self::Cpu | Self::Cuda { .. } => None,
101 }
102 }
103
104 #[must_use]
111 pub fn detect_all_devices() -> Vec<Self> {
112 let mut devices = Vec::new();
113
114 if Self::cuda_available() {
116 let cuda_count = Self::cuda_device_count();
117 for i in 0..cuda_count {
118 if let Some(info) = DeviceInfo::cuda_info(i) {
119 if info.memory_gb >= 4.0 {
120 devices.push(Self::Cuda { device_id: i });
121 }
122 }
123 }
124 }
125
126 #[cfg(feature = "gpu")]
128 {
129 let wgpu_count = Self::wgpu_adapter_count();
130 for i in 0..wgpu_count {
131 devices.push(Self::Wgpu { adapter_index: i as u32 });
132 }
133 }
134
135 if devices.is_empty() {
137 devices.push(Self::Cpu);
138 }
139
140 devices
141 }
142
143 fn cuda_device_count() -> usize {
145 std::process::Command::new("nvidia-smi")
146 .args(["--query-gpu=name", "--format=csv,noheader"])
147 .output()
148 .ok()
149 .filter(|o| o.status.success())
150 .map_or(0, |o| String::from_utf8_lossy(&o.stdout).lines().count())
151 }
152
153 #[cfg(feature = "gpu")]
155 fn wgpu_adapter_count() -> usize {
156 0
159 }
160}
161
162impl Default for ComputeDevice {
163 fn default() -> Self {
164 Self::auto_detect()
165 }
166}
167
168impl fmt::Display for ComputeDevice {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 match self {
171 Self::Cpu => write!(f, "CPU"),
172 Self::Cuda { device_id } => write!(f, "CUDA:{device_id}"),
173 Self::Wgpu { adapter_index } => write!(f, "wgpu:{adapter_index}"),
174 }
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct DeviceInfo {
181 pub name: String,
183 pub memory_gb: f64,
185 pub compute_capability: Option<(u32, u32)>,
187 pub driver_version: Option<String>,
189}
190
191impl DeviceInfo {
192 #[must_use]
194 pub fn cpu_info() -> Self {
195 let num_cores =
196 std::thread::available_parallelism().map(std::num::NonZero::get).unwrap_or(1);
197
198 Self {
199 name: format!("CPU ({num_cores} cores)"),
200 memory_gb: Self::system_memory_gb(),
201 compute_capability: None,
202 driver_version: None,
203 }
204 }
205
206 #[must_use]
208 pub fn cuda_info(device_id: usize) -> Option<Self> {
209 let output = std::process::Command::new("nvidia-smi")
211 .args([
212 "--query-gpu=name,memory.total,driver_version",
213 "--format=csv,noheader,nounits",
214 &format!("--id={device_id}"),
215 ])
216 .output()
217 .ok()?;
218
219 if !output.status.success() {
220 return None;
221 }
222
223 let stdout = String::from_utf8_lossy(&output.stdout);
224 let parts: Vec<&str> = stdout.trim().split(", ").collect();
225
226 if parts.len() >= 3 {
227 let name = parts[0].to_string();
228 let memory_mb: f64 = parts[1].parse().unwrap_or(0.0);
229 let driver = parts[2].to_string();
230
231 Some(Self {
232 name,
233 memory_gb: memory_mb / 1024.0,
234 compute_capability: None, driver_version: Some(driver),
236 })
237 } else {
238 None
239 }
240 }
241
242 fn system_memory_gb() -> f64 {
244 if let Ok(content) = std::fs::read_to_string("/proc/meminfo") {
246 for line in content.lines() {
247 if line.starts_with("MemTotal:") {
248 let parts: Vec<&str> = line.split_whitespace().collect();
249 if parts.len() >= 2 {
250 if let Ok(kb) = parts[1].parse::<f64>() {
251 return kb / 1024.0 / 1024.0;
252 }
253 }
254 }
255 }
256 }
257 16.0 }
259
260 #[must_use]
262 pub fn sufficient_for_qlora(&self) -> bool {
263 self.memory_gb >= 6.0
264 }
265
266 #[must_use]
268 pub fn sufficient_for_lora(&self) -> bool {
269 self.memory_gb >= 12.0
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_compute_device_cpu() {
279 let device = ComputeDevice::Cpu;
280 assert!(device.is_cpu());
281 assert!(!device.is_cuda());
282 assert_eq!(device.device_id(), None);
283 assert_eq!(device.to_string(), "CPU");
284 }
285
286 #[test]
287 fn test_compute_device_cuda() {
288 let device = ComputeDevice::Cuda { device_id: 0 };
289 assert!(device.is_cuda());
290 assert!(!device.is_cpu());
291 assert!(!device.is_wgpu());
292 assert_eq!(device.device_id(), Some(0));
293 assert_eq!(device.to_string(), "CUDA:0");
294 }
295
296 #[test]
297 fn test_compute_device_wgpu() {
298 let device = ComputeDevice::Wgpu { adapter_index: 1 };
299 assert!(device.is_wgpu());
300 assert!(!device.is_cpu());
301 assert!(!device.is_cuda());
302 assert_eq!(device.adapter_index(), Some(1));
303 assert_eq!(device.device_id(), None);
304 assert_eq!(device.to_string(), "wgpu:1");
305 }
306
307 #[test]
308 fn test_auto_detect_returns_valid_device() {
309 let device = ComputeDevice::auto_detect();
310 assert!(device.is_cpu() || device.is_cuda());
312 }
313
314 #[test]
315 fn test_device_info_cpu() {
316 let info = DeviceInfo::cpu_info();
317 assert!(info.name.contains("CPU"));
318 assert!(info.memory_gb > 0.0);
319 assert!(info.compute_capability.is_none());
320 }
321
322 #[test]
323 fn test_device_default() {
324 let device = ComputeDevice::default();
325 assert!(device.is_cpu() || device.is_cuda());
327 }
328
329 #[test]
330 fn test_detect_all_devices() {
331 let devices = ComputeDevice::detect_all_devices();
332 assert!(!devices.is_empty(), "must detect at least one device");
333 }
336
337 #[test]
338 fn test_sufficient_memory_checks() {
339 let small = DeviceInfo {
340 name: "Small GPU".into(),
341 memory_gb: 4.0,
342 compute_capability: None,
343 driver_version: None,
344 };
345 assert!(!small.sufficient_for_qlora());
346 assert!(!small.sufficient_for_lora());
347
348 let medium = DeviceInfo {
349 name: "Medium GPU".into(),
350 memory_gb: 8.0,
351 compute_capability: None,
352 driver_version: None,
353 };
354 assert!(medium.sufficient_for_qlora());
355 assert!(!medium.sufficient_for_lora());
356
357 let large = DeviceInfo {
358 name: "Large GPU".into(),
359 memory_gb: 16.0,
360 compute_capability: None,
361 driver_version: None,
362 };
363 assert!(large.sufficient_for_qlora());
364 assert!(large.sufficient_for_lora());
365 }
366}