1use crate::{
2 CubeBackend, CubeRuntime, FloatElement, IntElement,
3 element::BoolElement,
4 execute_with_dtype,
5 kernel::{
6 self,
7 conv::{ConvStrategy, ConvTranspose2dStrategy},
8 },
9};
10use burn_tensor::ops::{
11 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
12 MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
13};
14use burn_tensor::ops::{FloatTensor, IntTensor};
15
16impl<R, F, I, BT> ModuleOps<Self> for CubeBackend<R, F, I, BT>
17where
18 R: CubeRuntime,
19 F: FloatElement,
20 I: IntElement,
21 BT: BoolElement,
22{
23 fn conv1d(
24 x: FloatTensor<Self>,
25 weight: FloatTensor<Self>,
26 bias: Option<FloatTensor<Self>>,
27 options: ConvOptions<1>,
28 ) -> FloatTensor<Self> {
29 execute_with_dtype!(
30 float(x.dtype),
31 E,
32 kernel::conv::conv::<R, E, 1>(x, weight, bias, options, ConvStrategy::default())
33 .unwrap()
34 )
35 }
36
37 fn conv2d(
38 x: FloatTensor<Self>,
39 weight: FloatTensor<Self>,
40 bias: Option<FloatTensor<Self>>,
41 options: ConvOptions<2>,
42 ) -> FloatTensor<Self> {
43 execute_with_dtype!(
44 float(x.dtype),
45 E,
46 kernel::conv::conv::<R, E, 2>(x, weight, bias, options, ConvStrategy::default())
47 .unwrap()
48 )
49 }
50
51 fn deform_conv2d(
52 x: FloatTensor<Self>,
53 offset: FloatTensor<Self>,
54 weight: FloatTensor<Self>,
55 mask: Option<FloatTensor<Self>>,
56 bias: Option<FloatTensor<Self>>,
57 options: DeformConvOptions<2>,
58 ) -> FloatTensor<Self> {
59 execute_with_dtype!(
60 float(x.dtype),
61 E,
62 kernel::conv::deform_conv2d::<R, E>(x, offset, weight, mask, bias, options).unwrap()
63 )
64 }
65
66 fn deform_conv2d_backward(
67 x: FloatTensor<Self>,
68 offset: FloatTensor<Self>,
69 weight: FloatTensor<Self>,
70 mask: Option<FloatTensor<Self>>,
71 bias: Option<FloatTensor<Self>>,
72 output_grad: FloatTensor<Self>,
73 options: DeformConvOptions<2>,
74 ) -> DeformConv2dBackward<Self> {
75 execute_with_dtype!(float(x.dtype), E, {
76 let (x, o, w, m, b) = kernel::conv::deform_conv2d_backward::<R, E, I, BT>(
77 x,
78 offset,
79 weight,
80 mask,
81 bias,
82 output_grad,
83 options,
84 )
85 .unwrap();
86 DeformConv2dBackward::new(x, o, w, m, b)
87 })
88 }
89
90 fn conv3d(
91 x: FloatTensor<Self>,
92 weight: FloatTensor<Self>,
93 bias: Option<FloatTensor<Self>>,
94 options: ConvOptions<3>,
95 ) -> FloatTensor<Self> {
96 execute_with_dtype!(
97 float(x.dtype),
98 E,
99 kernel::conv::conv::<R, E, 3>(x, weight, bias, options, ConvStrategy::Direct).unwrap()
100 )
101 }
102
103 fn conv_transpose2d(
104 x: FloatTensor<Self>,
105 weight: FloatTensor<Self>,
106 bias: Option<FloatTensor<Self>>,
107 options: ConvTransposeOptions<2>,
108 ) -> FloatTensor<Self> {
109 execute_with_dtype!(
110 float(x.dtype),
111 E,
112 kernel::conv::conv_transpose2d::<R, E, I>(
113 x,
114 weight,
115 bias,
116 options,
117 ConvTranspose2dStrategy::default(),
118 )
119 .unwrap()
120 )
121 }
122
123 fn conv_transpose3d(
124 x: FloatTensor<Self>,
125 weight: FloatTensor<Self>,
126 bias: Option<FloatTensor<Self>>,
127 options: ConvTransposeOptions<3>,
128 ) -> FloatTensor<Self> {
129 execute_with_dtype!(
130 float(x.dtype),
131 E,
132 kernel::conv::conv_transpose3d::<R, E>(x, weight, bias, options)
133 )
134 }
135
136 fn avg_pool2d(
137 x: FloatTensor<Self>,
138 kernel_size: [usize; 2],
139 stride: [usize; 2],
140 padding: [usize; 2],
141 count_include_pad: bool,
142 ) -> FloatTensor<Self> {
143 execute_with_dtype!(
144 float(x.dtype),
145 E,
146 kernel::pool::avg_pool2d::<R, E>(x, kernel_size, stride, padding, count_include_pad)
147 )
148 }
149
150 fn avg_pool2d_backward(
151 x: FloatTensor<Self>,
152 grad: FloatTensor<Self>,
153 kernel_size: [usize; 2],
154 stride: [usize; 2],
155 padding: [usize; 2],
156 count_include_pad: bool,
157 ) -> FloatTensor<Self> {
158 execute_with_dtype!(
159 float(x.dtype),
160 E,
161 kernel::pool::avg_pool2d_backward::<R, E>(
162 x,
163 grad,
164 kernel_size,
165 stride,
166 padding,
167 count_include_pad,
168 )
169 )
170 }
171
172 fn max_pool2d(
173 x: FloatTensor<Self>,
174 kernel_size: [usize; 2],
175 stride: [usize; 2],
176 padding: [usize; 2],
177 dilation: [usize; 2],
178 ) -> FloatTensor<Self> {
179 execute_with_dtype!(
180 float(x.dtype),
181 E,
182 kernel::pool::max_pool2d::<R, E>(x, kernel_size, stride, padding, dilation)
183 )
184 }
185
186 fn max_pool2d_with_indices(
187 x: FloatTensor<Self>,
188 kernel_size: [usize; 2],
189 stride: [usize; 2],
190 padding: [usize; 2],
191 dilation: [usize; 2],
192 ) -> MaxPool2dWithIndices<Self> {
193 execute_with_dtype!(float(x.dtype), E, {
194 let (output, indices) = kernel::pool::max_pool2d_with_indices::<R, E, I>(
195 x,
196 kernel_size,
197 stride,
198 padding,
199 dilation,
200 );
201
202 MaxPool2dWithIndices::new(output, indices)
203 })
204 }
205
206 fn max_pool2d_with_indices_backward(
207 x: FloatTensor<Self>,
208 kernel_size: [usize; 2],
209 stride: [usize; 2],
210 padding: [usize; 2],
211 dilation: [usize; 2],
212 output_grad: FloatTensor<Self>,
213 indices: IntTensor<Self>,
214 ) -> MaxPool2dBackward<Self> {
215 execute_with_dtype!(
216 int(indices.dtype),
217 I,
218 execute_with_dtype!(
219 float(x.dtype),
220 E,
221 MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward::<R, E, I>(
222 x,
223 output_grad,
224 indices,
225 kernel_size,
226 stride,
227 padding,
228 dilation,
229 ))
230 )
231 )
232 }
233
234 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
235 execute_with_dtype!(
236 float(x.dtype),
237 E,
238 kernel::pool::adaptive_avg_pool2d::<R, E>(x, output_size)
239 )
240 }
241
242 fn adaptive_avg_pool2d_backward(
243 x: FloatTensor<Self>,
244 grad: FloatTensor<Self>,
245 ) -> FloatTensor<Self> {
246 execute_with_dtype!(
247 float(x.dtype),
248 E,
249 kernel::pool::adaptive_avg_pool2d_backward::<R, E>(x, grad)
250 )
251 }
252
253 fn interpolate(
254 x: FloatTensor<Self>,
255 output_size: [usize; 2],
256 options: InterpolateOptions,
257 ) -> FloatTensor<Self> {
258 execute_with_dtype!(
259 float(x.dtype),
260 E,
261 kernel::interpolate::interpolate::<R, E>(x, output_size, options)
262 )
263 }
264
265 fn interpolate_backward(
266 x: FloatTensor<Self>,
267 grad: FloatTensor<Self>,
268 output_size: [usize; 2],
269 options: InterpolateOptions,
270 ) -> FloatTensor<Self> {
271 execute_with_dtype!(
272 float(x.dtype),
273 E,
274 kernel::interpolate::interpolate_backward::<R, E>(x, grad, output_size, options)
275 )
276 }
277}