1use crate::{
2 element::BoolElement,
3 kernel::{
4 self,
5 conv::{Conv2dStrategy, ConvTranspose2dStrategy},
6 },
7 FloatElement, IntElement, JitBackend, JitRuntime,
8};
9use burn_tensor::ops::{
10 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
11 MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
12};
13use burn_tensor::ops::{FloatTensor, IntTensor};
14
15impl<R, F, I, BT> ModuleOps<Self> for JitBackend<R, F, I, BT>
16where
17 R: JitRuntime,
18 F: FloatElement,
19 I: IntElement,
20 BT: BoolElement,
21{
22 fn conv2d(
23 x: FloatTensor<Self>,
24 weight: FloatTensor<Self>,
25 bias: Option<FloatTensor<Self>>,
26 options: ConvOptions<2>,
27 ) -> FloatTensor<Self> {
28 kernel::conv::conv2d::<R, F>(x, weight, bias, options, Conv2dStrategy::default()).unwrap()
29 }
30
31 fn deform_conv2d(
32 x: FloatTensor<Self>,
33 offset: FloatTensor<Self>,
34 weight: FloatTensor<Self>,
35 mask: Option<FloatTensor<Self>>,
36 bias: Option<FloatTensor<Self>>,
37 options: DeformConvOptions<2>,
38 ) -> FloatTensor<Self> {
39 kernel::conv::deform_conv2d::<R, F>(x, offset, weight, mask, bias, options).unwrap()
40 }
41
42 fn deform_conv2d_backward(
43 x: FloatTensor<Self>,
44 offset: FloatTensor<Self>,
45 weight: FloatTensor<Self>,
46 mask: Option<FloatTensor<Self>>,
47 bias: Option<FloatTensor<Self>>,
48 output_grad: FloatTensor<Self>,
49 options: DeformConvOptions<2>,
50 ) -> DeformConv2dBackward<Self> {
51 kernel::conv::deform_conv2d_backward::<R, F, I, BT>(
52 x,
53 offset,
54 weight,
55 mask,
56 bias,
57 output_grad,
58 options,
59 )
60 .unwrap()
61 }
62
63 fn conv3d(
64 x: FloatTensor<Self>,
65 weight: FloatTensor<Self>,
66 bias: Option<FloatTensor<Self>>,
67 options: ConvOptions<3>,
68 ) -> FloatTensor<Self> {
69 kernel::conv::conv3d::<R, F>(x, weight, bias, options)
70 }
71
72 fn conv_transpose2d(
73 x: FloatTensor<Self>,
74 weight: FloatTensor<Self>,
75 bias: Option<FloatTensor<Self>>,
76 options: ConvTransposeOptions<2>,
77 ) -> FloatTensor<Self> {
78 kernel::conv::conv_transpose2d::<R, F, I>(
79 x,
80 weight,
81 bias,
82 options,
83 ConvTranspose2dStrategy::default(),
84 )
85 .unwrap()
86 }
87
88 fn conv_transpose3d(
89 x: FloatTensor<Self>,
90 weight: FloatTensor<Self>,
91 bias: Option<FloatTensor<Self>>,
92 options: ConvTransposeOptions<3>,
93 ) -> FloatTensor<Self> {
94 kernel::conv::conv_transpose3d::<R, F>(x, weight, bias, options)
95 }
96
97 fn avg_pool2d(
98 x: FloatTensor<Self>,
99 kernel_size: [usize; 2],
100 stride: [usize; 2],
101 padding: [usize; 2],
102 count_include_pad: bool,
103 ) -> FloatTensor<Self> {
104 kernel::pool::avg_pool2d::<R, F>(x, kernel_size, stride, padding, count_include_pad)
105 }
106
107 fn avg_pool2d_backward(
108 x: FloatTensor<Self>,
109 grad: FloatTensor<Self>,
110 kernel_size: [usize; 2],
111 stride: [usize; 2],
112 padding: [usize; 2],
113 count_include_pad: bool,
114 ) -> FloatTensor<Self> {
115 kernel::pool::avg_pool2d_backward::<R, F>(
116 x,
117 grad,
118 kernel_size,
119 stride,
120 padding,
121 count_include_pad,
122 )
123 }
124
125 fn max_pool2d(
126 x: FloatTensor<Self>,
127 kernel_size: [usize; 2],
128 stride: [usize; 2],
129 padding: [usize; 2],
130 dilation: [usize; 2],
131 ) -> FloatTensor<Self> {
132 kernel::pool::max_pool2d::<R, F>(x, kernel_size, stride, padding, dilation)
133 }
134
135 fn max_pool2d_with_indices(
136 x: FloatTensor<Self>,
137 kernel_size: [usize; 2],
138 stride: [usize; 2],
139 padding: [usize; 2],
140 dilation: [usize; 2],
141 ) -> MaxPool2dWithIndices<Self> {
142 let (output, indices) = kernel::pool::max_pool2d_with_indices::<R, F, I>(
143 x,
144 kernel_size,
145 stride,
146 padding,
147 dilation,
148 );
149
150 MaxPool2dWithIndices::new(output, indices)
151 }
152
153 fn max_pool2d_with_indices_backward(
154 x: FloatTensor<Self>,
155 kernel_size: [usize; 2],
156 stride: [usize; 2],
157 padding: [usize; 2],
158 dilation: [usize; 2],
159 output_grad: FloatTensor<Self>,
160 indices: IntTensor<Self>,
161 ) -> MaxPool2dBackward<Self> {
162 MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward::<R, F, I>(
163 x,
164 output_grad,
165 indices,
166 kernel_size,
167 stride,
168 padding,
169 dilation,
170 ))
171 }
172
173 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
174 kernel::pool::adaptive_avg_pool2d::<R, F>(x, output_size)
175 }
176
177 fn adaptive_avg_pool2d_backward(
178 x: FloatTensor<Self>,
179 grad: FloatTensor<Self>,
180 ) -> FloatTensor<Self> {
181 kernel::pool::adaptive_avg_pool2d_backward::<R, F>(x, grad)
182 }
183
184 fn interpolate(
185 x: FloatTensor<Self>,
186 output_size: [usize; 2],
187 options: InterpolateOptions,
188 ) -> FloatTensor<Self> {
189 kernel::interpolate::interpolate::<R, F>(x, output_size, options)
190 }
191
192 fn interpolate_backward(
193 x: FloatTensor<Self>,
194 grad: FloatTensor<Self>,
195 output_size: [usize; 2],
196 options: InterpolateOptions,
197 ) -> FloatTensor<Self> {
198 kernel::interpolate::interpolate_backward::<R, F>(x, grad, output_size, options)
199 }
200}