1use crate::{QuantElement, TchQTensor};
2
3use super::element::TchElement;
4use super::TchTensor;
5use burn_tensor::backend::{Backend, DeviceId, DeviceOps};
6use burn_tensor::ops::IntTensorOps;
7use burn_tensor::{Int, Tensor};
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10pub enum LibTorchDevice {
26 Cpu,
28
29 Cuda(usize),
32
33 Mps,
35
36 Vulkan,
38}
39
40impl From<LibTorchDevice> for tch::Device {
41 fn from(device: LibTorchDevice) -> Self {
42 match device {
43 LibTorchDevice::Cpu => tch::Device::Cpu,
44 LibTorchDevice::Cuda(num) => tch::Device::Cuda(num),
45 LibTorchDevice::Mps => tch::Device::Mps,
46 LibTorchDevice::Vulkan => tch::Device::Vulkan,
47 }
48 }
49}
50
51impl From<tch::Device> for LibTorchDevice {
52 fn from(device: tch::Device) -> Self {
53 match device {
54 tch::Device::Cpu => LibTorchDevice::Cpu,
55 tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),
56 tch::Device::Mps => LibTorchDevice::Mps,
57 tch::Device::Vulkan => LibTorchDevice::Vulkan,
58 }
59 }
60}
61
62impl DeviceOps for LibTorchDevice {
63 fn id(&self) -> burn_tensor::backend::DeviceId {
64 match self {
65 LibTorchDevice::Cpu => DeviceId::new(0, 0),
66 LibTorchDevice::Cuda(index) => DeviceId::new(1, *index as u32),
67 LibTorchDevice::Mps => DeviceId::new(2, 0),
68 LibTorchDevice::Vulkan => DeviceId::new(3, 0),
69 }
70 }
71}
72
73impl Default for LibTorchDevice {
74 fn default() -> Self {
75 Self::Cpu
76 }
77}
78
79#[derive(Clone, Copy, Default, Debug)]
89pub struct LibTorch<E = f32, Q = i8> {
90 _e: E,
91 _q: Q,
92}
93
94impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
95 type Device = LibTorchDevice;
96
97 type FloatTensorPrimitive = TchTensor;
98 type FloatElem = E;
99
100 type IntTensorPrimitive = TchTensor;
101 type IntElem = i64;
102
103 type BoolTensorPrimitive = TchTensor;
104 type BoolElem = bool;
105
106 type QuantizedTensorPrimitive = TchQTensor;
107 type QuantizedEncoding = Q;
108
109 fn seed(seed: u64) {
110 tch::manual_seed(seed as i64);
111 }
112
113 fn ad_enabled() -> bool {
114 false
115 }
116
117 fn name() -> String {
118 "tch".to_string()
119 }
120
121 fn sync(device: &Self::Device) {
122 match device {
123 LibTorchDevice::Cpu => (),
124 LibTorchDevice::Cuda(index) => {
125 tch::Cuda::synchronize(*index as i64);
126 }
127 _ => {
128 Tensor::<Self, 1, Int>::from_primitive(<Self as IntTensorOps<Self>>::int_zeros(
130 [1].into(),
131 device,
132 ))
133 .into_data();
134 }
135 }
136 }
137}