burn_tch/ops/
module.rs

1use crate::{LibTorch, TchTensor, element::TchElement};
2use burn_tensor::{
3    TensorMetadata,
4    ops::{
5        ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
6        InterpolateMode, InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward,
7        MaxPool2dWithIndices, ModuleOps,
8    },
9};
10
11impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
12    fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor {
13        let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
14
15        TchTensor::new(tensor)
16    }
17
18    fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor {
19        let [n_embedding, _d_model] = weights.shape().dims();
20        let tensor = tch::Tensor::embedding_backward(
21            &output.tensor,
22            &indices.tensor,
23            n_embedding as i64,
24            -1,
25            false,
26            false,
27        );
28
29        TchTensor::new(tensor)
30    }
31
32    fn conv1d(
33        x: TchTensor,
34        weight: TchTensor,
35        bias: Option<TchTensor>,
36        options: ConvOptions<1>,
37    ) -> TchTensor {
38        let tensor = tch::Tensor::conv1d(
39            &x.tensor,
40            &weight.tensor,
41            bias.map(|t| t.tensor),
42            options.stride.map(|i| i as i64),
43            options.padding.map(|i| i as i64),
44            options.dilation.map(|i| i as i64),
45            options.groups as i64,
46        );
47
48        TchTensor::new(tensor)
49    }
50
51    fn conv2d(
52        x: TchTensor,
53        weight: TchTensor,
54        bias: Option<TchTensor>,
55        options: ConvOptions<2>,
56    ) -> TchTensor {
57        let tensor = tch::Tensor::conv2d(
58            &x.tensor,
59            &weight.tensor,
60            bias.map(|t| t.tensor),
61            options.stride.map(|i| i as i64),
62            options.padding.map(|i| i as i64),
63            options.dilation.map(|i| i as i64),
64            options.groups as i64,
65        );
66
67        TchTensor::new(tensor)
68    }
69
70    fn conv3d(
71        x: TchTensor,
72        weight: TchTensor,
73        bias: Option<TchTensor>,
74        options: ConvOptions<3>,
75    ) -> TchTensor {
76        let tensor = tch::Tensor::conv3d(
77            &x.tensor,
78            &weight.tensor,
79            bias.map(|t| t.tensor),
80            options.stride.map(|i| i as i64),
81            options.padding.map(|i| i as i64),
82            options.dilation.map(|i| i as i64),
83            options.groups as i64,
84        );
85
86        TchTensor::new(tensor)
87    }
88
89    fn deform_conv2d(
90        _x: TchTensor,
91        _offset: TchTensor,
92        _weight: TchTensor,
93        _mask: Option<TchTensor>,
94        _bias: Option<TchTensor>,
95        _options: DeformConvOptions<2>,
96    ) -> TchTensor {
97        unimplemented!("Torch bindings don't support deform_conv2d");
98    }
99
100    fn deform_conv2d_backward(
101        _x: TchTensor,
102        _offset: TchTensor,
103        _weight: TchTensor,
104        _mask: Option<TchTensor>,
105        _bias: Option<TchTensor>,
106        _out_grad: TchTensor,
107        _options: DeformConvOptions<2>,
108    ) -> DeformConv2dBackward<Self> {
109        unimplemented!("Torch bindings don't support deform_conv2d");
110    }
111
112    fn conv_transpose1d(
113        x: TchTensor,
114        weight: TchTensor,
115        bias: Option<TchTensor>,
116        options: ConvTransposeOptions<1>,
117    ) -> TchTensor {
118        let tensor = tch::Tensor::conv_transpose1d(
119            &x.tensor,
120            &weight.tensor,
121            bias.map(|t| t.tensor),
122            options.stride.map(|i| i as i64),
123            options.padding.map(|i| i as i64),
124            options.padding_out.map(|i| i as i64),
125            options.groups as i64,
126            options.dilation.map(|i| i as i64),
127        );
128
129        TchTensor::new(tensor)
130    }
131
132    fn conv_transpose2d(
133        x: TchTensor,
134        weight: TchTensor,
135        bias: Option<TchTensor>,
136        options: ConvTransposeOptions<2>,
137    ) -> TchTensor {
138        let tensor = tch::Tensor::conv_transpose2d(
139            &x.tensor,
140            &weight.tensor,
141            bias.map(|t| t.tensor),
142            options.stride.map(|i| i as i64),
143            options.padding.map(|i| i as i64),
144            options.padding_out.map(|i| i as i64),
145            options.groups as i64,
146            options.dilation.map(|i| i as i64),
147        );
148
149        TchTensor::new(tensor)
150    }
151
152    fn conv_transpose3d(
153        x: TchTensor,
154        weight: TchTensor,
155        bias: Option<TchTensor>,
156        options: ConvTransposeOptions<3>,
157    ) -> TchTensor {
158        let tensor = tch::Tensor::conv_transpose3d(
159            &x.tensor,
160            &weight.tensor,
161            bias.map(|t| t.tensor),
162            options.stride.map(|i| i as i64),
163            options.padding.map(|i| i as i64),
164            options.padding_out.map(|i| i as i64),
165            options.groups as i64,
166            options.dilation.map(|i| i as i64),
167        );
168
169        TchTensor::new(tensor)
170    }
171
172    fn avg_pool1d(
173        x: TchTensor,
174        kernel_size: usize,
175        stride: usize,
176        padding: usize,
177        count_include_pad: bool,
178    ) -> TchTensor {
179        let tensor = tch::Tensor::avg_pool1d(
180            &x.tensor,
181            [kernel_size as i64],
182            [stride as i64],
183            [padding as i64],
184            false,
185            count_include_pad,
186        );
187
188        TchTensor::new(tensor)
189    }
190    fn avg_pool2d(
191        x: TchTensor,
192        kernel_size: [usize; 2],
193        stride: [usize; 2],
194        padding: [usize; 2],
195        count_include_pad: bool,
196    ) -> TchTensor {
197        let tensor = tch::Tensor::avg_pool2d(
198            &x.tensor,
199            [kernel_size[0] as i64, kernel_size[1] as i64],
200            [stride[0] as i64, stride[1] as i64],
201            [padding[0] as i64, padding[1] as i64],
202            false,
203            count_include_pad,
204            None,
205        );
206
207        TchTensor::new(tensor)
208    }
209
210    fn avg_pool2d_backward(
211        x: TchTensor,
212        grad: TchTensor,
213        kernel_size: [usize; 2],
214        stride: [usize; 2],
215        padding: [usize; 2],
216        count_include_pad: bool,
217    ) -> TchTensor {
218        let tensor = tch::Tensor::avg_pool2d_backward(
219            &x.tensor,
220            &grad.tensor,
221            [kernel_size[0] as i64, kernel_size[1] as i64],
222            [stride[0] as i64, stride[1] as i64],
223            [padding[0] as i64, padding[1] as i64],
224            false,
225            count_include_pad,
226            None,
227        );
228
229        TchTensor::new(tensor)
230    }
231
232    fn max_pool1d(
233        x: TchTensor,
234        kernel_size: usize,
235        stride: usize,
236        padding: usize,
237        dilation: usize,
238    ) -> TchTensor {
239        let tensor = tch::Tensor::max_pool1d(
240            &x.tensor,
241            kernel_size as i64,
242            stride as i64,
243            padding as i64,
244            dilation as i64,
245            false,
246        );
247
248        TchTensor::new(tensor)
249    }
250
251    fn max_pool1d_with_indices(
252        x: TchTensor,
253        kernel_size: usize,
254        stride: usize,
255        padding: usize,
256        dilation: usize,
257    ) -> MaxPool1dWithIndices<LibTorch<E>> {
258        let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
259            &x.tensor,
260            kernel_size as i64,
261            stride as i64,
262            padding as i64,
263            dilation as i64,
264            false,
265        );
266
267        MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
268    }
269
270    fn max_pool2d(
271        x: TchTensor,
272        kernel_size: [usize; 2],
273        stride: [usize; 2],
274        padding: [usize; 2],
275        dilation: [usize; 2],
276    ) -> TchTensor {
277        let tensor = tch::Tensor::max_pool2d(
278            &x.tensor,
279            [kernel_size[0] as i64, kernel_size[1] as i64],
280            [stride[0] as i64, stride[1] as i64],
281            [padding[0] as i64, padding[1] as i64],
282            [dilation[0] as i64, dilation[1] as i64],
283            false,
284        );
285
286        TchTensor::new(tensor)
287    }
288
289    fn max_pool2d_with_indices(
290        x: TchTensor,
291        kernel_size: [usize; 2],
292        stride: [usize; 2],
293        padding: [usize; 2],
294        dilation: [usize; 2],
295    ) -> MaxPool2dWithIndices<LibTorch<E>> {
296        let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
297            &x.tensor,
298            [kernel_size[0] as i64, kernel_size[1] as i64],
299            [stride[0] as i64, stride[1] as i64],
300            [padding[0] as i64, padding[1] as i64],
301            [dilation[0] as i64, dilation[1] as i64],
302            false,
303        );
304
305        MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
306    }
307
308    fn max_pool2d_with_indices_backward(
309        x: TchTensor,
310        kernel_size: [usize; 2],
311        stride: [usize; 2],
312        padding: [usize; 2],
313        dilation: [usize; 2],
314        output_grad: TchTensor,
315        indices: TchTensor,
316    ) -> MaxPool2dBackward<LibTorch<E>> {
317        let grad = tch::Tensor::max_pool2d_with_indices_backward(
318            &x.tensor,
319            &output_grad.tensor,
320            [kernel_size[0] as i64, kernel_size[1] as i64],
321            [stride[0] as i64, stride[1] as i64],
322            [padding[0] as i64, padding[1] as i64],
323            [dilation[0] as i64, dilation[1] as i64],
324            false,
325            &indices.tensor,
326        );
327
328        MaxPool2dBackward::new(TchTensor::new(grad))
329    }
330
331    fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor {
332        let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
333
334        TchTensor::new(tensor)
335    }
336
337    fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor {
338        let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
339
340        TchTensor::new(tensor)
341    }
342
343    fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor {
344        let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64);
345
346        TchTensor::new(tensor)
347    }
348
349    fn interpolate(
350        x: TchTensor,
351        output_size: [usize; 2],
352        options: InterpolateOptions,
353    ) -> TchTensor {
354        let output_size = output_size.map(|e| e as i64);
355
356        let tensor = match options.mode {
357            InterpolateMode::Nearest => {
358                tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
359            }
360            InterpolateMode::Bilinear => {
361                tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, true, None, None)
362            }
363            InterpolateMode::Bicubic => {
364                tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, true, None, None)
365            }
366        };
367
368        TchTensor::new(tensor)
369    }
370
371    fn interpolate_backward(
372        x: TchTensor,
373        grad: TchTensor,
374        output_size: [usize; 2],
375        options: InterpolateOptions,
376    ) -> TchTensor {
377        let output_size = output_size.map(|e| e as i64);
378        let [n, c, h_in, w_in] = x.shape().dims();
379        let input_size = [n as i64, c as i64, h_in as i64, w_in as i64];
380
381        let tensor = match options.mode {
382            InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward(
383                &grad.tensor,
384                output_size,
385                input_size,
386                None,
387                None,
388            ),
389            InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
390                &grad.tensor,
391                output_size,
392                input_size,
393                true,
394                None,
395                None,
396            ),
397            InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward(
398                &grad.tensor,
399                output_size,
400                input_size,
401                true,
402                None,
403                None,
404            ),
405        };
406
407        TchTensor::new(tensor)
408    }
409}