use std::fmt;
use crate::autograd::cuda_training_available;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum Device {
Cpu,
Cuda { index: u8 },
}
impl Device {
#[must_use]
pub fn tag(&self) -> String {
match self {
Device::Cpu => "cpu".to_string(),
Device::Cuda { index } => format!("cuda:{index}"),
}
}
#[must_use]
pub fn is_cuda(&self) -> bool {
matches!(self, Device::Cuda { .. })
}
}
impl fmt::Display for Device {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.tag())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DeviceError {
InvalidSpec(String),
CudaNotAvailable { requested: String },
}
impl fmt::Display for DeviceError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DeviceError::InvalidSpec(s) => write!(
f,
"--device `{s}` does not match grammar \
^(cpu|cuda(:[0-9]|:1[0-5])?|auto)$ \
(contract gpu-training-backend-v1 INV-GPUTRAIN-001)",
),
DeviceError::CudaNotAvailable { requested } => write!(
f,
"--device `{requested}` requested but CUDA runtime is \
not available on this host \
(contract gpu-training-backend-v1 GATE-GPUTRAIN-002: \
no silent CPU fallback). Rebuild with `--features cuda` \
or pass `--device cpu` to opt in to the CPU path.",
),
}
}
}
impl std::error::Error for DeviceError {}
pub fn resolve_device(spec: &str) -> Result<Device, DeviceError> {
let parsed =
parse_device_spec(spec).ok_or_else(|| DeviceError::InvalidSpec(spec.to_string()))?;
match parsed {
ParsedSpec::Cpu => Ok(Device::Cpu),
ParsedSpec::Cuda(index) => {
if cuda_training_available() {
Ok(Device::Cuda { index })
} else {
Err(DeviceError::CudaNotAvailable { requested: spec.to_string() })
}
}
ParsedSpec::Auto => {
if cuda_training_available() {
Ok(Device::Cuda { index: 0 })
} else {
Ok(Device::Cpu)
}
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum ParsedSpec {
Cpu,
Cuda(u8),
Auto,
}
fn parse_device_spec(spec: &str) -> Option<ParsedSpec> {
match spec {
"cpu" => Some(ParsedSpec::Cpu),
"auto" => Some(ParsedSpec::Auto),
"cuda" => Some(ParsedSpec::Cuda(0)),
other => {
let rest = other.strip_prefix("cuda:")?;
let idx: u8 = rest.parse().ok()?;
if idx > 15 {
return None;
}
match rest.len() {
1 => {}
2 if rest.starts_with('1') => {}
_ => return None,
}
Some(ParsedSpec::Cuda(idx))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn falsify_gputrain_001_accepts_cpu() {
assert_eq!(parse_device_spec("cpu"), Some(ParsedSpec::Cpu));
}
#[test]
fn falsify_gputrain_001_accepts_auto() {
assert_eq!(parse_device_spec("auto"), Some(ParsedSpec::Auto));
}
#[test]
fn falsify_gputrain_001_accepts_cuda_alias() {
assert_eq!(parse_device_spec("cuda"), Some(ParsedSpec::Cuda(0)));
}
#[test]
fn falsify_gputrain_001_accepts_cuda_single_digit() {
for i in 0..=9u8 {
let spec = format!("cuda:{i}");
assert_eq!(
parse_device_spec(&spec),
Some(ParsedSpec::Cuda(i)),
"grammar must accept {spec}",
);
}
}
#[test]
fn falsify_gputrain_001_accepts_cuda_10_through_15() {
for i in 10..=15u8 {
let spec = format!("cuda:{i}");
assert_eq!(
parse_device_spec(&spec),
Some(ParsedSpec::Cuda(i)),
"grammar must accept {spec}",
);
}
}
#[test]
fn falsify_gputrain_001_rejects_index_16() {
assert_eq!(parse_device_spec("cuda:16"), None);
}
#[test]
fn falsify_gputrain_001_rejects_index_99() {
assert_eq!(parse_device_spec("cuda:99"), None);
}
#[test]
fn falsify_gputrain_001_rejects_leading_zero() {
assert_eq!(parse_device_spec("cuda:01"), None);
}
#[test]
fn falsify_gputrain_001_rejects_empty_index() {
assert_eq!(parse_device_spec("cuda:"), None);
}
#[test]
fn falsify_gputrain_001_rejects_negative_index() {
assert_eq!(parse_device_spec("cuda:-1"), None);
}
#[test]
fn falsify_gputrain_001_rejects_typo() {
assert_eq!(parse_device_spec("gpu"), None);
assert_eq!(parse_device_spec("CUDA"), None);
assert_eq!(parse_device_spec("cudaa"), None);
assert_eq!(parse_device_spec(""), None);
assert_eq!(parse_device_spec(" cpu"), None);
}
#[test]
fn falsify_gputrain_001_resolve_wraps_invalid_as_device_error() {
let err = resolve_device("gpu").unwrap_err();
assert!(matches!(err, DeviceError::InvalidSpec(ref s) if s == "gpu"));
}
#[test]
fn falsify_gputrain_002_explicit_cuda_without_runtime_errors() {
if cuda_training_available() {
assert_eq!(resolve_device("cuda:0"), Ok(Device::Cuda { index: 0 }));
assert_eq!(resolve_device("auto"), Ok(Device::Cuda { index: 0 }));
} else {
let err = resolve_device("cuda:0").unwrap_err();
assert!(matches!(err, DeviceError::CudaNotAvailable { .. }));
let err = resolve_device("cuda").unwrap_err();
assert!(matches!(err, DeviceError::CudaNotAvailable { .. }));
assert_eq!(resolve_device("auto"), Ok(Device::Cpu));
}
}
#[test]
fn falsify_gputrain_002_cpu_always_resolves() {
assert_eq!(resolve_device("cpu"), Ok(Device::Cpu));
}
#[test]
fn device_tag_round_trips() {
assert_eq!(Device::Cpu.tag(), "cpu");
assert_eq!(Device::Cuda { index: 0 }.tag(), "cuda:0");
assert_eq!(Device::Cuda { index: 7 }.tag(), "cuda:7");
assert_eq!(Device::Cuda { index: 15 }.tag(), "cuda:15");
}
#[test]
fn device_is_cuda_discriminator() {
assert!(!Device::Cpu.is_cuda());
assert!(Device::Cuda { index: 0 }.is_cuda());
}
#[test]
fn device_error_display_mentions_contract() {
let invalid = DeviceError::InvalidSpec("bogus".into()).to_string();
assert!(invalid.contains("INV-GPUTRAIN-001"));
assert!(invalid.contains("bogus"));
let unavail = DeviceError::CudaNotAvailable { requested: "cuda:0".into() }.to_string();
assert!(unavail.contains("GATE-GPUTRAIN-002"));
assert!(unavail.contains("cuda:0"));
}
}