burn_jit/ops/
module_ops.rs

1use crate::{
2    element::BoolElement,
3    kernel::{
4        self,
5        conv::{Conv2dStrategy, ConvTranspose2dStrategy},
6    },
7    FloatElement, IntElement, JitBackend, JitRuntime,
8};
9use burn_tensor::ops::{
10    ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
11    MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
12};
13use burn_tensor::ops::{FloatTensor, IntTensor};
14
15impl<R, F, I, BT> ModuleOps<Self> for JitBackend<R, F, I, BT>
16where
17    R: JitRuntime,
18    F: FloatElement,
19    I: IntElement,
20    BT: BoolElement,
21{
22    fn conv2d(
23        x: FloatTensor<Self>,
24        weight: FloatTensor<Self>,
25        bias: Option<FloatTensor<Self>>,
26        options: ConvOptions<2>,
27    ) -> FloatTensor<Self> {
28        kernel::conv::conv2d::<R, F>(x, weight, bias, options, Conv2dStrategy::default()).unwrap()
29    }
30
31    fn deform_conv2d(
32        x: FloatTensor<Self>,
33        offset: FloatTensor<Self>,
34        weight: FloatTensor<Self>,
35        mask: Option<FloatTensor<Self>>,
36        bias: Option<FloatTensor<Self>>,
37        options: DeformConvOptions<2>,
38    ) -> FloatTensor<Self> {
39        kernel::conv::deform_conv2d::<R, F>(x, offset, weight, mask, bias, options).unwrap()
40    }
41
42    fn deform_conv2d_backward(
43        x: FloatTensor<Self>,
44        offset: FloatTensor<Self>,
45        weight: FloatTensor<Self>,
46        mask: Option<FloatTensor<Self>>,
47        bias: Option<FloatTensor<Self>>,
48        output_grad: FloatTensor<Self>,
49        options: DeformConvOptions<2>,
50    ) -> DeformConv2dBackward<Self> {
51        kernel::conv::deform_conv2d_backward::<R, F, I, BT>(
52            x,
53            offset,
54            weight,
55            mask,
56            bias,
57            output_grad,
58            options,
59        )
60        .unwrap()
61    }
62
63    fn conv3d(
64        x: FloatTensor<Self>,
65        weight: FloatTensor<Self>,
66        bias: Option<FloatTensor<Self>>,
67        options: ConvOptions<3>,
68    ) -> FloatTensor<Self> {
69        kernel::conv::conv3d::<R, F>(x, weight, bias, options)
70    }
71
72    fn conv_transpose2d(
73        x: FloatTensor<Self>,
74        weight: FloatTensor<Self>,
75        bias: Option<FloatTensor<Self>>,
76        options: ConvTransposeOptions<2>,
77    ) -> FloatTensor<Self> {
78        kernel::conv::conv_transpose2d::<R, F, I>(
79            x,
80            weight,
81            bias,
82            options,
83            ConvTranspose2dStrategy::default(),
84        )
85        .unwrap()
86    }
87
88    fn conv_transpose3d(
89        x: FloatTensor<Self>,
90        weight: FloatTensor<Self>,
91        bias: Option<FloatTensor<Self>>,
92        options: ConvTransposeOptions<3>,
93    ) -> FloatTensor<Self> {
94        kernel::conv::conv_transpose3d::<R, F>(x, weight, bias, options)
95    }
96
97    fn avg_pool2d(
98        x: FloatTensor<Self>,
99        kernel_size: [usize; 2],
100        stride: [usize; 2],
101        padding: [usize; 2],
102        count_include_pad: bool,
103    ) -> FloatTensor<Self> {
104        kernel::pool::avg_pool2d::<R, F>(x, kernel_size, stride, padding, count_include_pad)
105    }
106
107    fn avg_pool2d_backward(
108        x: FloatTensor<Self>,
109        grad: FloatTensor<Self>,
110        kernel_size: [usize; 2],
111        stride: [usize; 2],
112        padding: [usize; 2],
113        count_include_pad: bool,
114    ) -> FloatTensor<Self> {
115        kernel::pool::avg_pool2d_backward::<R, F>(
116            x,
117            grad,
118            kernel_size,
119            stride,
120            padding,
121            count_include_pad,
122        )
123    }
124
125    fn max_pool2d(
126        x: FloatTensor<Self>,
127        kernel_size: [usize; 2],
128        stride: [usize; 2],
129        padding: [usize; 2],
130        dilation: [usize; 2],
131    ) -> FloatTensor<Self> {
132        kernel::pool::max_pool2d::<R, F>(x, kernel_size, stride, padding, dilation)
133    }
134
135    fn max_pool2d_with_indices(
136        x: FloatTensor<Self>,
137        kernel_size: [usize; 2],
138        stride: [usize; 2],
139        padding: [usize; 2],
140        dilation: [usize; 2],
141    ) -> MaxPool2dWithIndices<Self> {
142        let (output, indices) = kernel::pool::max_pool2d_with_indices::<R, F, I>(
143            x,
144            kernel_size,
145            stride,
146            padding,
147            dilation,
148        );
149
150        MaxPool2dWithIndices::new(output, indices)
151    }
152
153    fn max_pool2d_with_indices_backward(
154        x: FloatTensor<Self>,
155        kernel_size: [usize; 2],
156        stride: [usize; 2],
157        padding: [usize; 2],
158        dilation: [usize; 2],
159        output_grad: FloatTensor<Self>,
160        indices: IntTensor<Self>,
161    ) -> MaxPool2dBackward<Self> {
162        MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward::<R, F, I>(
163            x,
164            output_grad,
165            indices,
166            kernel_size,
167            stride,
168            padding,
169            dilation,
170        ))
171    }
172
173    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
174        kernel::pool::adaptive_avg_pool2d::<R, F>(x, output_size)
175    }
176
177    fn adaptive_avg_pool2d_backward(
178        x: FloatTensor<Self>,
179        grad: FloatTensor<Self>,
180    ) -> FloatTensor<Self> {
181        kernel::pool::adaptive_avg_pool2d_backward::<R, F>(x, grad)
182    }
183
184    fn interpolate(
185        x: FloatTensor<Self>,
186        output_size: [usize; 2],
187        options: InterpolateOptions,
188    ) -> FloatTensor<Self> {
189        kernel::interpolate::interpolate::<R, F>(x, output_size, options)
190    }
191
192    fn interpolate_backward(
193        x: FloatTensor<Self>,
194        grad: FloatTensor<Self>,
195        output_size: [usize; 2],
196        options: InterpolateOptions,
197    ) -> FloatTensor<Self> {
198        kernel::interpolate::interpolate_backward::<R, F>(x, grad, output_size, options)
199    }
200}