Skip to main content

dynamo_memory/
torch.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4#[derive(Clone, Debug, PartialEq, Eq)]
5pub enum TorchDevice {
6    Cuda(usize),
7    Other(String),
8}
9
10impl TorchDevice {
11    pub fn is_cuda(&self) -> bool {
12        matches!(self, TorchDevice::Cuda(_))
13    }
14
15    pub fn cuda_device_index(&self) -> Option<usize> {
16        match self {
17            TorchDevice::Cuda(index) => Some(*index),
18            TorchDevice::Other(_) => None,
19        }
20    }
21}
22
23pub trait TorchTensor: std::fmt::Debug + Send + Sync {
24    fn device(&self) -> TorchDevice;
25    fn data_ptr(&self) -> u64;
26    fn size_bytes(&self) -> usize;
27    fn shape(&self) -> Vec<usize>;
28    fn stride(&self) -> Vec<usize>;
29}