burn_tch/
backend.rs

1use super::TchTensor;
2use super::element::TchElement;
3use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
4use burn_tensor::ops::IntTensorOps;
5use burn_tensor::{Int, Tensor};
6
7#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8/// The device struct when using the `tch` backend.
9///
10/// Note that you need to provide the device index when using Cuda.
11///
12/// # Example
13///
14/// ```no_run
15/// use burn_tch::LibTorchDevice;
16///
17/// let device_gpu_1 = LibTorchDevice::Cuda(0); // First GPU
18/// let device_gpu_2 = LibTorchDevice::Cuda(1); // Second GPU
19/// let device_cpu = LibTorchDevice::Cpu; // CPU
20/// let device_mps = LibTorchDevice::Mps; // Metal Performance Shaders
21/// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan
22/// ```
23#[derive(Default)]
24pub enum LibTorchDevice {
25    /// CPU device.
26    #[default]
27    Cpu,
28
29    /// Cuda device with the given index. The index is the index of the Cuda device in the list of
30    /// all Cuda devices found on the system.
31    Cuda(usize),
32
33    /// Metal Performance Shaders device.
34    Mps,
35
36    /// Vulkan device.
37    Vulkan,
38}
39
40impl From<LibTorchDevice> for tch::Device {
41    #[allow(
42        unreachable_code,
43        reason = "CUDA branch always panics if the library is missing"
44    )]
45    fn from(device: LibTorchDevice) -> Self {
46        match device {
47            LibTorchDevice::Cpu => tch::Device::Cpu,
48            LibTorchDevice::Cuda(_num) => {
49                include!(concat!(env!("OUT_DIR"), "/tch_gpu_check.rs"));
50                tch::Device::Cuda(_num)
51            }
52            LibTorchDevice::Mps => tch::Device::Mps,
53            LibTorchDevice::Vulkan => tch::Device::Vulkan,
54        }
55    }
56}
57
58impl From<tch::Device> for LibTorchDevice {
59    fn from(device: tch::Device) -> Self {
60        match device {
61            tch::Device::Cpu => LibTorchDevice::Cpu,
62            tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
63            tch::Device::Mps => LibTorchDevice::Mps,
64            tch::Device::Vulkan => LibTorchDevice::Vulkan,
65        }
66    }
67}
68
69impl burn_common::device::Device for LibTorchDevice {
70    fn from_id(device_id: DeviceId) -> Self {
71        match device_id.type_id {
72            0 => Self::Cuda(device_id.index_id as usize),
73            1 => Self::Mps,
74            2 => Self::Cpu,
75            3 => Self::Vulkan,
76            _ => LibTorchDevice::Cpu,
77        }
78    }
79
80    fn to_id(&self) -> DeviceId {
81        match self {
82            LibTorchDevice::Cuda(index) => DeviceId::new(0, *index as u32),
83            LibTorchDevice::Mps => DeviceId::new(1, 0),
84            LibTorchDevice::Cpu => DeviceId::new(2, 0),
85            LibTorchDevice::Vulkan => DeviceId::new(3, 0),
86        }
87    }
88
89    fn device_count(_type_id: u16) -> usize {
90        // TODO: Somehow find the info using the tch API.
91        1
92    }
93}
94
95impl DeviceOps for LibTorchDevice {}
96
97/// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations.
98///
99/// This backend is compatible with a wide range of hardwares ranging from CPUs to GPUs, but
100/// requires `LibTorch` to be installed correctly. The CPU version can be downloaded
101/// automatically and the CUDA version as well by setting the `TORCH_CUDA_VERSION` environment
102/// variable. For more complex configurations, check out the manual installation for
103/// [burn-tch](https://github.com/tracel-ai/burn/tree/main/crates/burn-tch).
104///
105/// Refer to the [tch] crate for more information.
106#[derive(Clone, Copy, Default, Debug)]
107pub struct LibTorch<E = f32> {
108    _e: E,
109}
110
111impl<E: TchElement> Backend for LibTorch<E> {
112    type Device = LibTorchDevice;
113
114    type FloatTensorPrimitive = TchTensor;
115    type FloatElem = E;
116
117    type IntTensorPrimitive = TchTensor;
118    type IntElem = i64;
119
120    type BoolTensorPrimitive = TchTensor;
121    type BoolElem = bool;
122
123    type QuantizedTensorPrimitive = TchTensor;
124
125    fn seed(_device: &Self::Device, seed: u64) {
126        tch::manual_seed(seed as i64);
127    }
128
129    fn ad_enabled() -> bool {
130        false
131    }
132
133    fn name(device: &Self::Device) -> String {
134        match device {
135            LibTorchDevice::Cpu => "libtorch<cpu>",
136            LibTorchDevice::Cuda(_) => "libtorch<cuda>",
137            LibTorchDevice::Mps => "libtorch<metal>",
138            LibTorchDevice::Vulkan => "libtorch<vulkan>",
139        }
140        .to_string()
141    }
142
143    fn sync(device: &Self::Device) {
144        match device {
145            LibTorchDevice::Cpu => (),
146            LibTorchDevice::Cuda(index) => {
147                tch::Cuda::synchronize(*index as i64);
148            }
149            _ => {
150                // When there is no explicit way to synchronize, we write and read one value to sync
151                Tensor::<Self, 1, Int>::from_primitive(<Self as IntTensorOps<Self>>::int_zeros(
152                    [1].into(),
153                    device,
154                    E::dtype().into(),
155                ))
156                .into_data();
157            }
158        }
159    }
160}