1use crate::errors::{EtensorError, EtensorResult};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7pub enum Device {
8 Cpu,
10
11 #[cfg(feature = "cuda-native")]
14 CudaNative(usize),
15
16 #[cfg(feature = "torch")]
19 CudaTorch(usize),
20}
21
22impl Device {
23 pub fn is_cpu(&self) -> bool {
25 matches!(self, Device::Cpu)
26 }
27
28 pub fn index(&self) -> EtensorResult<usize> {
33 match self {
34 Device::Cpu => Err(EtensorError::InternalError(
35 "Attempted to extract a GPU index from a CPU device marker.".to_string(),
36 )),
37 #[cfg(feature = "cuda-native")]
38 Device::CudaNative(idx) => Ok(*idx),
39 #[cfg(feature = "torch")]
40 Device::CudaTorch(idx) => Ok(*idx),
41 }
42 }
43}
44
45impl std::fmt::Display for Device {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 Device::Cpu => write!(f, "cpu"),
49 #[cfg(feature = "cuda-native")]
50 Device::CudaNative(idx) => write!(f, "cuda_native:{}", idx),
51 #[cfg(feature = "torch")]
52 Device::CudaTorch(idx) => write!(f, "cuda_torch:{}", idx),
53 }
54 }
55}
56
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61
62 #[test]
63 fn test_cpu_device_properties() {
64 let dev = Device::Cpu;
65
66 assert!(dev.is_cpu());
67 assert_eq!(dev.to_string(), "cpu");
68
69 let idx_result = dev.index();
71 assert!(idx_result.is_err());
72
73 if let Err(EtensorError::InternalError(msg)) = idx_result {
74 assert!(msg.contains("CPU device marker"));
75 } else {
76 panic!("Expected InternalError for CPU device index extraction!");
77 }
78 }
79
80 #[cfg(feature = "cuda-native")]
84 #[test]
85 fn test_cuda_native_properties() {
86 let dev = Device::CudaNative(1); assert!(!dev.is_cpu());
88 assert_eq!(dev.index().unwrap(), 1);
89 assert_eq!(dev.to_string(), "cuda_native:1");
90 }
91
92 #[cfg(feature = "torch")]
93 #[test]
94 fn test_cuda_torch_properties() {
95 let dev = Device::CudaTorch(0); assert!(!dev.is_cpu());
97 assert_eq!(dev.index().unwrap(), 0);
98 assert_eq!(dev.to_string(), "cuda_torch:0");
99 }
100}