burn_tch/ops/
module.rs

1use crate::{LibTorch, TchTensor, element::TchElement};
2use burn_backend::{
3    TensorMetadata,
4    ops::{
5        ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
6        InterpolateMode, InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward,
7        MaxPool2dWithIndices, ModuleOps,
8    },
9};
10
11impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
12    fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor {
13        // Workaround for MPS "Placeholder storage has not been allocated" error.
14        // See: https://github.com/pytorch/pytorch/issues/123995
15        // MPS uses lazy allocation and the embedding operation (which uses index_select)
16        // can fail if the tensors haven't been materialized yet.
17        // We work around this by performing the embedding on CPU and transferring back to MPS.
18        if matches!(weights.tensor.device(), tch::Device::Mps) {
19            let cpu_weights = weights.tensor.to(tch::Device::Cpu);
20            let cpu_indices = indices.tensor.to(tch::Device::Cpu);
21            let result = tch::Tensor::embedding(&cpu_weights, &cpu_indices, -1, false, false)
22                .to(tch::Device::Mps);
23            return TchTensor::new(result);
24        }
25
26        let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
27        TchTensor::new(tensor)
28    }
29
30    fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor {
31        let [n_embedding, _d_model] = weights.shape().dims();
32
33        // Workaround for MPS "Placeholder storage has not been allocated" error.
34        // See: https://github.com/pytorch/pytorch/issues/123995
35        if matches!(output.tensor.device(), tch::Device::Mps) {
36            let cpu_output = output.tensor.to(tch::Device::Cpu);
37            let cpu_indices = indices.tensor.to(tch::Device::Cpu);
38            let result = tch::Tensor::embedding_backward(
39                &cpu_output,
40                &cpu_indices,
41                n_embedding as i64,
42                -1,
43                false,
44                false,
45            )
46            .to(tch::Device::Mps);
47            return TchTensor::new(result);
48        }
49
50        let tensor = tch::Tensor::embedding_backward(
51            &output.tensor,
52            &indices.tensor,
53            n_embedding as i64,
54            -1,
55            false,
56            false,
57        );
58
59        TchTensor::new(tensor)
60    }
61
62    fn conv1d(
63        x: TchTensor,
64        weight: TchTensor,
65        bias: Option<TchTensor>,
66        options: ConvOptions<1>,
67    ) -> TchTensor {
68        let tensor = tch::Tensor::conv1d(
69            &x.tensor,
70            &weight.tensor,
71            bias.map(|t| t.tensor),
72            options.stride.map(|i| i as i64),
73            options.padding.map(|i| i as i64),
74            options.dilation.map(|i| i as i64),
75            options.groups as i64,
76        );
77
78        TchTensor::new(tensor)
79    }
80
81    fn conv2d(
82        x: TchTensor,
83        weight: TchTensor,
84        bias: Option<TchTensor>,
85        options: ConvOptions<2>,
86    ) -> TchTensor {
87        let tensor = tch::Tensor::conv2d(
88            &x.tensor,
89            &weight.tensor,
90            bias.map(|t| t.tensor),
91            options.stride.map(|i| i as i64),
92            options.padding.map(|i| i as i64),
93            options.dilation.map(|i| i as i64),
94            options.groups as i64,
95        );
96
97        TchTensor::new(tensor)
98    }
99
100    fn conv3d(
101        x: TchTensor,
102        weight: TchTensor,
103        bias: Option<TchTensor>,
104        options: ConvOptions<3>,
105    ) -> TchTensor {
106        let tensor = tch::Tensor::conv3d(
107            &x.tensor,
108            &weight.tensor,
109            bias.map(|t| t.tensor),
110            options.stride.map(|i| i as i64),
111            options.padding.map(|i| i as i64),
112            options.dilation.map(|i| i as i64),
113            options.groups as i64,
114        );
115
116        TchTensor::new(tensor)
117    }
118
119    fn deform_conv2d(
120        _x: TchTensor,
121        _offset: TchTensor,
122        _weight: TchTensor,
123        _mask: Option<TchTensor>,
124        _bias: Option<TchTensor>,
125        _options: DeformConvOptions<2>,
126    ) -> TchTensor {
127        unimplemented!("Torch bindings don't support deform_conv2d");
128    }
129
130    fn deform_conv2d_backward(
131        _x: TchTensor,
132        _offset: TchTensor,
133        _weight: TchTensor,
134        _mask: Option<TchTensor>,
135        _bias: Option<TchTensor>,
136        _out_grad: TchTensor,
137        _options: DeformConvOptions<2>,
138    ) -> DeformConv2dBackward<Self> {
139        unimplemented!("Torch bindings don't support deform_conv2d");
140    }
141
142    fn conv_transpose1d(
143        x: TchTensor,
144        weight: TchTensor,
145        bias: Option<TchTensor>,
146        options: ConvTransposeOptions<1>,
147    ) -> TchTensor {
148        let tensor = tch::Tensor::conv_transpose1d(
149            &x.tensor,
150            &weight.tensor,
151            bias.map(|t| t.tensor),
152            options.stride.map(|i| i as i64),
153            options.padding.map(|i| i as i64),
154            options.padding_out.map(|i| i as i64),
155            options.groups as i64,
156            options.dilation.map(|i| i as i64),
157        );
158
159        TchTensor::new(tensor)
160    }
161
162    fn conv_transpose2d(
163        x: TchTensor,
164        weight: TchTensor,
165        bias: Option<TchTensor>,
166        options: ConvTransposeOptions<2>,
167    ) -> TchTensor {
168        let tensor = tch::Tensor::conv_transpose2d(
169            &x.tensor,
170            &weight.tensor,
171            bias.map(|t| t.tensor),
172            options.stride.map(|i| i as i64),
173            options.padding.map(|i| i as i64),
174            options.padding_out.map(|i| i as i64),
175            options.groups as i64,
176            options.dilation.map(|i| i as i64),
177        );
178
179        TchTensor::new(tensor)
180    }
181
182    fn conv_transpose3d(
183        x: TchTensor,
184        weight: TchTensor,
185        bias: Option<TchTensor>,
186        options: ConvTransposeOptions<3>,
187    ) -> TchTensor {
188        let tensor = tch::Tensor::conv_transpose3d(
189            &x.tensor,
190            &weight.tensor,
191            bias.map(|t| t.tensor),
192            options.stride.map(|i| i as i64),
193            options.padding.map(|i| i as i64),
194            options.padding_out.map(|i| i as i64),
195            options.groups as i64,
196            options.dilation.map(|i| i as i64),
197        );
198
199        TchTensor::new(tensor)
200    }
201
202    fn avg_pool1d(
203        x: TchTensor,
204        kernel_size: usize,
205        stride: usize,
206        padding: usize,
207        count_include_pad: bool,
208        ceil_mode: bool,
209    ) -> TchTensor {
210        let tensor = tch::Tensor::avg_pool1d(
211            &x.tensor,
212            [kernel_size as i64],
213            [stride as i64],
214            [padding as i64],
215            ceil_mode,
216            count_include_pad,
217        );
218
219        TchTensor::new(tensor)
220    }
221    fn avg_pool2d(
222        x: TchTensor,
223        kernel_size: [usize; 2],
224        stride: [usize; 2],
225        padding: [usize; 2],
226        count_include_pad: bool,
227        ceil_mode: bool,
228    ) -> TchTensor {
229        let tensor = tch::Tensor::avg_pool2d(
230            &x.tensor,
231            [kernel_size[0] as i64, kernel_size[1] as i64],
232            [stride[0] as i64, stride[1] as i64],
233            [padding[0] as i64, padding[1] as i64],
234            ceil_mode,
235            count_include_pad,
236            None,
237        );
238
239        TchTensor::new(tensor)
240    }
241
242    fn avg_pool2d_backward(
243        x: TchTensor,
244        grad: TchTensor,
245        kernel_size: [usize; 2],
246        stride: [usize; 2],
247        padding: [usize; 2],
248        count_include_pad: bool,
249        ceil_mode: bool,
250    ) -> TchTensor {
251        let tensor = tch::Tensor::avg_pool2d_backward(
252            &x.tensor,
253            &grad.tensor,
254            [kernel_size[0] as i64, kernel_size[1] as i64],
255            [stride[0] as i64, stride[1] as i64],
256            [padding[0] as i64, padding[1] as i64],
257            ceil_mode,
258            count_include_pad,
259            None,
260        );
261
262        TchTensor::new(tensor)
263    }
264
265    fn max_pool1d(
266        x: TchTensor,
267        kernel_size: usize,
268        stride: usize,
269        padding: usize,
270        dilation: usize,
271        ceil_mode: bool,
272    ) -> TchTensor {
273        let tensor = tch::Tensor::max_pool1d(
274            &x.tensor,
275            kernel_size as i64,
276            stride as i64,
277            padding as i64,
278            dilation as i64,
279            ceil_mode,
280        );
281
282        TchTensor::new(tensor)
283    }
284
285    fn max_pool1d_with_indices(
286        x: TchTensor,
287        kernel_size: usize,
288        stride: usize,
289        padding: usize,
290        dilation: usize,
291        ceil_mode: bool,
292    ) -> MaxPool1dWithIndices<LibTorch<E>> {
293        let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
294            &x.tensor,
295            kernel_size as i64,
296            stride as i64,
297            padding as i64,
298            dilation as i64,
299            ceil_mode,
300        );
301
302        MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
303    }
304
305    fn max_pool2d(
306        x: TchTensor,
307        kernel_size: [usize; 2],
308        stride: [usize; 2],
309        padding: [usize; 2],
310        dilation: [usize; 2],
311        ceil_mode: bool,
312    ) -> TchTensor {
313        let tensor = tch::Tensor::max_pool2d(
314            &x.tensor,
315            [kernel_size[0] as i64, kernel_size[1] as i64],
316            [stride[0] as i64, stride[1] as i64],
317            [padding[0] as i64, padding[1] as i64],
318            [dilation[0] as i64, dilation[1] as i64],
319            ceil_mode,
320        );
321
322        TchTensor::new(tensor)
323    }
324
325    fn max_pool2d_with_indices(
326        x: TchTensor,
327        kernel_size: [usize; 2],
328        stride: [usize; 2],
329        padding: [usize; 2],
330        dilation: [usize; 2],
331        ceil_mode: bool,
332    ) -> MaxPool2dWithIndices<LibTorch<E>> {
333        let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
334            &x.tensor,
335            [kernel_size[0] as i64, kernel_size[1] as i64],
336            [stride[0] as i64, stride[1] as i64],
337            [padding[0] as i64, padding[1] as i64],
338            [dilation[0] as i64, dilation[1] as i64],
339            ceil_mode,
340        );
341
342        MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
343    }
344
345    fn max_pool2d_with_indices_backward(
346        x: TchTensor,
347        kernel_size: [usize; 2],
348        stride: [usize; 2],
349        padding: [usize; 2],
350        dilation: [usize; 2],
351        ceil_mode: bool,
352        output_grad: TchTensor,
353        indices: TchTensor,
354    ) -> MaxPool2dBackward<LibTorch<E>> {
355        let grad = tch::Tensor::max_pool2d_with_indices_backward(
356            &x.tensor,
357            &output_grad.tensor,
358            [kernel_size[0] as i64, kernel_size[1] as i64],
359            [stride[0] as i64, stride[1] as i64],
360            [padding[0] as i64, padding[1] as i64],
361            [dilation[0] as i64, dilation[1] as i64],
362            ceil_mode,
363            &indices.tensor,
364        );
365
366        MaxPool2dBackward::new(TchTensor::new(grad))
367    }
368
369    fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor {
370        let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
371
372        TchTensor::new(tensor)
373    }
374
375    fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor {
376        let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
377
378        TchTensor::new(tensor)
379    }
380
381    fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor {
382        let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64);
383
384        TchTensor::new(tensor)
385    }
386
387    fn interpolate(
388        x: TchTensor,
389        output_size: [usize; 2],
390        options: InterpolateOptions,
391    ) -> TchTensor {
392        let output_size = output_size.map(|e| e as i64);
393
394        let tensor = match options.mode {
395            InterpolateMode::Nearest => {
396                tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
397            }
398            InterpolateMode::Bilinear => {
399                tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, true, None, None)
400            }
401            InterpolateMode::Bicubic => {
402                tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, true, None, None)
403            }
404        };
405
406        TchTensor::new(tensor)
407    }
408
409    fn interpolate_backward(
410        x: TchTensor,
411        grad: TchTensor,
412        output_size: [usize; 2],
413        options: InterpolateOptions,
414    ) -> TchTensor {
415        let output_size = output_size.map(|e| e as i64);
416        let [n, c, h_in, w_in] = x.shape().dims();
417        let input_size = [n as i64, c as i64, h_in as i64, w_in as i64];
418
419        let tensor = match options.mode {
420            InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward(
421                &grad.tensor,
422                output_size,
423                input_size,
424                None,
425                None,
426            ),
427            InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
428                &grad.tensor,
429                output_size,
430                input_size,
431                true,
432                None,
433                None,
434            ),
435            InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward(
436                &grad.tensor,
437                output_size,
438                input_size,
439                true,
440                None,
441                None,
442            ),
443        };
444
445        TchTensor::new(tensor)
446    }
447}