1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//! Hardware abstraction and execution routing identifiers.
use crate::errors::{EtensorError, EtensorResult};
/// Identifies the hardware execution context where a tensor's memory physically resides.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Device {
/// Standard host memory (RAM) executed via generic CPU backends (ndarray/Rayon).
Cpu,
/// Pure Rust PTX JIT/AOT execution via the `cudarc` driver.
/// The inner value represents the ordinal GPU index (e.g., 0 for "cuda:0").
#[cfg(feature = "cuda-native")]
CudaNative(usize),
/// Enterprise C++ PyTorch bindings execution via `tch-rs`.
/// The inner value represents the ordinal GPU index (e.g., 0 for "cuda:0").
#[cfg(feature = "torch")]
CudaTorch(usize),
}
impl Device {
/// Returns true if the device represents host CPU memory.
pub fn is_cpu(&self) -> bool {
matches!(self, Device::Cpu)
}
/// Safely extracts the hardware index for multi-GPU setups.
///
/// Returns an error if called on a CPU device, enforcing strict boundary checks
/// before passing layout configurations to hardware-specific allocators.
pub fn index(&self) -> EtensorResult<usize> {
match self {
Device::Cpu => Err(EtensorError::InternalError(
"Attempted to extract a GPU index from a CPU device marker.".to_string(),
)),
#[cfg(feature = "cuda-native")]
Device::CudaNative(idx) => Ok(*idx),
#[cfg(feature = "torch")]
Device::CudaTorch(idx) => Ok(*idx),
}
}
}
impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Device::Cpu => write!(f, "cpu"),
#[cfg(feature = "cuda-native")]
Device::CudaNative(idx) => write!(f, "cuda_native:{}", idx),
#[cfg(feature = "torch")]
Device::CudaTorch(idx) => write!(f, "cuda_torch:{}", idx),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_device_properties() {
let dev = Device::Cpu;
assert!(dev.is_cpu());
assert_eq!(dev.to_string(), "cpu");
// Attempting to extract an index from a CPU should trigger our boundary shield
let idx_result = dev.index();
assert!(idx_result.is_err());
if let Err(EtensorError::InternalError(msg)) = idx_result {
assert!(msg.contains("CPU device marker"));
} else {
panic!("Expected InternalError for CPU device index extraction!");
}
}
// These tests will be silently skipped by `cargo test` unless you explicitly
// run it with `--features cuda-native` or `--features torch`.
#[cfg(feature = "cuda-native")]
#[test]
fn test_cuda_native_properties() {
let dev = Device::CudaNative(1); // Simulating GPU 1
assert!(!dev.is_cpu());
assert_eq!(dev.index().unwrap(), 1);
assert_eq!(dev.to_string(), "cuda_native:1");
}
#[cfg(feature = "torch")]
#[test]
fn test_cuda_torch_properties() {
let dev = Device::CudaTorch(0); // Simulating GPU 0
assert!(!dev.is_cpu());
assert_eq!(dev.index().unwrap(), 0);
assert_eq!(dev.to_string(), "cuda_torch:0");
}
}