1use std::fmt;
27
28use crate::autograd::cuda_training_available;
29
30#[derive(Copy, Clone, Debug, PartialEq, Eq)]
32pub enum Device {
33 Cpu,
35 Cuda { index: u8 },
37}
38
39impl Device {
40 #[must_use]
42 pub fn tag(&self) -> String {
43 match self {
44 Device::Cpu => "cpu".to_string(),
45 Device::Cuda { index } => format!("cuda:{index}"),
46 }
47 }
48
49 #[must_use]
51 pub fn is_cuda(&self) -> bool {
52 matches!(self, Device::Cuda { .. })
53 }
54}
55
56impl fmt::Display for Device {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 f.write_str(&self.tag())
59 }
60}
61
62#[derive(Clone, Debug, PartialEq, Eq)]
64pub enum DeviceError {
65 InvalidSpec(String),
68 CudaNotAvailable { requested: String },
73}
74
75impl fmt::Display for DeviceError {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match self {
78 DeviceError::InvalidSpec(s) => write!(
79 f,
80 "--device `{s}` does not match grammar \
81 ^(cpu|cuda(:[0-9]|:1[0-5])?|auto)$ \
82 (contract gpu-training-backend-v1 INV-GPUTRAIN-001)",
83 ),
84 DeviceError::CudaNotAvailable { requested } => write!(
85 f,
86 "--device `{requested}` requested but CUDA runtime is \
87 not available on this host \
88 (contract gpu-training-backend-v1 GATE-GPUTRAIN-002: \
89 no silent CPU fallback). Rebuild with `--features cuda` \
90 or pass `--device cpu` to opt in to the CPU path.",
91 ),
92 }
93 }
94}
95
96impl std::error::Error for DeviceError {}
97
98pub fn resolve_device(spec: &str) -> Result<Device, DeviceError> {
111 let parsed =
112 parse_device_spec(spec).ok_or_else(|| DeviceError::InvalidSpec(spec.to_string()))?;
113
114 match parsed {
115 ParsedSpec::Cpu => Ok(Device::Cpu),
116 ParsedSpec::Cuda(index) => {
117 if cuda_training_available() {
118 Ok(Device::Cuda { index })
119 } else {
120 Err(DeviceError::CudaNotAvailable { requested: spec.to_string() })
121 }
122 }
123 ParsedSpec::Auto => {
124 if cuda_training_available() {
125 Ok(Device::Cuda { index: 0 })
126 } else {
127 Ok(Device::Cpu)
128 }
129 }
130 }
131}
132
133#[derive(Copy, Clone, Debug, PartialEq, Eq)]
137enum ParsedSpec {
138 Cpu,
139 Cuda(u8),
140 Auto,
141}
142
143fn parse_device_spec(spec: &str) -> Option<ParsedSpec> {
144 match spec {
145 "cpu" => Some(ParsedSpec::Cpu),
146 "auto" => Some(ParsedSpec::Auto),
147 "cuda" => Some(ParsedSpec::Cuda(0)),
148 other => {
149 let rest = other.strip_prefix("cuda:")?;
150 let idx: u8 = rest.parse().ok()?;
156 if idx > 15 {
157 return None;
158 }
159 match rest.len() {
163 1 => {}
164 2 if rest.starts_with('1') => {}
165 _ => return None,
166 }
167 Some(ParsedSpec::Cuda(idx))
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
184 fn falsify_gputrain_001_accepts_cpu() {
185 assert_eq!(parse_device_spec("cpu"), Some(ParsedSpec::Cpu));
186 }
187
188 #[test]
189 fn falsify_gputrain_001_accepts_auto() {
190 assert_eq!(parse_device_spec("auto"), Some(ParsedSpec::Auto));
191 }
192
193 #[test]
194 fn falsify_gputrain_001_accepts_cuda_alias() {
195 assert_eq!(parse_device_spec("cuda"), Some(ParsedSpec::Cuda(0)));
196 }
197
198 #[test]
199 fn falsify_gputrain_001_accepts_cuda_single_digit() {
200 for i in 0..=9u8 {
201 let spec = format!("cuda:{i}");
202 assert_eq!(
203 parse_device_spec(&spec),
204 Some(ParsedSpec::Cuda(i)),
205 "grammar must accept {spec}",
206 );
207 }
208 }
209
210 #[test]
211 fn falsify_gputrain_001_accepts_cuda_10_through_15() {
212 for i in 10..=15u8 {
213 let spec = format!("cuda:{i}");
214 assert_eq!(
215 parse_device_spec(&spec),
216 Some(ParsedSpec::Cuda(i)),
217 "grammar must accept {spec}",
218 );
219 }
220 }
221
222 #[test]
223 fn falsify_gputrain_001_rejects_index_16() {
224 assert_eq!(parse_device_spec("cuda:16"), None);
225 }
226
227 #[test]
228 fn falsify_gputrain_001_rejects_index_99() {
229 assert_eq!(parse_device_spec("cuda:99"), None);
230 }
231
232 #[test]
233 fn falsify_gputrain_001_rejects_leading_zero() {
234 assert_eq!(parse_device_spec("cuda:01"), None);
237 }
238
239 #[test]
240 fn falsify_gputrain_001_rejects_empty_index() {
241 assert_eq!(parse_device_spec("cuda:"), None);
242 }
243
244 #[test]
245 fn falsify_gputrain_001_rejects_negative_index() {
246 assert_eq!(parse_device_spec("cuda:-1"), None);
247 }
248
249 #[test]
250 fn falsify_gputrain_001_rejects_typo() {
251 assert_eq!(parse_device_spec("gpu"), None);
252 assert_eq!(parse_device_spec("CUDA"), None);
253 assert_eq!(parse_device_spec("cudaa"), None);
254 assert_eq!(parse_device_spec(""), None);
255 assert_eq!(parse_device_spec(" cpu"), None);
256 }
257
258 #[test]
259 fn falsify_gputrain_001_resolve_wraps_invalid_as_device_error() {
260 let err = resolve_device("gpu").unwrap_err();
261 assert!(matches!(err, DeviceError::InvalidSpec(ref s) if s == "gpu"));
262 }
263
264 #[test]
272 fn falsify_gputrain_002_explicit_cuda_without_runtime_errors() {
273 if cuda_training_available() {
274 assert_eq!(resolve_device("cuda:0"), Ok(Device::Cuda { index: 0 }));
278 assert_eq!(resolve_device("auto"), Ok(Device::Cuda { index: 0 }));
279 } else {
280 let err = resolve_device("cuda:0").unwrap_err();
286 assert!(matches!(err, DeviceError::CudaNotAvailable { .. }));
287 let err = resolve_device("cuda").unwrap_err();
288 assert!(matches!(err, DeviceError::CudaNotAvailable { .. }));
289 assert_eq!(resolve_device("auto"), Ok(Device::Cpu));
290 }
291 }
292
293 #[test]
294 fn falsify_gputrain_002_cpu_always_resolves() {
295 assert_eq!(resolve_device("cpu"), Ok(Device::Cpu));
300 }
301
302 #[test]
303 fn device_tag_round_trips() {
304 assert_eq!(Device::Cpu.tag(), "cpu");
305 assert_eq!(Device::Cuda { index: 0 }.tag(), "cuda:0");
306 assert_eq!(Device::Cuda { index: 7 }.tag(), "cuda:7");
307 assert_eq!(Device::Cuda { index: 15 }.tag(), "cuda:15");
308 }
309
310 #[test]
311 fn device_is_cuda_discriminator() {
312 assert!(!Device::Cpu.is_cuda());
313 assert!(Device::Cuda { index: 0 }.is_cuda());
314 }
315
316 #[test]
317 fn device_error_display_mentions_contract() {
318 let invalid = DeviceError::InvalidSpec("bogus".into()).to_string();
319 assert!(invalid.contains("INV-GPUTRAIN-001"));
320 assert!(invalid.contains("bogus"));
321 let unavail = DeviceError::CudaNotAvailable { requested: "cuda:0".into() }.to_string();
322 assert!(unavail.contains("GATE-GPUTRAIN-002"));
323 assert!(unavail.contains("cuda:0"));
324 }
325}