Skip to main content

burn_tch/
backend.rs

1use std::marker::PhantomData;
2
3use crate::IntoKind;
4
5use super::TchTensor;
6use super::element::TchElement;
7use burn_backend::backend::{Backend, BackendTypes, DeviceId, DeviceOps, ExecutionError};
8use burn_backend::ops::IntTensorOps;
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11/// The device struct when using the `tch` backend.
12///
13/// Note that you need to provide the device index when using Cuda.
14///
15/// # Example
16///
17/// ```no_run
18/// use burn_tch::LibTorchDevice;
19///
20/// let device_gpu_1 = LibTorchDevice::Cuda(0); // First GPU
21/// let device_gpu_2 = LibTorchDevice::Cuda(1); // Second GPU
22/// let device_cpu = LibTorchDevice::Cpu; // CPU
23/// let device_mps = LibTorchDevice::Mps; // Metal Performance Shaders
24/// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan
25/// ```
26#[derive(Default)]
27pub enum LibTorchDevice {
28    /// CPU device.
29    #[default]
30    Cpu,
31
32    /// Cuda device with the given index. The index is the index of the Cuda device in the list of
33    /// all Cuda devices found on the system.
34    Cuda(usize),
35
36    /// Metal Performance Shaders device.
37    Mps,
38
39    /// Vulkan device.
40    Vulkan,
41}
42
43impl From<LibTorchDevice> for tch::Device {
44    #[allow(
45        unreachable_code,
46        reason = "CUDA branch always panics if the library is missing"
47    )]
48    fn from(device: LibTorchDevice) -> Self {
49        match device {
50            LibTorchDevice::Cpu => tch::Device::Cpu,
51            LibTorchDevice::Cuda(_num) => {
52                include!(concat!(env!("OUT_DIR"), "/tch_gpu_check.rs"));
53                tch::Device::Cuda(_num)
54            }
55            LibTorchDevice::Mps => tch::Device::Mps,
56            LibTorchDevice::Vulkan => tch::Device::Vulkan,
57        }
58    }
59}
60
61impl From<tch::Device> for LibTorchDevice {
62    fn from(device: tch::Device) -> Self {
63        match device {
64            tch::Device::Cpu => LibTorchDevice::Cpu,
65            tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
66            tch::Device::Mps => LibTorchDevice::Mps,
67            tch::Device::Vulkan => LibTorchDevice::Vulkan,
68        }
69    }
70}
71
72impl burn_backend::Device for LibTorchDevice {
73    fn from_id(device_id: DeviceId) -> Self {
74        match device_id.type_id {
75            0 => Self::Cuda(device_id.index_id as usize),
76            1 => Self::Mps,
77            2 => Self::Cpu,
78            3 => Self::Vulkan,
79            _ => LibTorchDevice::Cpu,
80        }
81    }
82
83    fn to_id(&self) -> DeviceId {
84        match self {
85            LibTorchDevice::Cuda(index) => DeviceId::new(0, *index as u16),
86            LibTorchDevice::Mps => DeviceId::new(1, 0),
87            LibTorchDevice::Cpu => DeviceId::new(2, 0),
88            LibTorchDevice::Vulkan => DeviceId::new(3, 0),
89        }
90    }
91}
92
93impl DeviceOps for LibTorchDevice {}
94
95/// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations.
96///
97/// This backend is compatible with a wide range of hardwares ranging from CPUs to GPUs, but
98/// requires `LibTorch` to be installed correctly. The CPU version can be downloaded
99/// automatically and the CUDA version as well by setting the `TORCH_CUDA_VERSION` environment
100/// variable. For more complex configurations, check out the manual installation for
101/// [burn-tch](https://github.com/tracel-ai/burn/tree/main/crates/burn-tch).
102///
103/// Refer to the [tch] crate for more information.
104#[derive(Clone, Copy, Default, Debug)]
105pub struct LibTorch<E = f32> {
106    _e: PhantomData<E>,
107}
108
109impl<E: TchElement> BackendTypes for LibTorch<E> {
110    type Device = LibTorchDevice;
111
112    type FloatTensorPrimitive = TchTensor;
113    type FloatElem = E;
114
115    type IntTensorPrimitive = TchTensor;
116    type IntElem = i64;
117
118    type BoolTensorPrimitive = TchTensor;
119    type BoolElem = bool;
120
121    type QuantizedTensorPrimitive = TchTensor;
122}
123
124impl<E: TchElement> Backend for LibTorch<E> {
125    fn seed(_device: &Self::Device, seed: u64) {
126        tch::manual_seed(seed as i64);
127    }
128
129    fn ad_enabled(_device: &Self::Device) -> bool {
130        false
131    }
132
133    fn name(device: &Self::Device) -> String {
134        match device {
135            LibTorchDevice::Cpu => "libtorch<cpu>",
136            LibTorchDevice::Cuda(_) => "libtorch<cuda>",
137            LibTorchDevice::Mps => "libtorch<metal>",
138            LibTorchDevice::Vulkan => "libtorch<vulkan>",
139        }
140        .to_string()
141    }
142
143    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
144        match device {
145            LibTorchDevice::Cpu => (),
146            LibTorchDevice::Cuda(index) => {
147                tch::Cuda::synchronize(*index as i64);
148            }
149            _ => {
150                // When there is no explicit way to synchronize, we write and read one value to sync
151                burn_backend::read_sync(Self::int_into_data(Self::int_zeros(
152                    [1].into(),
153                    device,
154                    <Self::IntElem as burn_backend::Element>::dtype().into(),
155                )))
156                .unwrap();
157            }
158        };
159
160        Ok(())
161    }
162
163    fn dtype_usage(
164        _device: &Self::Device,
165        dtype: burn_backend::DType,
166    ) -> burn_backend::DTypeUsageSet {
167        if dtype.try_into_kind().is_ok() {
168            burn_backend::DTypeUsage::general()
169        } else {
170            burn_backend::DTypeUsageSet::empty()
171        }
172    }
173
174    fn device_count(_: u16) -> usize {
175        // tch only supports one device for each backend
176        1
177    }
178}