burn_tch/
backend.rs

1use crate::{QuantElement, TchQTensor};
2
3use super::element::TchElement;
4use super::TchTensor;
5use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
6use burn_tensor::ops::IntTensorOps;
7use burn_tensor::{Int, Tensor};
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10/// The device struct when using the `tch` backend.
11///
12/// Note that you need to provide the device index when using Cuda.
13///
14/// # Example
15///
16/// ```no_run
17/// use burn_tch::LibTorchDevice;
18///
19/// let device_gpu_1 = LibTorchDevice::Cuda(0); // First GPU
20/// let device_gpu_2 = LibTorchDevice::Cuda(1); // Second GPU
21/// let device_cpu = LibTorchDevice::Cpu; // CPU
22/// let device_mps = LibTorchDevice::Mps; // Metal Performance Shaders
23/// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan
24/// ```
25pub enum LibTorchDevice {
26    /// CPU device.
27    Cpu,
28
29    /// Cuda device with the given index. The index is the index of the Cuda device in the list of
30    /// all Cuda devices found on the system.
31    Cuda(usize),
32
33    /// Metal Performance Shaders device.
34    Mps,
35
36    /// Vulkan device.
37    Vulkan,
38}
39
40impl From<LibTorchDevice> for tch::Device {
41    fn from(device: LibTorchDevice) -> Self {
42        match device {
43            LibTorchDevice::Cpu => tch::Device::Cpu,
44            LibTorchDevice::Cuda(num) => tch::Device::Cuda(num),
45            LibTorchDevice::Mps => tch::Device::Mps,
46            LibTorchDevice::Vulkan => tch::Device::Vulkan,
47        }
48    }
49}
50
51impl From<tch::Device> for LibTorchDevice {
52    fn from(device: tch::Device) -> Self {
53        match device {
54            tch::Device::Cpu => LibTorchDevice::Cpu,
55            tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
56            tch::Device::Mps => LibTorchDevice::Mps,
57            tch::Device::Vulkan => LibTorchDevice::Vulkan,
58        }
59    }
60}
61
62impl DeviceOps for LibTorchDevice {
63    fn id(&self) -> burn_tensor::backend::DeviceId {
64        match self {
65            LibTorchDevice::Cpu => DeviceId::new(0, 0),
66            LibTorchDevice::Cuda(index) => DeviceId::new(1, *index as u32),
67            LibTorchDevice::Mps => DeviceId::new(2, 0),
68            LibTorchDevice::Vulkan => DeviceId::new(3, 0),
69        }
70    }
71}
72
73impl Default for LibTorchDevice {
74    fn default() -> Self {
75        Self::Cpu
76    }
77}
78
79/// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations.
80///
81/// This backend is compatible with a wide range of hardwares ranging from CPUs to GPUs, but
82/// requires `LibTorch` to be installed correctly. The CPU version can be downloaded
83/// automatically and the CUDA version as well by setting the `TORCH_CUDA_VERSION` environment
84/// variable. For more complex configurations, check out the manual installation for
85/// [burn-tch](https://github.com/tracel-ai/burn/tree/main/burn-tch).
86///
87/// Refer to the [tch] crate for more information.
88#[derive(Clone, Copy, Default, Debug)]
89pub struct LibTorch<E = f32, Q = i8> {
90    _e: E,
91    _q: Q,
92}
93
94impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
95    type Device = LibTorchDevice;
96
97    type FloatTensorPrimitive = TchTensor;
98    type FloatElem = E;
99
100    type IntTensorPrimitive = TchTensor;
101    type IntElem = i64;
102
103    type BoolTensorPrimitive = TchTensor;
104    type BoolElem = bool;
105
106    type QuantizedTensorPrimitive = TchQTensor;
107    type QuantizedEncoding = Q;
108
109    fn seed(seed: u64) {
110        tch::manual_seed(seed as i64);
111    }
112
113    fn ad_enabled() -> bool {
114        false
115    }
116
117    fn name() -> String {
118        "tch".to_string()
119    }
120
121    fn sync(device: &Self::Device) {
122        match device {
123            LibTorchDevice::Cpu => (),
124            LibTorchDevice::Cuda(index) => {
125                tch::Cuda::synchronize(*index as i64);
126            }
127            _ => {
128                // When there is no explicit way to synchronize, we write and read one value to sync
129                Tensor::<Self, 1, Int>::from_primitive(<Self as IntTensorOps<Self>>::int_zeros(
130                    [1].into(),
131                    device,
132                ))
133                .into_data();
134            }
135        }
136    }
137}