Skip to main content

burn_tch/ops/
module.rs

1use crate::{LibTorch, TchTensor, element::TchElement};
2use burn_backend::{
3    TensorMetadata,
4    ops::{
5        AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
6        DeformConvOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices,
7        MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, attention::attention_fallback,
8    },
9    tensor::{FloatTensor, IntTensor},
10};
11
12impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
13    fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor {
14        // Workaround for MPS "Placeholder storage has not been allocated" error.
15        // See: https://github.com/pytorch/pytorch/issues/123995
16        // MPS uses lazy allocation and the embedding operation (which uses index_select)
17        // can fail if the tensors haven't been materialized yet.
18        // We work around this by performing the embedding on CPU and transferring back to MPS.
19        if matches!(weights.tensor.device(), tch::Device::Mps) {
20            let cpu_weights = weights.tensor.to(tch::Device::Cpu);
21            let cpu_indices = indices.tensor.to(tch::Device::Cpu);
22            let result = tch::Tensor::embedding(&cpu_weights, &cpu_indices, -1, false, false)
23                .to(tch::Device::Mps);
24            return TchTensor::new(result);
25        }
26
27        let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
28        TchTensor::new(tensor)
29    }
30
31    fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor {
32        let [n_embedding, _d_model] = weights.shape().dims();
33
34        // Workaround for MPS "Placeholder storage has not been allocated" error.
35        // See: https://github.com/pytorch/pytorch/issues/123995
36        if matches!(output.tensor.device(), tch::Device::Mps) {
37            let cpu_output = output.tensor.to(tch::Device::Cpu);
38            let cpu_indices = indices.tensor.to(tch::Device::Cpu);
39            let result = tch::Tensor::embedding_backward(
40                &cpu_output,
41                &cpu_indices,
42                n_embedding as i64,
43                -1,
44                false,
45                false,
46            )
47            .to(tch::Device::Mps);
48            return TchTensor::new(result);
49        }
50
51        let tensor = tch::Tensor::embedding_backward(
52            &output.tensor,
53            &indices.tensor,
54            n_embedding as i64,
55            -1,
56            false,
57            false,
58        );
59
60        TchTensor::new(tensor)
61    }
62
63    fn conv1d(
64        x: TchTensor,
65        weight: TchTensor,
66        bias: Option<TchTensor>,
67        options: ConvOptions<1>,
68    ) -> TchTensor {
69        let tensor = tch::Tensor::conv1d(
70            &x.tensor,
71            &weight.tensor,
72            bias.map(|t| t.tensor),
73            options.stride.map(|i| i as i64),
74            options.padding.map(|i| i as i64),
75            options.dilation.map(|i| i as i64),
76            options.groups as i64,
77        );
78
79        TchTensor::new(tensor)
80    }
81
82    fn conv2d(
83        x: TchTensor,
84        weight: TchTensor,
85        bias: Option<TchTensor>,
86        options: ConvOptions<2>,
87    ) -> TchTensor {
88        let tensor = tch::Tensor::conv2d(
89            &x.tensor,
90            &weight.tensor,
91            bias.map(|t| t.tensor),
92            options.stride.map(|i| i as i64),
93            options.padding.map(|i| i as i64),
94            options.dilation.map(|i| i as i64),
95            options.groups as i64,
96        );
97
98        TchTensor::new(tensor)
99    }
100
101    fn conv3d(
102        x: TchTensor,
103        weight: TchTensor,
104        bias: Option<TchTensor>,
105        options: ConvOptions<3>,
106    ) -> TchTensor {
107        let tensor = tch::Tensor::conv3d(
108            &x.tensor,
109            &weight.tensor,
110            bias.map(|t| t.tensor),
111            options.stride.map(|i| i as i64),
112            options.padding.map(|i| i as i64),
113            options.dilation.map(|i| i as i64),
114            options.groups as i64,
115        );
116
117        TchTensor::new(tensor)
118    }
119
120    fn deform_conv2d(
121        _x: TchTensor,
122        _offset: TchTensor,
123        _weight: TchTensor,
124        _mask: Option<TchTensor>,
125        _bias: Option<TchTensor>,
126        _options: DeformConvOptions<2>,
127    ) -> TchTensor {
128        unimplemented!("Torch bindings don't support deform_conv2d");
129    }
130
131    fn deform_conv2d_backward(
132        _x: TchTensor,
133        _offset: TchTensor,
134        _weight: TchTensor,
135        _mask: Option<TchTensor>,
136        _bias: Option<TchTensor>,
137        _out_grad: TchTensor,
138        _options: DeformConvOptions<2>,
139    ) -> DeformConv2dBackward<Self> {
140        unimplemented!("Torch bindings don't support deform_conv2d");
141    }
142
143    fn conv_transpose1d(
144        x: TchTensor,
145        weight: TchTensor,
146        bias: Option<TchTensor>,
147        options: ConvTransposeOptions<1>,
148    ) -> TchTensor {
149        let tensor = tch::Tensor::conv_transpose1d(
150            &x.tensor,
151            &weight.tensor,
152            bias.map(|t| t.tensor),
153            options.stride.map(|i| i as i64),
154            options.padding.map(|i| i as i64),
155            options.padding_out.map(|i| i as i64),
156            options.groups as i64,
157            options.dilation.map(|i| i as i64),
158        );
159
160        TchTensor::new(tensor)
161    }
162
163    fn conv_transpose2d(
164        x: TchTensor,
165        weight: TchTensor,
166        bias: Option<TchTensor>,
167        options: ConvTransposeOptions<2>,
168    ) -> TchTensor {
169        let tensor = tch::Tensor::conv_transpose2d(
170            &x.tensor,
171            &weight.tensor,
172            bias.map(|t| t.tensor),
173            options.stride.map(|i| i as i64),
174            options.padding.map(|i| i as i64),
175            options.padding_out.map(|i| i as i64),
176            options.groups as i64,
177            options.dilation.map(|i| i as i64),
178        );
179
180        TchTensor::new(tensor)
181    }
182
183    fn conv_transpose3d(
184        x: TchTensor,
185        weight: TchTensor,
186        bias: Option<TchTensor>,
187        options: ConvTransposeOptions<3>,
188    ) -> TchTensor {
189        let tensor = tch::Tensor::conv_transpose3d(
190            &x.tensor,
191            &weight.tensor,
192            bias.map(|t| t.tensor),
193            options.stride.map(|i| i as i64),
194            options.padding.map(|i| i as i64),
195            options.padding_out.map(|i| i as i64),
196            options.groups as i64,
197            options.dilation.map(|i| i as i64),
198        );
199
200        TchTensor::new(tensor)
201    }
202
203    fn avg_pool1d(
204        x: TchTensor,
205        kernel_size: usize,
206        stride: usize,
207        padding: usize,
208        count_include_pad: bool,
209        ceil_mode: bool,
210    ) -> TchTensor {
211        let tensor = tch::Tensor::avg_pool1d(
212            &x.tensor,
213            [kernel_size as i64],
214            [stride as i64],
215            [padding as i64],
216            ceil_mode,
217            count_include_pad,
218        );
219
220        TchTensor::new(tensor)
221    }
222    fn avg_pool2d(
223        x: TchTensor,
224        kernel_size: [usize; 2],
225        stride: [usize; 2],
226        padding: [usize; 2],
227        count_include_pad: bool,
228        ceil_mode: bool,
229    ) -> TchTensor {
230        let tensor = tch::Tensor::avg_pool2d(
231            &x.tensor,
232            [kernel_size[0] as i64, kernel_size[1] as i64],
233            [stride[0] as i64, stride[1] as i64],
234            [padding[0] as i64, padding[1] as i64],
235            ceil_mode,
236            count_include_pad,
237            None,
238        );
239
240        TchTensor::new(tensor)
241    }
242
243    fn avg_pool2d_backward(
244        x: TchTensor,
245        grad: TchTensor,
246        kernel_size: [usize; 2],
247        stride: [usize; 2],
248        padding: [usize; 2],
249        count_include_pad: bool,
250        ceil_mode: bool,
251    ) -> TchTensor {
252        let tensor = tch::Tensor::avg_pool2d_backward(
253            &x.tensor,
254            &grad.tensor,
255            [kernel_size[0] as i64, kernel_size[1] as i64],
256            [stride[0] as i64, stride[1] as i64],
257            [padding[0] as i64, padding[1] as i64],
258            ceil_mode,
259            count_include_pad,
260            None,
261        );
262
263        TchTensor::new(tensor)
264    }
265
266    fn max_pool1d(
267        x: TchTensor,
268        kernel_size: usize,
269        stride: usize,
270        padding: usize,
271        dilation: usize,
272        ceil_mode: bool,
273    ) -> TchTensor {
274        let tensor = tch::Tensor::max_pool1d(
275            &x.tensor,
276            kernel_size as i64,
277            stride as i64,
278            padding as i64,
279            dilation as i64,
280            ceil_mode,
281        );
282
283        TchTensor::new(tensor)
284    }
285
286    fn max_pool1d_with_indices(
287        x: TchTensor,
288        kernel_size: usize,
289        stride: usize,
290        padding: usize,
291        dilation: usize,
292        ceil_mode: bool,
293    ) -> MaxPool1dWithIndices<Self> {
294        let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
295            &x.tensor,
296            kernel_size as i64,
297            stride as i64,
298            padding as i64,
299            dilation as i64,
300            ceil_mode,
301        );
302
303        MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
304    }
305
306    fn max_pool2d(
307        x: TchTensor,
308        kernel_size: [usize; 2],
309        stride: [usize; 2],
310        padding: [usize; 2],
311        dilation: [usize; 2],
312        ceil_mode: bool,
313    ) -> TchTensor {
314        let tensor = tch::Tensor::max_pool2d(
315            &x.tensor,
316            [kernel_size[0] as i64, kernel_size[1] as i64],
317            [stride[0] as i64, stride[1] as i64],
318            [padding[0] as i64, padding[1] as i64],
319            [dilation[0] as i64, dilation[1] as i64],
320            ceil_mode,
321        );
322
323        TchTensor::new(tensor)
324    }
325
326    fn max_pool2d_with_indices(
327        x: TchTensor,
328        kernel_size: [usize; 2],
329        stride: [usize; 2],
330        padding: [usize; 2],
331        dilation: [usize; 2],
332        ceil_mode: bool,
333    ) -> MaxPool2dWithIndices<Self> {
334        let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
335            &x.tensor,
336            [kernel_size[0] as i64, kernel_size[1] as i64],
337            [stride[0] as i64, stride[1] as i64],
338            [padding[0] as i64, padding[1] as i64],
339            [dilation[0] as i64, dilation[1] as i64],
340            ceil_mode,
341        );
342
343        MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
344    }
345
346    fn max_pool2d_with_indices_backward(
347        x: TchTensor,
348        kernel_size: [usize; 2],
349        stride: [usize; 2],
350        padding: [usize; 2],
351        dilation: [usize; 2],
352        ceil_mode: bool,
353        output_grad: TchTensor,
354        indices: TchTensor,
355    ) -> MaxPool2dBackward<Self> {
356        let grad = tch::Tensor::max_pool2d_with_indices_backward(
357            &x.tensor,
358            &output_grad.tensor,
359            [kernel_size[0] as i64, kernel_size[1] as i64],
360            [stride[0] as i64, stride[1] as i64],
361            [padding[0] as i64, padding[1] as i64],
362            [dilation[0] as i64, dilation[1] as i64],
363            ceil_mode,
364            &indices.tensor,
365        );
366
367        MaxPool2dBackward::new(TchTensor::new(grad))
368    }
369
370    fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor {
371        let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
372
373        TchTensor::new(tensor)
374    }
375
376    fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor {
377        let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
378
379        TchTensor::new(tensor)
380    }
381
382    fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor {
383        let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64);
384
385        TchTensor::new(tensor)
386    }
387
388    fn interpolate(
389        x: TchTensor,
390        output_size: [usize; 2],
391        options: InterpolateOptions,
392    ) -> TchTensor {
393        let output_size = output_size.map(|e| e as i64);
394
395        let align_corners = options.align_corners;
396        let tensor = match options.mode {
397            InterpolateMode::Nearest => {
398                tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
399            }
400            InterpolateMode::Bilinear => {
401                tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, align_corners, None, None)
402            }
403            InterpolateMode::Bicubic => {
404                tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, align_corners, None, None)
405            }
406            InterpolateMode::Lanczos3 => {
407                panic!("lanczos3 interpolation is not supported by PyTorch/tch backend")
408            }
409        };
410
411        TchTensor::new(tensor)
412    }
413
414    fn interpolate_backward(
415        x: TchTensor,
416        grad: TchTensor,
417        output_size: [usize; 2],
418        options: InterpolateOptions,
419    ) -> TchTensor {
420        let output_size = output_size.map(|e| e as i64);
421        let [n, c, h_in, w_in] = x.shape().dims();
422        let input_size = [n as i64, c as i64, h_in as i64, w_in as i64];
423        let align_corners = options.align_corners;
424
425        let tensor = match options.mode {
426            InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward(
427                &grad.tensor,
428                output_size,
429                input_size,
430                None,
431                None,
432            ),
433            InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
434                &grad.tensor,
435                output_size,
436                input_size,
437                align_corners,
438                None,
439                None,
440            ),
441            InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward(
442                &grad.tensor,
443                output_size,
444                input_size,
445                align_corners,
446                None,
447                None,
448            ),
449            InterpolateMode::Lanczos3 => {
450                panic!("lanczos3 interpolation backward is not supported by PyTorch/tch backend")
451            }
452        };
453
454        TchTensor::new(tensor)
455    }
456
457    fn attention(
458        query: TchTensor,
459        key: TchTensor,
460        value: TchTensor,
461        mask: Option<TchTensor>,
462        attn_bias: Option<TchTensor>,
463        options: AttentionModuleOptions,
464    ) -> TchTensor {
465        if attn_bias.is_some() {
466            return attention_fallback::<Self>(query, key, value, mask, attn_bias, options);
467        }
468
469        TchTensor::new(tch::Tensor::scaled_dot_product_attention(
470            &query.tensor,
471            &key.tensor,
472            &value.tensor,
473            mask.map(|m| m.tensor),
474            0.,
475            options.is_causal,
476            options.scale,
477            false,
478        ))
479    }
480
481    fn layer_norm(
482        tensor: TchTensor,
483        gamma: TchTensor,
484        beta: Option<TchTensor>,
485        epsilon: f64,
486    ) -> TchTensor {
487        let shape = tensor.shape();
488        let last_dim = shape[shape.num_dims() - 1] as i64;
489
490        let tensor = tensor.tensor.layer_norm(
491            [last_dim],
492            Some(&gamma.tensor),
493            beta.as_ref().map(|b| &b.tensor),
494            epsilon,
495            true,
496        );
497
498        TchTensor::new(tensor)
499    }
500
501    fn has_ctc_loss_backward() -> bool {
502        true
503    }
504
505    fn ctc_loss(
506        log_probs: FloatTensor<Self>,
507        targets: IntTensor<Self>,
508        input_lengths: IntTensor<Self>,
509        target_lengths: IntTensor<Self>,
510        blank: usize,
511    ) -> FloatTensor<Self> {
512        // PyTorch's CTC requires int64 for targets and length tensors.
513        let targets_i64 = targets.tensor.to_kind(tch::Kind::Int64);
514        let input_lengths_i64 = input_lengths.tensor.to_kind(tch::Kind::Int64);
515        let target_lengths_i64 = target_lengths.tensor.to_kind(tch::Kind::Int64);
516
517        // Reduction::None returns per-sample losses [N], matching the trait contract.
518        let tensor = tch::Tensor::ctc_loss_tensor(
519            &log_probs.tensor,
520            &targets_i64,
521            &input_lengths_i64,
522            &target_lengths_i64,
523            blank as i64,
524            tch::Reduction::None,
525            false,
526        );
527
528        TchTensor::new(tensor)
529    }
530
531    fn ctc_loss_backward(
532        log_probs: FloatTensor<Self>,
533        targets: IntTensor<Self>,
534        input_lengths: IntTensor<Self>,
535        target_lengths: IntTensor<Self>,
536        grad_loss: FloatTensor<Self>,
537        blank: usize,
538    ) -> FloatTensor<Self> {
539        let targets_i64 = targets.tensor.to_kind(tch::Kind::Int64);
540        let input_lengths_i64 = input_lengths.tensor.to_kind(tch::Kind::Int64);
541        let target_lengths_i64 = target_lengths.tensor.to_kind(tch::Kind::Int64);
542
543        // Recompute forward to get neg_log_likelihood and log_alpha (LibTorch's
544        // backward needs both). PyTorch caches log_alpha during the autograd
545        // forward; our trait has no caching slot for it, so we redo the alpha
546        // recursion here. This is still a single-call into LibTorch's fused
547        // kernel and avoids the ~40T host-side dispatches.
548        let (neg_log_likelihood, log_alpha) = tch::Tensor::internal_ctc_loss_tensor(
549            &log_probs.tensor,
550            &targets_i64,
551            &input_lengths_i64,
552            &target_lengths_i64,
553            blank as i64,
554            false,
555        );
556
557        let grad = tch::Tensor::internal_ctc_loss_backward_tensor(
558            &grad_loss.tensor,
559            &log_probs.tensor,
560            &targets_i64,
561            &input_lengths_i64,
562            &target_lengths_i64,
563            &neg_log_likelihood,
564            &log_alpha,
565            blank as i64,
566            false,
567        );
568
569        TchTensor::new(grad)
570    }
571
572    fn rfft(
573        signal: FloatTensor<Self>,
574        dim: usize,
575        n: Option<usize>,
576    ) -> (FloatTensor<Self>, FloatTensor<Self>) {
577        let complex = signal
578            .tensor
579            .fft_rfft(n.map(|v| v as i64), dim as i64, "backward");
580        let re = TchTensor::new(complex.real().contiguous());
581        let im = TchTensor::new(complex.imag().contiguous());
582        (re, im)
583    }
584
585    fn irfft(
586        spectrum_re: FloatTensor<Self>,
587        spectrum_im: FloatTensor<Self>,
588        dim: usize,
589        n: Option<usize>,
590    ) -> FloatTensor<Self> {
591        let complex = tch::Tensor::complex(&spectrum_re.tensor, &spectrum_im.tensor);
592        TchTensor::new(complex.fft_irfft(n.map(|v| v as i64), dim as i64, "backward"))
593    }
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599    use burn_backend::{
600        TensorData, Tolerance,
601        ops::{FloatTensorOps, IntTensorOps},
602        read_sync,
603    };
604
605    type B = crate::LibTorch<f32>;
606
607    #[test]
608    fn ctc_loss_uniform() {
609        // T=3, N=1, C=2, blank=0, target=[1, 1].
610        // Only valid alignment is (1, 0, 1) with prob (1/2)^3.
611        // Loss = -ln(1/8) = 3 * ln(2)
612        let device = Default::default();
613        let log_probs_data = vec![(0.5f32).ln(); 3 * 2];
614        let log_probs = B::float_from_data(TensorData::new(log_probs_data, [3, 1, 2]), &device);
615        let targets = B::int_from_data(TensorData::from([[1i64, 1]]), &device);
616        let input_lengths = B::int_from_data(TensorData::from([3i64]), &device);
617        let target_lengths = B::int_from_data(TensorData::from([2i64]), &device);
618
619        let loss =
620            <B as ModuleOps<B>>::ctc_loss(log_probs, targets, input_lengths, target_lengths, 0);
621
622        let out = read_sync(B::float_into_data(loss)).unwrap();
623        let expected = TensorData::from([3.0f32 * 2.0f32.ln()]);
624        out.assert_approx_eq::<f32>(&expected, Tolerance::rel_abs(1e-3, 1e-3));
625    }
626}