Skip to main content

etensor_core/
device.rs

1//! Hardware abstraction and execution routing identifiers.
2
3use crate::errors::{EtensorError, EtensorResult};
4
5/// Identifies the hardware execution context where a tensor's memory physically resides.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7pub enum Device {
8    /// Standard host memory (RAM) executed via generic CPU backends (ndarray/Rayon).
9    Cpu,
10    
11    /// Pure Rust PTX JIT/AOT execution via the `cudarc` driver.
12    /// The inner value represents the ordinal GPU index (e.g., 0 for "cuda:0").
13    #[cfg(feature = "cuda-native")]
14    CudaNative(usize),
15
16    /// Enterprise C++ PyTorch bindings execution via `tch-rs`.
17    /// The inner value represents the ordinal GPU index (e.g., 0 for "cuda:0").
18    #[cfg(feature = "torch")]
19    CudaTorch(usize),
20}
21
22impl Device {
23    /// Returns true if the device represents host CPU memory.
24    pub fn is_cpu(&self) -> bool {
25        matches!(self, Device::Cpu)
26    }
27
28    /// Safely extracts the hardware index for multi-GPU setups.
29    /// 
30    /// Returns an error if called on a CPU device, enforcing strict boundary checks
31    /// before passing layout configurations to hardware-specific allocators.
32    pub fn index(&self) -> EtensorResult<usize> {
33        match self {
34            Device::Cpu => Err(EtensorError::InternalError(
35                "Attempted to extract a GPU index from a CPU device marker.".to_string(),
36            )),
37            #[cfg(feature = "cuda-native")]
38            Device::CudaNative(idx) => Ok(*idx),
39            #[cfg(feature = "torch")]
40            Device::CudaTorch(idx) => Ok(*idx),
41        }
42    }
43}
44
45impl std::fmt::Display for Device {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            Device::Cpu => write!(f, "cpu"),
49            #[cfg(feature = "cuda-native")]
50            Device::CudaNative(idx) => write!(f, "cuda_native:{}", idx),
51            #[cfg(feature = "torch")]
52            Device::CudaTorch(idx) => write!(f, "cuda_torch:{}", idx),
53        }
54    }
55}
56
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    #[test]
63    fn test_cpu_device_properties() {
64        let dev = Device::Cpu;
65        
66        assert!(dev.is_cpu());
67        assert_eq!(dev.to_string(), "cpu");
68        
69        // Attempting to extract an index from a CPU should trigger our boundary shield
70        let idx_result = dev.index();
71        assert!(idx_result.is_err());
72        
73        if let Err(EtensorError::InternalError(msg)) = idx_result {
74            assert!(msg.contains("CPU device marker"));
75        } else {
76            panic!("Expected InternalError for CPU device index extraction!");
77        }
78    }
79
80    // These tests will be silently skipped by `cargo test` unless you explicitly 
81    // run it with `--features cuda-native` or `--features torch`.
82    
83    #[cfg(feature = "cuda-native")]
84    #[test]
85    fn test_cuda_native_properties() {
86        let dev = Device::CudaNative(1); // Simulating GPU 1
87        assert!(!dev.is_cpu());
88        assert_eq!(dev.index().unwrap(), 1);
89        assert_eq!(dev.to_string(), "cuda_native:1");
90    }
91
92    #[cfg(feature = "torch")]
93    #[test]
94    fn test_cuda_torch_properties() {
95        let dev = Device::CudaTorch(0); // Simulating GPU 0
96        assert!(!dev.is_cpu());
97        assert_eq!(dev.index().unwrap(), 0);
98        assert_eq!(dev.to_string(), "cuda_torch:0");
99    }
100}