1use super::TchTensor;
2use super::element::TchElement;
3use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
4use burn_tensor::ops::IntTensorOps;
5use burn_tensor::{Int, Tensor};
6
7#[derive(Clone, Copy, Debug, PartialEq, Eq)]
8#[derive(Default)]
24pub enum LibTorchDevice {
25 #[default]
27 Cpu,
28
29 Cuda(usize),
32
33 Mps,
35
36 Vulkan,
38}
39
40impl From<LibTorchDevice> for tch::Device {
41 #[allow(
42 unreachable_code,
43 reason = "CUDA branch always panics if the library is missing"
44 )]
45 fn from(device: LibTorchDevice) -> Self {
46 match device {
47 LibTorchDevice::Cpu => tch::Device::Cpu,
48 LibTorchDevice::Cuda(_num) => {
49 include!(concat!(env!("OUT_DIR"), "/tch_gpu_check.rs"));
50 tch::Device::Cuda(_num)
51 }
52 LibTorchDevice::Mps => tch::Device::Mps,
53 LibTorchDevice::Vulkan => tch::Device::Vulkan,
54 }
55 }
56}
57
58impl From<tch::Device> for LibTorchDevice {
59 fn from(device: tch::Device) -> Self {
60 match device {
61 tch::Device::Cpu => LibTorchDevice::Cpu,
62 tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
63 tch::Device::Mps => LibTorchDevice::Mps,
64 tch::Device::Vulkan => LibTorchDevice::Vulkan,
65 }
66 }
67}
68
69impl burn_common::device::Device for LibTorchDevice {
70 fn from_id(device_id: DeviceId) -> Self {
71 match device_id.type_id {
72 0 => Self::Cuda(device_id.index_id as usize),
73 1 => Self::Mps,
74 2 => Self::Cpu,
75 3 => Self::Vulkan,
76 _ => LibTorchDevice::Cpu,
77 }
78 }
79
80 fn to_id(&self) -> DeviceId {
81 match self {
82 LibTorchDevice::Cuda(index) => DeviceId::new(0, *index as u32),
83 LibTorchDevice::Mps => DeviceId::new(1, 0),
84 LibTorchDevice::Cpu => DeviceId::new(2, 0),
85 LibTorchDevice::Vulkan => DeviceId::new(3, 0),
86 }
87 }
88
89 fn device_count(_type_id: u16) -> usize {
90 1
92 }
93}
94
95impl DeviceOps for LibTorchDevice {}
96
97#[derive(Clone, Copy, Default, Debug)]
107pub struct LibTorch<E = f32> {
108 _e: E,
109}
110
111impl<E: TchElement> Backend for LibTorch<E> {
112 type Device = LibTorchDevice;
113
114 type FloatTensorPrimitive = TchTensor;
115 type FloatElem = E;
116
117 type IntTensorPrimitive = TchTensor;
118 type IntElem = i64;
119
120 type BoolTensorPrimitive = TchTensor;
121 type BoolElem = bool;
122
123 type QuantizedTensorPrimitive = TchTensor;
124
125 fn seed(_device: &Self::Device, seed: u64) {
126 tch::manual_seed(seed as i64);
127 }
128
129 fn ad_enabled() -> bool {
130 false
131 }
132
133 fn name(device: &Self::Device) -> String {
134 match device {
135 LibTorchDevice::Cpu => "libtorch<cpu>",
136 LibTorchDevice::Cuda(_) => "libtorch<cuda>",
137 LibTorchDevice::Mps => "libtorch<metal>",
138 LibTorchDevice::Vulkan => "libtorch<vulkan>",
139 }
140 .to_string()
141 }
142
143 fn sync(device: &Self::Device) {
144 match device {
145 LibTorchDevice::Cpu => (),
146 LibTorchDevice::Cuda(index) => {
147 tch::Cuda::synchronize(*index as i64);
148 }
149 _ => {
150 Tensor::<Self, 1, Int>::from_primitive(<Self as IntTensorOps<Self>>::int_zeros(
152 [1].into(),
153 device,
154 E::dtype().into(),
155 ))
156 .into_data();
157 }
158 }
159 }
160}