1use crate::{LibTorch, TchTensor, element::TchElement};
2use burn_backend::{
3 TensorMetadata,
4 ops::{
5 AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
6 DeformConvOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices,
7 MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, attention::attention_fallback,
8 },
9 tensor::{FloatTensor, IntTensor},
10};
11
12impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
13 fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor {
14 if matches!(weights.tensor.device(), tch::Device::Mps) {
20 let cpu_weights = weights.tensor.to(tch::Device::Cpu);
21 let cpu_indices = indices.tensor.to(tch::Device::Cpu);
22 let result = tch::Tensor::embedding(&cpu_weights, &cpu_indices, -1, false, false)
23 .to(tch::Device::Mps);
24 return TchTensor::new(result);
25 }
26
27 let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
28 TchTensor::new(tensor)
29 }
30
31 fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor {
32 let [n_embedding, _d_model] = weights.shape().dims();
33
34 if matches!(output.tensor.device(), tch::Device::Mps) {
37 let cpu_output = output.tensor.to(tch::Device::Cpu);
38 let cpu_indices = indices.tensor.to(tch::Device::Cpu);
39 let result = tch::Tensor::embedding_backward(
40 &cpu_output,
41 &cpu_indices,
42 n_embedding as i64,
43 -1,
44 false,
45 false,
46 )
47 .to(tch::Device::Mps);
48 return TchTensor::new(result);
49 }
50
51 let tensor = tch::Tensor::embedding_backward(
52 &output.tensor,
53 &indices.tensor,
54 n_embedding as i64,
55 -1,
56 false,
57 false,
58 );
59
60 TchTensor::new(tensor)
61 }
62
63 fn conv1d(
64 x: TchTensor,
65 weight: TchTensor,
66 bias: Option<TchTensor>,
67 options: ConvOptions<1>,
68 ) -> TchTensor {
69 let tensor = tch::Tensor::conv1d(
70 &x.tensor,
71 &weight.tensor,
72 bias.map(|t| t.tensor),
73 options.stride.map(|i| i as i64),
74 options.padding.map(|i| i as i64),
75 options.dilation.map(|i| i as i64),
76 options.groups as i64,
77 );
78
79 TchTensor::new(tensor)
80 }
81
82 fn conv2d(
83 x: TchTensor,
84 weight: TchTensor,
85 bias: Option<TchTensor>,
86 options: ConvOptions<2>,
87 ) -> TchTensor {
88 let tensor = tch::Tensor::conv2d(
89 &x.tensor,
90 &weight.tensor,
91 bias.map(|t| t.tensor),
92 options.stride.map(|i| i as i64),
93 options.padding.map(|i| i as i64),
94 options.dilation.map(|i| i as i64),
95 options.groups as i64,
96 );
97
98 TchTensor::new(tensor)
99 }
100
101 fn conv3d(
102 x: TchTensor,
103 weight: TchTensor,
104 bias: Option<TchTensor>,
105 options: ConvOptions<3>,
106 ) -> TchTensor {
107 let tensor = tch::Tensor::conv3d(
108 &x.tensor,
109 &weight.tensor,
110 bias.map(|t| t.tensor),
111 options.stride.map(|i| i as i64),
112 options.padding.map(|i| i as i64),
113 options.dilation.map(|i| i as i64),
114 options.groups as i64,
115 );
116
117 TchTensor::new(tensor)
118 }
119
120 fn deform_conv2d(
121 _x: TchTensor,
122 _offset: TchTensor,
123 _weight: TchTensor,
124 _mask: Option<TchTensor>,
125 _bias: Option<TchTensor>,
126 _options: DeformConvOptions<2>,
127 ) -> TchTensor {
128 unimplemented!("Torch bindings don't support deform_conv2d");
129 }
130
131 fn deform_conv2d_backward(
132 _x: TchTensor,
133 _offset: TchTensor,
134 _weight: TchTensor,
135 _mask: Option<TchTensor>,
136 _bias: Option<TchTensor>,
137 _out_grad: TchTensor,
138 _options: DeformConvOptions<2>,
139 ) -> DeformConv2dBackward<Self> {
140 unimplemented!("Torch bindings don't support deform_conv2d");
141 }
142
143 fn conv_transpose1d(
144 x: TchTensor,
145 weight: TchTensor,
146 bias: Option<TchTensor>,
147 options: ConvTransposeOptions<1>,
148 ) -> TchTensor {
149 let tensor = tch::Tensor::conv_transpose1d(
150 &x.tensor,
151 &weight.tensor,
152 bias.map(|t| t.tensor),
153 options.stride.map(|i| i as i64),
154 options.padding.map(|i| i as i64),
155 options.padding_out.map(|i| i as i64),
156 options.groups as i64,
157 options.dilation.map(|i| i as i64),
158 );
159
160 TchTensor::new(tensor)
161 }
162
163 fn conv_transpose2d(
164 x: TchTensor,
165 weight: TchTensor,
166 bias: Option<TchTensor>,
167 options: ConvTransposeOptions<2>,
168 ) -> TchTensor {
169 let tensor = tch::Tensor::conv_transpose2d(
170 &x.tensor,
171 &weight.tensor,
172 bias.map(|t| t.tensor),
173 options.stride.map(|i| i as i64),
174 options.padding.map(|i| i as i64),
175 options.padding_out.map(|i| i as i64),
176 options.groups as i64,
177 options.dilation.map(|i| i as i64),
178 );
179
180 TchTensor::new(tensor)
181 }
182
183 fn conv_transpose3d(
184 x: TchTensor,
185 weight: TchTensor,
186 bias: Option<TchTensor>,
187 options: ConvTransposeOptions<3>,
188 ) -> TchTensor {
189 let tensor = tch::Tensor::conv_transpose3d(
190 &x.tensor,
191 &weight.tensor,
192 bias.map(|t| t.tensor),
193 options.stride.map(|i| i as i64),
194 options.padding.map(|i| i as i64),
195 options.padding_out.map(|i| i as i64),
196 options.groups as i64,
197 options.dilation.map(|i| i as i64),
198 );
199
200 TchTensor::new(tensor)
201 }
202
203 fn avg_pool1d(
204 x: TchTensor,
205 kernel_size: usize,
206 stride: usize,
207 padding: usize,
208 count_include_pad: bool,
209 ceil_mode: bool,
210 ) -> TchTensor {
211 let tensor = tch::Tensor::avg_pool1d(
212 &x.tensor,
213 [kernel_size as i64],
214 [stride as i64],
215 [padding as i64],
216 ceil_mode,
217 count_include_pad,
218 );
219
220 TchTensor::new(tensor)
221 }
222 fn avg_pool2d(
223 x: TchTensor,
224 kernel_size: [usize; 2],
225 stride: [usize; 2],
226 padding: [usize; 2],
227 count_include_pad: bool,
228 ceil_mode: bool,
229 ) -> TchTensor {
230 let tensor = tch::Tensor::avg_pool2d(
231 &x.tensor,
232 [kernel_size[0] as i64, kernel_size[1] as i64],
233 [stride[0] as i64, stride[1] as i64],
234 [padding[0] as i64, padding[1] as i64],
235 ceil_mode,
236 count_include_pad,
237 None,
238 );
239
240 TchTensor::new(tensor)
241 }
242
243 fn avg_pool2d_backward(
244 x: TchTensor,
245 grad: TchTensor,
246 kernel_size: [usize; 2],
247 stride: [usize; 2],
248 padding: [usize; 2],
249 count_include_pad: bool,
250 ceil_mode: bool,
251 ) -> TchTensor {
252 let tensor = tch::Tensor::avg_pool2d_backward(
253 &x.tensor,
254 &grad.tensor,
255 [kernel_size[0] as i64, kernel_size[1] as i64],
256 [stride[0] as i64, stride[1] as i64],
257 [padding[0] as i64, padding[1] as i64],
258 ceil_mode,
259 count_include_pad,
260 None,
261 );
262
263 TchTensor::new(tensor)
264 }
265
266 fn max_pool1d(
267 x: TchTensor,
268 kernel_size: usize,
269 stride: usize,
270 padding: usize,
271 dilation: usize,
272 ceil_mode: bool,
273 ) -> TchTensor {
274 let tensor = tch::Tensor::max_pool1d(
275 &x.tensor,
276 kernel_size as i64,
277 stride as i64,
278 padding as i64,
279 dilation as i64,
280 ceil_mode,
281 );
282
283 TchTensor::new(tensor)
284 }
285
286 fn max_pool1d_with_indices(
287 x: TchTensor,
288 kernel_size: usize,
289 stride: usize,
290 padding: usize,
291 dilation: usize,
292 ceil_mode: bool,
293 ) -> MaxPool1dWithIndices<Self> {
294 let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
295 &x.tensor,
296 kernel_size as i64,
297 stride as i64,
298 padding as i64,
299 dilation as i64,
300 ceil_mode,
301 );
302
303 MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
304 }
305
306 fn max_pool2d(
307 x: TchTensor,
308 kernel_size: [usize; 2],
309 stride: [usize; 2],
310 padding: [usize; 2],
311 dilation: [usize; 2],
312 ceil_mode: bool,
313 ) -> TchTensor {
314 let tensor = tch::Tensor::max_pool2d(
315 &x.tensor,
316 [kernel_size[0] as i64, kernel_size[1] as i64],
317 [stride[0] as i64, stride[1] as i64],
318 [padding[0] as i64, padding[1] as i64],
319 [dilation[0] as i64, dilation[1] as i64],
320 ceil_mode,
321 );
322
323 TchTensor::new(tensor)
324 }
325
326 fn max_pool2d_with_indices(
327 x: TchTensor,
328 kernel_size: [usize; 2],
329 stride: [usize; 2],
330 padding: [usize; 2],
331 dilation: [usize; 2],
332 ceil_mode: bool,
333 ) -> MaxPool2dWithIndices<Self> {
334 let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
335 &x.tensor,
336 [kernel_size[0] as i64, kernel_size[1] as i64],
337 [stride[0] as i64, stride[1] as i64],
338 [padding[0] as i64, padding[1] as i64],
339 [dilation[0] as i64, dilation[1] as i64],
340 ceil_mode,
341 );
342
343 MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
344 }
345
346 fn max_pool2d_with_indices_backward(
347 x: TchTensor,
348 kernel_size: [usize; 2],
349 stride: [usize; 2],
350 padding: [usize; 2],
351 dilation: [usize; 2],
352 ceil_mode: bool,
353 output_grad: TchTensor,
354 indices: TchTensor,
355 ) -> MaxPool2dBackward<Self> {
356 let grad = tch::Tensor::max_pool2d_with_indices_backward(
357 &x.tensor,
358 &output_grad.tensor,
359 [kernel_size[0] as i64, kernel_size[1] as i64],
360 [stride[0] as i64, stride[1] as i64],
361 [padding[0] as i64, padding[1] as i64],
362 [dilation[0] as i64, dilation[1] as i64],
363 ceil_mode,
364 &indices.tensor,
365 );
366
367 MaxPool2dBackward::new(TchTensor::new(grad))
368 }
369
370 fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor {
371 let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
372
373 TchTensor::new(tensor)
374 }
375
376 fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor {
377 let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
378
379 TchTensor::new(tensor)
380 }
381
382 fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor {
383 let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64);
384
385 TchTensor::new(tensor)
386 }
387
388 fn interpolate(
389 x: TchTensor,
390 output_size: [usize; 2],
391 options: InterpolateOptions,
392 ) -> TchTensor {
393 let output_size = output_size.map(|e| e as i64);
394
395 let align_corners = options.align_corners;
396 let tensor = match options.mode {
397 InterpolateMode::Nearest => {
398 tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
399 }
400 InterpolateMode::Bilinear => {
401 tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, align_corners, None, None)
402 }
403 InterpolateMode::Bicubic => {
404 tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, align_corners, None, None)
405 }
406 InterpolateMode::Lanczos3 => {
407 panic!("lanczos3 interpolation is not supported by PyTorch/tch backend")
408 }
409 };
410
411 TchTensor::new(tensor)
412 }
413
414 fn interpolate_backward(
415 x: TchTensor,
416 grad: TchTensor,
417 output_size: [usize; 2],
418 options: InterpolateOptions,
419 ) -> TchTensor {
420 let output_size = output_size.map(|e| e as i64);
421 let [n, c, h_in, w_in] = x.shape().dims();
422 let input_size = [n as i64, c as i64, h_in as i64, w_in as i64];
423 let align_corners = options.align_corners;
424
425 let tensor = match options.mode {
426 InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward(
427 &grad.tensor,
428 output_size,
429 input_size,
430 None,
431 None,
432 ),
433 InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
434 &grad.tensor,
435 output_size,
436 input_size,
437 align_corners,
438 None,
439 None,
440 ),
441 InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward(
442 &grad.tensor,
443 output_size,
444 input_size,
445 align_corners,
446 None,
447 None,
448 ),
449 InterpolateMode::Lanczos3 => {
450 panic!("lanczos3 interpolation backward is not supported by PyTorch/tch backend")
451 }
452 };
453
454 TchTensor::new(tensor)
455 }
456
457 fn attention(
458 query: TchTensor,
459 key: TchTensor,
460 value: TchTensor,
461 mask: Option<TchTensor>,
462 attn_bias: Option<TchTensor>,
463 options: AttentionModuleOptions,
464 ) -> TchTensor {
465 if attn_bias.is_some() {
466 return attention_fallback::<Self>(query, key, value, mask, attn_bias, options);
467 }
468
469 TchTensor::new(tch::Tensor::scaled_dot_product_attention(
470 &query.tensor,
471 &key.tensor,
472 &value.tensor,
473 mask.map(|m| m.tensor),
474 0.,
475 options.is_causal,
476 options.scale,
477 false,
478 ))
479 }
480
481 fn layer_norm(
482 tensor: TchTensor,
483 gamma: TchTensor,
484 beta: Option<TchTensor>,
485 epsilon: f64,
486 ) -> TchTensor {
487 let shape = tensor.shape();
488 let last_dim = shape[shape.num_dims() - 1] as i64;
489
490 let tensor = tensor.tensor.layer_norm(
491 [last_dim],
492 Some(&gamma.tensor),
493 beta.as_ref().map(|b| &b.tensor),
494 epsilon,
495 true,
496 );
497
498 TchTensor::new(tensor)
499 }
500
501 fn has_ctc_loss_backward() -> bool {
502 true
503 }
504
505 fn ctc_loss(
506 log_probs: FloatTensor<Self>,
507 targets: IntTensor<Self>,
508 input_lengths: IntTensor<Self>,
509 target_lengths: IntTensor<Self>,
510 blank: usize,
511 ) -> FloatTensor<Self> {
512 let targets_i64 = targets.tensor.to_kind(tch::Kind::Int64);
514 let input_lengths_i64 = input_lengths.tensor.to_kind(tch::Kind::Int64);
515 let target_lengths_i64 = target_lengths.tensor.to_kind(tch::Kind::Int64);
516
517 let tensor = tch::Tensor::ctc_loss_tensor(
519 &log_probs.tensor,
520 &targets_i64,
521 &input_lengths_i64,
522 &target_lengths_i64,
523 blank as i64,
524 tch::Reduction::None,
525 false,
526 );
527
528 TchTensor::new(tensor)
529 }
530
531 fn ctc_loss_backward(
532 log_probs: FloatTensor<Self>,
533 targets: IntTensor<Self>,
534 input_lengths: IntTensor<Self>,
535 target_lengths: IntTensor<Self>,
536 grad_loss: FloatTensor<Self>,
537 blank: usize,
538 ) -> FloatTensor<Self> {
539 let targets_i64 = targets.tensor.to_kind(tch::Kind::Int64);
540 let input_lengths_i64 = input_lengths.tensor.to_kind(tch::Kind::Int64);
541 let target_lengths_i64 = target_lengths.tensor.to_kind(tch::Kind::Int64);
542
543 let (neg_log_likelihood, log_alpha) = tch::Tensor::internal_ctc_loss_tensor(
549 &log_probs.tensor,
550 &targets_i64,
551 &input_lengths_i64,
552 &target_lengths_i64,
553 blank as i64,
554 false,
555 );
556
557 let grad = tch::Tensor::internal_ctc_loss_backward_tensor(
558 &grad_loss.tensor,
559 &log_probs.tensor,
560 &targets_i64,
561 &input_lengths_i64,
562 &target_lengths_i64,
563 &neg_log_likelihood,
564 &log_alpha,
565 blank as i64,
566 false,
567 );
568
569 TchTensor::new(grad)
570 }
571
572 fn rfft(
573 signal: FloatTensor<Self>,
574 dim: usize,
575 n: Option<usize>,
576 ) -> (FloatTensor<Self>, FloatTensor<Self>) {
577 let complex = signal
578 .tensor
579 .fft_rfft(n.map(|v| v as i64), dim as i64, "backward");
580 let re = TchTensor::new(complex.real().contiguous());
581 let im = TchTensor::new(complex.imag().contiguous());
582 (re, im)
583 }
584
585 fn irfft(
586 spectrum_re: FloatTensor<Self>,
587 spectrum_im: FloatTensor<Self>,
588 dim: usize,
589 n: Option<usize>,
590 ) -> FloatTensor<Self> {
591 let complex = tch::Tensor::complex(&spectrum_re.tensor, &spectrum_im.tensor);
592 TchTensor::new(complex.fft_irfft(n.map(|v| v as i64), dim as i64, "backward"))
593 }
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599 use burn_backend::{
600 TensorData, Tolerance,
601 ops::{FloatTensorOps, IntTensorOps},
602 read_sync,
603 };
604
605 type B = crate::LibTorch<f32>;
606
607 #[test]
608 fn ctc_loss_uniform() {
609 let device = Default::default();
613 let log_probs_data = vec![(0.5f32).ln(); 3 * 2];
614 let log_probs = B::float_from_data(TensorData::new(log_probs_data, [3, 1, 2]), &device);
615 let targets = B::int_from_data(TensorData::from([[1i64, 1]]), &device);
616 let input_lengths = B::int_from_data(TensorData::from([3i64]), &device);
617 let target_lengths = B::int_from_data(TensorData::from([2i64]), &device);
618
619 let loss =
620 <B as ModuleOps<B>>::ctc_loss(log_probs, targets, input_lengths, target_lengths, 0);
621
622 let out = read_sync(B::float_into_data(loss)).unwrap();
623 let expected = TensorData::from([3.0f32 * 2.0f32.ln()]);
624 out.assert_approx_eq::<f32>(&expected, Tolerance::rel_abs(1e-3, 1e-3));
625 }
626}