1use crate::{LibTorch, TchTensor, element::TchElement};
2use burn_backend::{
3 TensorMetadata,
4 ops::{
5 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
6 InterpolateMode, InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward,
7 MaxPool2dWithIndices, ModuleOps,
8 },
9};
10
11impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
12 fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor {
13 if matches!(weights.tensor.device(), tch::Device::Mps) {
19 let cpu_weights = weights.tensor.to(tch::Device::Cpu);
20 let cpu_indices = indices.tensor.to(tch::Device::Cpu);
21 let result = tch::Tensor::embedding(&cpu_weights, &cpu_indices, -1, false, false)
22 .to(tch::Device::Mps);
23 return TchTensor::new(result);
24 }
25
26 let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
27 TchTensor::new(tensor)
28 }
29
30 fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor {
31 let [n_embedding, _d_model] = weights.shape().dims();
32
33 if matches!(output.tensor.device(), tch::Device::Mps) {
36 let cpu_output = output.tensor.to(tch::Device::Cpu);
37 let cpu_indices = indices.tensor.to(tch::Device::Cpu);
38 let result = tch::Tensor::embedding_backward(
39 &cpu_output,
40 &cpu_indices,
41 n_embedding as i64,
42 -1,
43 false,
44 false,
45 )
46 .to(tch::Device::Mps);
47 return TchTensor::new(result);
48 }
49
50 let tensor = tch::Tensor::embedding_backward(
51 &output.tensor,
52 &indices.tensor,
53 n_embedding as i64,
54 -1,
55 false,
56 false,
57 );
58
59 TchTensor::new(tensor)
60 }
61
62 fn conv1d(
63 x: TchTensor,
64 weight: TchTensor,
65 bias: Option<TchTensor>,
66 options: ConvOptions<1>,
67 ) -> TchTensor {
68 let tensor = tch::Tensor::conv1d(
69 &x.tensor,
70 &weight.tensor,
71 bias.map(|t| t.tensor),
72 options.stride.map(|i| i as i64),
73 options.padding.map(|i| i as i64),
74 options.dilation.map(|i| i as i64),
75 options.groups as i64,
76 );
77
78 TchTensor::new(tensor)
79 }
80
81 fn conv2d(
82 x: TchTensor,
83 weight: TchTensor,
84 bias: Option<TchTensor>,
85 options: ConvOptions<2>,
86 ) -> TchTensor {
87 let tensor = tch::Tensor::conv2d(
88 &x.tensor,
89 &weight.tensor,
90 bias.map(|t| t.tensor),
91 options.stride.map(|i| i as i64),
92 options.padding.map(|i| i as i64),
93 options.dilation.map(|i| i as i64),
94 options.groups as i64,
95 );
96
97 TchTensor::new(tensor)
98 }
99
100 fn conv3d(
101 x: TchTensor,
102 weight: TchTensor,
103 bias: Option<TchTensor>,
104 options: ConvOptions<3>,
105 ) -> TchTensor {
106 let tensor = tch::Tensor::conv3d(
107 &x.tensor,
108 &weight.tensor,
109 bias.map(|t| t.tensor),
110 options.stride.map(|i| i as i64),
111 options.padding.map(|i| i as i64),
112 options.dilation.map(|i| i as i64),
113 options.groups as i64,
114 );
115
116 TchTensor::new(tensor)
117 }
118
119 fn deform_conv2d(
120 _x: TchTensor,
121 _offset: TchTensor,
122 _weight: TchTensor,
123 _mask: Option<TchTensor>,
124 _bias: Option<TchTensor>,
125 _options: DeformConvOptions<2>,
126 ) -> TchTensor {
127 unimplemented!("Torch bindings don't support deform_conv2d");
128 }
129
130 fn deform_conv2d_backward(
131 _x: TchTensor,
132 _offset: TchTensor,
133 _weight: TchTensor,
134 _mask: Option<TchTensor>,
135 _bias: Option<TchTensor>,
136 _out_grad: TchTensor,
137 _options: DeformConvOptions<2>,
138 ) -> DeformConv2dBackward<Self> {
139 unimplemented!("Torch bindings don't support deform_conv2d");
140 }
141
142 fn conv_transpose1d(
143 x: TchTensor,
144 weight: TchTensor,
145 bias: Option<TchTensor>,
146 options: ConvTransposeOptions<1>,
147 ) -> TchTensor {
148 let tensor = tch::Tensor::conv_transpose1d(
149 &x.tensor,
150 &weight.tensor,
151 bias.map(|t| t.tensor),
152 options.stride.map(|i| i as i64),
153 options.padding.map(|i| i as i64),
154 options.padding_out.map(|i| i as i64),
155 options.groups as i64,
156 options.dilation.map(|i| i as i64),
157 );
158
159 TchTensor::new(tensor)
160 }
161
162 fn conv_transpose2d(
163 x: TchTensor,
164 weight: TchTensor,
165 bias: Option<TchTensor>,
166 options: ConvTransposeOptions<2>,
167 ) -> TchTensor {
168 let tensor = tch::Tensor::conv_transpose2d(
169 &x.tensor,
170 &weight.tensor,
171 bias.map(|t| t.tensor),
172 options.stride.map(|i| i as i64),
173 options.padding.map(|i| i as i64),
174 options.padding_out.map(|i| i as i64),
175 options.groups as i64,
176 options.dilation.map(|i| i as i64),
177 );
178
179 TchTensor::new(tensor)
180 }
181
182 fn conv_transpose3d(
183 x: TchTensor,
184 weight: TchTensor,
185 bias: Option<TchTensor>,
186 options: ConvTransposeOptions<3>,
187 ) -> TchTensor {
188 let tensor = tch::Tensor::conv_transpose3d(
189 &x.tensor,
190 &weight.tensor,
191 bias.map(|t| t.tensor),
192 options.stride.map(|i| i as i64),
193 options.padding.map(|i| i as i64),
194 options.padding_out.map(|i| i as i64),
195 options.groups as i64,
196 options.dilation.map(|i| i as i64),
197 );
198
199 TchTensor::new(tensor)
200 }
201
202 fn avg_pool1d(
203 x: TchTensor,
204 kernel_size: usize,
205 stride: usize,
206 padding: usize,
207 count_include_pad: bool,
208 ceil_mode: bool,
209 ) -> TchTensor {
210 let tensor = tch::Tensor::avg_pool1d(
211 &x.tensor,
212 [kernel_size as i64],
213 [stride as i64],
214 [padding as i64],
215 ceil_mode,
216 count_include_pad,
217 );
218
219 TchTensor::new(tensor)
220 }
221 fn avg_pool2d(
222 x: TchTensor,
223 kernel_size: [usize; 2],
224 stride: [usize; 2],
225 padding: [usize; 2],
226 count_include_pad: bool,
227 ceil_mode: bool,
228 ) -> TchTensor {
229 let tensor = tch::Tensor::avg_pool2d(
230 &x.tensor,
231 [kernel_size[0] as i64, kernel_size[1] as i64],
232 [stride[0] as i64, stride[1] as i64],
233 [padding[0] as i64, padding[1] as i64],
234 ceil_mode,
235 count_include_pad,
236 None,
237 );
238
239 TchTensor::new(tensor)
240 }
241
242 fn avg_pool2d_backward(
243 x: TchTensor,
244 grad: TchTensor,
245 kernel_size: [usize; 2],
246 stride: [usize; 2],
247 padding: [usize; 2],
248 count_include_pad: bool,
249 ceil_mode: bool,
250 ) -> TchTensor {
251 let tensor = tch::Tensor::avg_pool2d_backward(
252 &x.tensor,
253 &grad.tensor,
254 [kernel_size[0] as i64, kernel_size[1] as i64],
255 [stride[0] as i64, stride[1] as i64],
256 [padding[0] as i64, padding[1] as i64],
257 ceil_mode,
258 count_include_pad,
259 None,
260 );
261
262 TchTensor::new(tensor)
263 }
264
265 fn max_pool1d(
266 x: TchTensor,
267 kernel_size: usize,
268 stride: usize,
269 padding: usize,
270 dilation: usize,
271 ceil_mode: bool,
272 ) -> TchTensor {
273 let tensor = tch::Tensor::max_pool1d(
274 &x.tensor,
275 kernel_size as i64,
276 stride as i64,
277 padding as i64,
278 dilation as i64,
279 ceil_mode,
280 );
281
282 TchTensor::new(tensor)
283 }
284
285 fn max_pool1d_with_indices(
286 x: TchTensor,
287 kernel_size: usize,
288 stride: usize,
289 padding: usize,
290 dilation: usize,
291 ceil_mode: bool,
292 ) -> MaxPool1dWithIndices<LibTorch<E>> {
293 let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
294 &x.tensor,
295 kernel_size as i64,
296 stride as i64,
297 padding as i64,
298 dilation as i64,
299 ceil_mode,
300 );
301
302 MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
303 }
304
305 fn max_pool2d(
306 x: TchTensor,
307 kernel_size: [usize; 2],
308 stride: [usize; 2],
309 padding: [usize; 2],
310 dilation: [usize; 2],
311 ceil_mode: bool,
312 ) -> TchTensor {
313 let tensor = tch::Tensor::max_pool2d(
314 &x.tensor,
315 [kernel_size[0] as i64, kernel_size[1] as i64],
316 [stride[0] as i64, stride[1] as i64],
317 [padding[0] as i64, padding[1] as i64],
318 [dilation[0] as i64, dilation[1] as i64],
319 ceil_mode,
320 );
321
322 TchTensor::new(tensor)
323 }
324
325 fn max_pool2d_with_indices(
326 x: TchTensor,
327 kernel_size: [usize; 2],
328 stride: [usize; 2],
329 padding: [usize; 2],
330 dilation: [usize; 2],
331 ceil_mode: bool,
332 ) -> MaxPool2dWithIndices<LibTorch<E>> {
333 let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
334 &x.tensor,
335 [kernel_size[0] as i64, kernel_size[1] as i64],
336 [stride[0] as i64, stride[1] as i64],
337 [padding[0] as i64, padding[1] as i64],
338 [dilation[0] as i64, dilation[1] as i64],
339 ceil_mode,
340 );
341
342 MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
343 }
344
345 fn max_pool2d_with_indices_backward(
346 x: TchTensor,
347 kernel_size: [usize; 2],
348 stride: [usize; 2],
349 padding: [usize; 2],
350 dilation: [usize; 2],
351 ceil_mode: bool,
352 output_grad: TchTensor,
353 indices: TchTensor,
354 ) -> MaxPool2dBackward<LibTorch<E>> {
355 let grad = tch::Tensor::max_pool2d_with_indices_backward(
356 &x.tensor,
357 &output_grad.tensor,
358 [kernel_size[0] as i64, kernel_size[1] as i64],
359 [stride[0] as i64, stride[1] as i64],
360 [padding[0] as i64, padding[1] as i64],
361 [dilation[0] as i64, dilation[1] as i64],
362 ceil_mode,
363 &indices.tensor,
364 );
365
366 MaxPool2dBackward::new(TchTensor::new(grad))
367 }
368
369 fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor {
370 let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
371
372 TchTensor::new(tensor)
373 }
374
375 fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor {
376 let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
377
378 TchTensor::new(tensor)
379 }
380
381 fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor {
382 let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64);
383
384 TchTensor::new(tensor)
385 }
386
387 fn interpolate(
388 x: TchTensor,
389 output_size: [usize; 2],
390 options: InterpolateOptions,
391 ) -> TchTensor {
392 let output_size = output_size.map(|e| e as i64);
393
394 let tensor = match options.mode {
395 InterpolateMode::Nearest => {
396 tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
397 }
398 InterpolateMode::Bilinear => {
399 tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, true, None, None)
400 }
401 InterpolateMode::Bicubic => {
402 tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, true, None, None)
403 }
404 };
405
406 TchTensor::new(tensor)
407 }
408
409 fn interpolate_backward(
410 x: TchTensor,
411 grad: TchTensor,
412 output_size: [usize; 2],
413 options: InterpolateOptions,
414 ) -> TchTensor {
415 let output_size = output_size.map(|e| e as i64);
416 let [n, c, h_in, w_in] = x.shape().dims();
417 let input_size = [n as i64, c as i64, h_in as i64, w_in as i64];
418
419 let tensor = match options.mode {
420 InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward(
421 &grad.tensor,
422 output_size,
423 input_size,
424 None,
425 None,
426 ),
427 InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
428 &grad.tensor,
429 output_size,
430 input_size,
431 true,
432 None,
433 None,
434 ),
435 InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward(
436 &grad.tensor,
437 output_size,
438 input_size,
439 true,
440 None,
441 None,
442 ),
443 };
444
445 TchTensor::new(tensor)
446 }
447}