Skip to main content

entrenar/train/
device.rs

1//! `Device` — selector for the training backend (`apr pretrain`).
2//!
3//! Contract binding: `contracts/entrenar/gpu-training-backend-v1.yaml`
4//! §device_dispatch.
5//!
6//! The string grammar accepted by `resolve_device` is fixed by
7//! INV-GPUTRAIN-001 / §device_dispatch.requested_device.grammar:
8//!
9//! ```text
10//! ^(cpu|cuda(:[0-9]|:1[0-5])?|auto)$
11//! ```
12//!
13//! - `cpu`            — force the CPU (trueno SIMD) training path.
14//! - `cuda`           — alias for `cuda:0`.
15//! - `cuda:N` (0..=15)— explicit CUDA device index.
16//! - `auto`           — `cuda:0` if `cuda_training_available()`, else `cpu`.
17//!
18//! The `auto` resolution is NOT a silent fallback: callers are obliged
19//! by GATE-GPUTRAIN-002 to print the resolved `Device` before starting
20//! training so the operator sees which backend was actually selected.
21//!
22//! Explicit `cuda` / `cuda:N` on a host without a usable CUDA runtime
23//! MUST return `DeviceError::CudaNotAvailable`. FALSIFY-GPUTRAIN-002
24//! binds this invariant.
25
26use std::fmt;
27
28use crate::autograd::cuda_training_available;
29
30/// Training backend selection.
31#[derive(Copy, Clone, Debug, PartialEq, Eq)]
32pub enum Device {
33    /// CPU (trueno SIMD) — `TransformerTrainer`.
34    Cpu,
35    /// CUDA device `index` — `CudaTransformerTrainer`.
36    Cuda { index: u8 },
37}
38
39impl Device {
40    /// Short human-readable tag used in CLI banners and run-dir metadata.
41    #[must_use]
42    pub fn tag(&self) -> String {
43        match self {
44            Device::Cpu => "cpu".to_string(),
45            Device::Cuda { index } => format!("cuda:{index}"),
46        }
47    }
48
49    /// Is this device a CUDA device (any index)?
50    #[must_use]
51    pub fn is_cuda(&self) -> bool {
52        matches!(self, Device::Cuda { .. })
53    }
54}
55
56impl fmt::Display for Device {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        f.write_str(&self.tag())
59    }
60}
61
62/// Failure modes for `resolve_device`.
63#[derive(Clone, Debug, PartialEq, Eq)]
64pub enum DeviceError {
65    /// Input string did not match
66    /// `^(cpu|cuda(:[0-9]|:1[0-5])?|auto)$`.
67    InvalidSpec(String),
68    /// Caller explicitly requested CUDA (or `auto` resolved to CUDA on a
69    /// host advertising CUDA) but `cuda_training_available()` returned
70    /// false. GATE-GPUTRAIN-002 forbids silent CPU fallback on explicit
71    /// CUDA requests — this variant IS the hard failure.
72    CudaNotAvailable { requested: String },
73}
74
75impl fmt::Display for DeviceError {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        match self {
78            DeviceError::InvalidSpec(s) => write!(
79                f,
80                "--device `{s}` does not match grammar \
81                 ^(cpu|cuda(:[0-9]|:1[0-5])?|auto)$ \
82                 (contract gpu-training-backend-v1 INV-GPUTRAIN-001)",
83            ),
84            DeviceError::CudaNotAvailable { requested } => write!(
85                f,
86                "--device `{requested}` requested but CUDA runtime is \
87                 not available on this host \
88                 (contract gpu-training-backend-v1 GATE-GPUTRAIN-002: \
89                 no silent CPU fallback). Rebuild with `--features cuda` \
90                 or pass `--device cpu` to opt in to the CPU path.",
91            ),
92        }
93    }
94}
95
96impl std::error::Error for DeviceError {}
97
98/// Resolve a CLI `--device` string into a concrete `Device`.
99///
100/// Contract: this function is THE single binding point for
101/// INV-GPUTRAIN-001 (grammar) and GATE-GPUTRAIN-002 (no silent CPU
102/// fallback on explicit CUDA request).
103///
104/// # Errors
105/// - [`DeviceError::InvalidSpec`] — `spec` is not one of `cpu`,
106///   `cuda`, `cuda:N` (0..=15), or `auto`.
107/// - [`DeviceError::CudaNotAvailable`] — `spec` explicitly asked for
108///   CUDA (or `auto` chose CUDA) but `cuda_training_available()`
109///   returned `false`.
110pub fn resolve_device(spec: &str) -> Result<Device, DeviceError> {
111    let parsed =
112        parse_device_spec(spec).ok_or_else(|| DeviceError::InvalidSpec(spec.to_string()))?;
113
114    match parsed {
115        ParsedSpec::Cpu => Ok(Device::Cpu),
116        ParsedSpec::Cuda(index) => {
117            if cuda_training_available() {
118                Ok(Device::Cuda { index })
119            } else {
120                Err(DeviceError::CudaNotAvailable { requested: spec.to_string() })
121            }
122        }
123        ParsedSpec::Auto => {
124            if cuda_training_available() {
125                Ok(Device::Cuda { index: 0 })
126            } else {
127                Ok(Device::Cpu)
128            }
129        }
130    }
131}
132
133/// Pure-function parser: string → `ParsedSpec`. Separated from the
134/// availability probe so FALSIFY-GPUTRAIN-001 (grammar) can be
135/// exercised deterministically regardless of whether the host has CUDA.
136#[derive(Copy, Clone, Debug, PartialEq, Eq)]
137enum ParsedSpec {
138    Cpu,
139    Cuda(u8),
140    Auto,
141}
142
143fn parse_device_spec(spec: &str) -> Option<ParsedSpec> {
144    match spec {
145        "cpu" => Some(ParsedSpec::Cpu),
146        "auto" => Some(ParsedSpec::Auto),
147        "cuda" => Some(ParsedSpec::Cuda(0)),
148        other => {
149            let rest = other.strip_prefix("cuda:")?;
150            // Grammar `:[0-9]|:1[0-5]` — one digit 0-9 OR "1" then 0-5.
151            // `u8::from_str` rejects leading zeros ("cuda:01") by parsing
152            // them, but the grammar does not: "01" is NOT in
153            // `[0-9]|1[0-5]`. We therefore reject any multi-char string
154            // whose first char is `0` or whose value is outside [0, 15].
155            let idx: u8 = rest.parse().ok()?;
156            if idx > 15 {
157                return None;
158            }
159            // Reject leading-zero spellings that happen to parse
160            // (e.g. "cuda:01"). Grammar allows only 1-2 chars AND
161            // 2-char forms must start with '1'.
162            match rest.len() {
163                1 => {}
164                2 if rest.starts_with('1') => {}
165                _ => return None,
166            }
167            Some(ParsedSpec::Cuda(idx))
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    // ─── FALSIFY-GPUTRAIN-001: grammar ──────────────────────────────────
177    //
178    // Binds contract `gpu-training-backend-v1` INV-GPUTRAIN-001. Any
179    // string that does NOT match
180    // `^(cpu|cuda(:[0-9]|:1[0-5])?|auto)$` MUST be rejected with
181    // `DeviceError::InvalidSpec`; any string that DOES match MUST parse.
182
183    #[test]
184    fn falsify_gputrain_001_accepts_cpu() {
185        assert_eq!(parse_device_spec("cpu"), Some(ParsedSpec::Cpu));
186    }
187
188    #[test]
189    fn falsify_gputrain_001_accepts_auto() {
190        assert_eq!(parse_device_spec("auto"), Some(ParsedSpec::Auto));
191    }
192
193    #[test]
194    fn falsify_gputrain_001_accepts_cuda_alias() {
195        assert_eq!(parse_device_spec("cuda"), Some(ParsedSpec::Cuda(0)));
196    }
197
198    #[test]
199    fn falsify_gputrain_001_accepts_cuda_single_digit() {
200        for i in 0..=9u8 {
201            let spec = format!("cuda:{i}");
202            assert_eq!(
203                parse_device_spec(&spec),
204                Some(ParsedSpec::Cuda(i)),
205                "grammar must accept {spec}",
206            );
207        }
208    }
209
210    #[test]
211    fn falsify_gputrain_001_accepts_cuda_10_through_15() {
212        for i in 10..=15u8 {
213            let spec = format!("cuda:{i}");
214            assert_eq!(
215                parse_device_spec(&spec),
216                Some(ParsedSpec::Cuda(i)),
217                "grammar must accept {spec}",
218            );
219        }
220    }
221
222    #[test]
223    fn falsify_gputrain_001_rejects_index_16() {
224        assert_eq!(parse_device_spec("cuda:16"), None);
225    }
226
227    #[test]
228    fn falsify_gputrain_001_rejects_index_99() {
229        assert_eq!(parse_device_spec("cuda:99"), None);
230    }
231
232    #[test]
233    fn falsify_gputrain_001_rejects_leading_zero() {
234        // Grammar allows one digit [0-9] or two chars 1[0-5]; "01"
235        // matches neither.
236        assert_eq!(parse_device_spec("cuda:01"), None);
237    }
238
239    #[test]
240    fn falsify_gputrain_001_rejects_empty_index() {
241        assert_eq!(parse_device_spec("cuda:"), None);
242    }
243
244    #[test]
245    fn falsify_gputrain_001_rejects_negative_index() {
246        assert_eq!(parse_device_spec("cuda:-1"), None);
247    }
248
249    #[test]
250    fn falsify_gputrain_001_rejects_typo() {
251        assert_eq!(parse_device_spec("gpu"), None);
252        assert_eq!(parse_device_spec("CUDA"), None);
253        assert_eq!(parse_device_spec("cudaa"), None);
254        assert_eq!(parse_device_spec(""), None);
255        assert_eq!(parse_device_spec(" cpu"), None);
256    }
257
258    #[test]
259    fn falsify_gputrain_001_resolve_wraps_invalid_as_device_error() {
260        let err = resolve_device("gpu").unwrap_err();
261        assert!(matches!(err, DeviceError::InvalidSpec(ref s) if s == "gpu"));
262    }
263
264    // ─── FALSIFY-GPUTRAIN-002: no silent CPU fallback ──────────────────
265    //
266    // Binds contract `gpu-training-backend-v1` INV-GPUTRAIN-002 /
267    // GATE-GPUTRAIN-002. Explicit `--device cuda` / `cuda:N` MUST hard-
268    // fail when the host has no CUDA runtime. `auto` is the ONLY spec
269    // allowed to fall back.
270
271    #[test]
272    fn falsify_gputrain_002_explicit_cuda_without_runtime_errors() {
273        if cuda_training_available() {
274            // On a CUDA host this branch is a positive assertion:
275            // explicit `cuda:0` must resolve successfully, and `auto`
276            // must choose CUDA (not silently downgrade).
277            assert_eq!(resolve_device("cuda:0"), Ok(Device::Cuda { index: 0 }));
278            assert_eq!(resolve_device("auto"), Ok(Device::Cuda { index: 0 }));
279        } else {
280            // On a CPU-only host:
281            // - explicit `cuda:0` MUST hard-fail (no silent fallback)
282            // - explicit `cuda` MUST hard-fail (alias for `cuda:0`)
283            // - `auto` MAY fall back to CPU (this is the documented
284            //   safe-default escape hatch)
285            let err = resolve_device("cuda:0").unwrap_err();
286            assert!(matches!(err, DeviceError::CudaNotAvailable { .. }));
287            let err = resolve_device("cuda").unwrap_err();
288            assert!(matches!(err, DeviceError::CudaNotAvailable { .. }));
289            assert_eq!(resolve_device("auto"), Ok(Device::Cpu));
290        }
291    }
292
293    #[test]
294    fn falsify_gputrain_002_cpu_always_resolves() {
295        // `--device cpu` must always return `Device::Cpu`, regardless of
296        // whether CUDA is available — it is an explicit opt-in to the
297        // CPU path (for falsification parity runs, reproducibility, or
298        // hosts without a usable GPU).
299        assert_eq!(resolve_device("cpu"), Ok(Device::Cpu));
300    }
301
302    #[test]
303    fn device_tag_round_trips() {
304        assert_eq!(Device::Cpu.tag(), "cpu");
305        assert_eq!(Device::Cuda { index: 0 }.tag(), "cuda:0");
306        assert_eq!(Device::Cuda { index: 7 }.tag(), "cuda:7");
307        assert_eq!(Device::Cuda { index: 15 }.tag(), "cuda:15");
308    }
309
310    #[test]
311    fn device_is_cuda_discriminator() {
312        assert!(!Device::Cpu.is_cuda());
313        assert!(Device::Cuda { index: 0 }.is_cuda());
314    }
315
316    #[test]
317    fn device_error_display_mentions_contract() {
318        let invalid = DeviceError::InvalidSpec("bogus".into()).to_string();
319        assert!(invalid.contains("INV-GPUTRAIN-001"));
320        assert!(invalid.contains("bogus"));
321        let unavail = DeviceError::CudaNotAvailable { requested: "cuda:0".into() }.to_string();
322        assert!(unavail.contains("GATE-GPUTRAIN-002"));
323        assert!(unavail.contains("cuda:0"));
324    }
325}