1use crate::{LibTorch, TchTensor, element::TchElement};
2use burn_tensor::{
3 TensorMetadata,
4 ops::{
5 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
6 InterpolateMode, InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward,
7 MaxPool2dWithIndices, ModuleOps,
8 },
9};
10
11impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
12 fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor {
13 let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
14
15 TchTensor::new(tensor)
16 }
17
18 fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor {
19 let [n_embedding, _d_model] = weights.shape().dims();
20 let tensor = tch::Tensor::embedding_backward(
21 &output.tensor,
22 &indices.tensor,
23 n_embedding as i64,
24 -1,
25 false,
26 false,
27 );
28
29 TchTensor::new(tensor)
30 }
31
32 fn conv1d(
33 x: TchTensor,
34 weight: TchTensor,
35 bias: Option<TchTensor>,
36 options: ConvOptions<1>,
37 ) -> TchTensor {
38 let tensor = tch::Tensor::conv1d(
39 &x.tensor,
40 &weight.tensor,
41 bias.map(|t| t.tensor),
42 options.stride.map(|i| i as i64),
43 options.padding.map(|i| i as i64),
44 options.dilation.map(|i| i as i64),
45 options.groups as i64,
46 );
47
48 TchTensor::new(tensor)
49 }
50
51 fn conv2d(
52 x: TchTensor,
53 weight: TchTensor,
54 bias: Option<TchTensor>,
55 options: ConvOptions<2>,
56 ) -> TchTensor {
57 let tensor = tch::Tensor::conv2d(
58 &x.tensor,
59 &weight.tensor,
60 bias.map(|t| t.tensor),
61 options.stride.map(|i| i as i64),
62 options.padding.map(|i| i as i64),
63 options.dilation.map(|i| i as i64),
64 options.groups as i64,
65 );
66
67 TchTensor::new(tensor)
68 }
69
70 fn conv3d(
71 x: TchTensor,
72 weight: TchTensor,
73 bias: Option<TchTensor>,
74 options: ConvOptions<3>,
75 ) -> TchTensor {
76 let tensor = tch::Tensor::conv3d(
77 &x.tensor,
78 &weight.tensor,
79 bias.map(|t| t.tensor),
80 options.stride.map(|i| i as i64),
81 options.padding.map(|i| i as i64),
82 options.dilation.map(|i| i as i64),
83 options.groups as i64,
84 );
85
86 TchTensor::new(tensor)
87 }
88
89 fn deform_conv2d(
90 _x: TchTensor,
91 _offset: TchTensor,
92 _weight: TchTensor,
93 _mask: Option<TchTensor>,
94 _bias: Option<TchTensor>,
95 _options: DeformConvOptions<2>,
96 ) -> TchTensor {
97 unimplemented!("Torch bindings don't support deform_conv2d");
98 }
99
100 fn deform_conv2d_backward(
101 _x: TchTensor,
102 _offset: TchTensor,
103 _weight: TchTensor,
104 _mask: Option<TchTensor>,
105 _bias: Option<TchTensor>,
106 _out_grad: TchTensor,
107 _options: DeformConvOptions<2>,
108 ) -> DeformConv2dBackward<Self> {
109 unimplemented!("Torch bindings don't support deform_conv2d");
110 }
111
112 fn conv_transpose1d(
113 x: TchTensor,
114 weight: TchTensor,
115 bias: Option<TchTensor>,
116 options: ConvTransposeOptions<1>,
117 ) -> TchTensor {
118 let tensor = tch::Tensor::conv_transpose1d(
119 &x.tensor,
120 &weight.tensor,
121 bias.map(|t| t.tensor),
122 options.stride.map(|i| i as i64),
123 options.padding.map(|i| i as i64),
124 options.padding_out.map(|i| i as i64),
125 options.groups as i64,
126 options.dilation.map(|i| i as i64),
127 );
128
129 TchTensor::new(tensor)
130 }
131
132 fn conv_transpose2d(
133 x: TchTensor,
134 weight: TchTensor,
135 bias: Option<TchTensor>,
136 options: ConvTransposeOptions<2>,
137 ) -> TchTensor {
138 let tensor = tch::Tensor::conv_transpose2d(
139 &x.tensor,
140 &weight.tensor,
141 bias.map(|t| t.tensor),
142 options.stride.map(|i| i as i64),
143 options.padding.map(|i| i as i64),
144 options.padding_out.map(|i| i as i64),
145 options.groups as i64,
146 options.dilation.map(|i| i as i64),
147 );
148
149 TchTensor::new(tensor)
150 }
151
152 fn conv_transpose3d(
153 x: TchTensor,
154 weight: TchTensor,
155 bias: Option<TchTensor>,
156 options: ConvTransposeOptions<3>,
157 ) -> TchTensor {
158 let tensor = tch::Tensor::conv_transpose3d(
159 &x.tensor,
160 &weight.tensor,
161 bias.map(|t| t.tensor),
162 options.stride.map(|i| i as i64),
163 options.padding.map(|i| i as i64),
164 options.padding_out.map(|i| i as i64),
165 options.groups as i64,
166 options.dilation.map(|i| i as i64),
167 );
168
169 TchTensor::new(tensor)
170 }
171
172 fn avg_pool1d(
173 x: TchTensor,
174 kernel_size: usize,
175 stride: usize,
176 padding: usize,
177 count_include_pad: bool,
178 ) -> TchTensor {
179 let tensor = tch::Tensor::avg_pool1d(
180 &x.tensor,
181 [kernel_size as i64],
182 [stride as i64],
183 [padding as i64],
184 false,
185 count_include_pad,
186 );
187
188 TchTensor::new(tensor)
189 }
190 fn avg_pool2d(
191 x: TchTensor,
192 kernel_size: [usize; 2],
193 stride: [usize; 2],
194 padding: [usize; 2],
195 count_include_pad: bool,
196 ) -> TchTensor {
197 let tensor = tch::Tensor::avg_pool2d(
198 &x.tensor,
199 [kernel_size[0] as i64, kernel_size[1] as i64],
200 [stride[0] as i64, stride[1] as i64],
201 [padding[0] as i64, padding[1] as i64],
202 false,
203 count_include_pad,
204 None,
205 );
206
207 TchTensor::new(tensor)
208 }
209
210 fn avg_pool2d_backward(
211 x: TchTensor,
212 grad: TchTensor,
213 kernel_size: [usize; 2],
214 stride: [usize; 2],
215 padding: [usize; 2],
216 count_include_pad: bool,
217 ) -> TchTensor {
218 let tensor = tch::Tensor::avg_pool2d_backward(
219 &x.tensor,
220 &grad.tensor,
221 [kernel_size[0] as i64, kernel_size[1] as i64],
222 [stride[0] as i64, stride[1] as i64],
223 [padding[0] as i64, padding[1] as i64],
224 false,
225 count_include_pad,
226 None,
227 );
228
229 TchTensor::new(tensor)
230 }
231
232 fn max_pool1d(
233 x: TchTensor,
234 kernel_size: usize,
235 stride: usize,
236 padding: usize,
237 dilation: usize,
238 ) -> TchTensor {
239 let tensor = tch::Tensor::max_pool1d(
240 &x.tensor,
241 kernel_size as i64,
242 stride as i64,
243 padding as i64,
244 dilation as i64,
245 false,
246 );
247
248 TchTensor::new(tensor)
249 }
250
251 fn max_pool1d_with_indices(
252 x: TchTensor,
253 kernel_size: usize,
254 stride: usize,
255 padding: usize,
256 dilation: usize,
257 ) -> MaxPool1dWithIndices<LibTorch<E>> {
258 let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(
259 &x.tensor,
260 kernel_size as i64,
261 stride as i64,
262 padding as i64,
263 dilation as i64,
264 false,
265 );
266
267 MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
268 }
269
270 fn max_pool2d(
271 x: TchTensor,
272 kernel_size: [usize; 2],
273 stride: [usize; 2],
274 padding: [usize; 2],
275 dilation: [usize; 2],
276 ) -> TchTensor {
277 let tensor = tch::Tensor::max_pool2d(
278 &x.tensor,
279 [kernel_size[0] as i64, kernel_size[1] as i64],
280 [stride[0] as i64, stride[1] as i64],
281 [padding[0] as i64, padding[1] as i64],
282 [dilation[0] as i64, dilation[1] as i64],
283 false,
284 );
285
286 TchTensor::new(tensor)
287 }
288
289 fn max_pool2d_with_indices(
290 x: TchTensor,
291 kernel_size: [usize; 2],
292 stride: [usize; 2],
293 padding: [usize; 2],
294 dilation: [usize; 2],
295 ) -> MaxPool2dWithIndices<LibTorch<E>> {
296 let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
297 &x.tensor,
298 [kernel_size[0] as i64, kernel_size[1] as i64],
299 [stride[0] as i64, stride[1] as i64],
300 [padding[0] as i64, padding[1] as i64],
301 [dilation[0] as i64, dilation[1] as i64],
302 false,
303 );
304
305 MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
306 }
307
308 fn max_pool2d_with_indices_backward(
309 x: TchTensor,
310 kernel_size: [usize; 2],
311 stride: [usize; 2],
312 padding: [usize; 2],
313 dilation: [usize; 2],
314 output_grad: TchTensor,
315 indices: TchTensor,
316 ) -> MaxPool2dBackward<LibTorch<E>> {
317 let grad = tch::Tensor::max_pool2d_with_indices_backward(
318 &x.tensor,
319 &output_grad.tensor,
320 [kernel_size[0] as i64, kernel_size[1] as i64],
321 [stride[0] as i64, stride[1] as i64],
322 [padding[0] as i64, padding[1] as i64],
323 [dilation[0] as i64, dilation[1] as i64],
324 false,
325 &indices.tensor,
326 );
327
328 MaxPool2dBackward::new(TchTensor::new(grad))
329 }
330
331 fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor {
332 let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));
333
334 TchTensor::new(tensor)
335 }
336
337 fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor {
338 let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);
339
340 TchTensor::new(tensor)
341 }
342
343 fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor {
344 let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64);
345
346 TchTensor::new(tensor)
347 }
348
349 fn interpolate(
350 x: TchTensor,
351 output_size: [usize; 2],
352 options: InterpolateOptions,
353 ) -> TchTensor {
354 let output_size = output_size.map(|e| e as i64);
355
356 let tensor = match options.mode {
357 InterpolateMode::Nearest => {
358 tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
359 }
360 InterpolateMode::Bilinear => {
361 tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, true, None, None)
362 }
363 InterpolateMode::Bicubic => {
364 tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, true, None, None)
365 }
366 };
367
368 TchTensor::new(tensor)
369 }
370
371 fn interpolate_backward(
372 x: TchTensor,
373 grad: TchTensor,
374 output_size: [usize; 2],
375 options: InterpolateOptions,
376 ) -> TchTensor {
377 let output_size = output_size.map(|e| e as i64);
378 let [n, c, h_in, w_in] = x.shape().dims();
379 let input_size = [n as i64, c as i64, h_in as i64, w_in as i64];
380
381 let tensor = match options.mode {
382 InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward(
383 &grad.tensor,
384 output_size,
385 input_size,
386 None,
387 None,
388 ),
389 InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
390 &grad.tensor,
391 output_size,
392 input_size,
393 true,
394 None,
395 None,
396 ),
397 InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward(
398 &grad.tensor,
399 output_size,
400 input_size,
401 true,
402 None,
403 None,
404 ),
405 };
406
407 TchTensor::new(tensor)
408 }
409}