1#[derive(Clone, Debug, PartialEq, Eq)]
5pub enum TorchDevice {
6 Cuda(usize),
7 Other(String),
8}
9
10impl TorchDevice {
11 pub fn is_cuda(&self) -> bool {
12 matches!(self, TorchDevice::Cuda(_))
13 }
14
15 pub fn cuda_device_index(&self) -> Option<usize> {
16 match self {
17 TorchDevice::Cuda(index) => Some(*index),
18 TorchDevice::Other(_) => None,
19 }
20 }
21}
22
23pub trait TorchTensor: std::fmt::Debug + Send + Sync {
24 fn device(&self) -> TorchDevice;
25 fn data_ptr(&self) -> u64;
26 fn size_bytes(&self) -> usize;
27 fn shape(&self) -> Vec<usize>;
28 fn stride(&self) -> Vec<usize>;
29}