1use std::marker::PhantomData;
2
3use crate::IntoKind;
4
5use super::TchTensor;
6use super::element::TchElement;
7use burn_backend::backend::{Backend, BackendTypes, DeviceId, DeviceOps, ExecutionError};
8use burn_backend::ops::IntTensorOps;
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
11#[derive(Default)]
27pub enum LibTorchDevice {
28 #[default]
30 Cpu,
31
32 Cuda(usize),
35
36 Mps,
38
39 Vulkan,
41}
42
43impl From<LibTorchDevice> for tch::Device {
44 #[allow(
45 unreachable_code,
46 reason = "CUDA branch always panics if the library is missing"
47 )]
48 fn from(device: LibTorchDevice) -> Self {
49 match device {
50 LibTorchDevice::Cpu => tch::Device::Cpu,
51 LibTorchDevice::Cuda(_num) => {
52 include!(concat!(env!("OUT_DIR"), "/tch_gpu_check.rs"));
53 tch::Device::Cuda(_num)
54 }
55 LibTorchDevice::Mps => tch::Device::Mps,
56 LibTorchDevice::Vulkan => tch::Device::Vulkan,
57 }
58 }
59}
60
61impl From<tch::Device> for LibTorchDevice {
62 fn from(device: tch::Device) -> Self {
63 match device {
64 tch::Device::Cpu => LibTorchDevice::Cpu,
65 tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
66 tch::Device::Mps => LibTorchDevice::Mps,
67 tch::Device::Vulkan => LibTorchDevice::Vulkan,
68 }
69 }
70}
71
72impl burn_backend::Device for LibTorchDevice {
73 fn from_id(device_id: DeviceId) -> Self {
74 match device_id.type_id {
75 0 => Self::Cuda(device_id.index_id as usize),
76 1 => Self::Mps,
77 2 => Self::Cpu,
78 3 => Self::Vulkan,
79 _ => LibTorchDevice::Cpu,
80 }
81 }
82
83 fn to_id(&self) -> DeviceId {
84 match self {
85 LibTorchDevice::Cuda(index) => DeviceId::new(0, *index as u16),
86 LibTorchDevice::Mps => DeviceId::new(1, 0),
87 LibTorchDevice::Cpu => DeviceId::new(2, 0),
88 LibTorchDevice::Vulkan => DeviceId::new(3, 0),
89 }
90 }
91}
92
93impl DeviceOps for LibTorchDevice {}
94
95#[derive(Clone, Copy, Default, Debug)]
105pub struct LibTorch<E = f32> {
106 _e: PhantomData<E>,
107}
108
109impl<E: TchElement> BackendTypes for LibTorch<E> {
110 type Device = LibTorchDevice;
111
112 type FloatTensorPrimitive = TchTensor;
113 type FloatElem = E;
114
115 type IntTensorPrimitive = TchTensor;
116 type IntElem = i64;
117
118 type BoolTensorPrimitive = TchTensor;
119 type BoolElem = bool;
120
121 type QuantizedTensorPrimitive = TchTensor;
122}
123
124impl<E: TchElement> Backend for LibTorch<E> {
125 fn seed(_device: &Self::Device, seed: u64) {
126 tch::manual_seed(seed as i64);
127 }
128
129 fn ad_enabled(_device: &Self::Device) -> bool {
130 false
131 }
132
133 fn name(device: &Self::Device) -> String {
134 match device {
135 LibTorchDevice::Cpu => "libtorch<cpu>",
136 LibTorchDevice::Cuda(_) => "libtorch<cuda>",
137 LibTorchDevice::Mps => "libtorch<metal>",
138 LibTorchDevice::Vulkan => "libtorch<vulkan>",
139 }
140 .to_string()
141 }
142
143 fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
144 match device {
145 LibTorchDevice::Cpu => (),
146 LibTorchDevice::Cuda(index) => {
147 tch::Cuda::synchronize(*index as i64);
148 }
149 _ => {
150 burn_backend::read_sync(Self::int_into_data(Self::int_zeros(
152 [1].into(),
153 device,
154 <Self::IntElem as burn_backend::Element>::dtype().into(),
155 )))
156 .unwrap();
157 }
158 };
159
160 Ok(())
161 }
162
163 fn dtype_usage(
164 _device: &Self::Device,
165 dtype: burn_backend::DType,
166 ) -> burn_backend::DTypeUsageSet {
167 if dtype.try_into_kind().is_ok() {
168 burn_backend::DTypeUsage::general()
169 } else {
170 burn_backend::DTypeUsageSet::empty()
171 }
172 }
173
174 fn device_count(_: u16) -> usize {
175 1
177 }
178}