1use crate::{
2 CubeBackend, CubeRuntime, FloatElement, IntElement,
3 element::BoolElement,
4 kernel::{self, conv::ConvTranspose2dStrategy},
5};
6use burn_backend::ops::{
7 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
8 MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
9};
10use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
11
12impl<R, F, I, BT> ModuleOps<Self> for CubeBackend<R, F, I, BT>
13where
14 R: CubeRuntime,
15 F: FloatElement,
16 I: IntElement,
17 BT: BoolElement,
18{
19 fn conv1d(
20 x: FloatTensor<Self>,
21 weight: FloatTensor<Self>,
22 bias: Option<FloatTensor<Self>>,
23 options: ConvOptions<1>,
24 ) -> FloatTensor<Self> {
25 kernel::conv::conv_forward::<R, 1>(x, weight, bias, options, Default::default()).unwrap()
26 }
27
28 fn conv1d_x_backward(
29 x: FloatTensor<Self>,
30 weight: FloatTensor<Self>,
31 output_grad: FloatTensor<Self>,
32 options: ConvOptions<1>,
33 ) -> FloatTensor<Self> {
34 kernel::conv::conv_data_backward(output_grad, weight, x.shape, options, Default::default())
35 .unwrap()
36 }
37
38 fn conv1d_weight_backward(
39 x: FloatTensor<Self>,
40 weight: FloatTensor<Self>,
41 output_grad: FloatTensor<Self>,
42 options: ConvOptions<1>,
43 ) -> FloatTensor<Self> {
44 kernel::conv::conv_weight_backward::<R, 1>(
45 x,
46 output_grad,
47 weight.shape.clone(),
48 options,
49 Default::default(),
50 )
51 .unwrap()
52 }
53
54 fn conv2d(
55 x: FloatTensor<Self>,
56 weight: FloatTensor<Self>,
57 bias: Option<FloatTensor<Self>>,
58 options: ConvOptions<2>,
59 ) -> FloatTensor<Self> {
60 kernel::conv::conv_forward::<R, 2>(x, weight, bias, options, Default::default()).unwrap()
61 }
62
63 fn conv2d_x_backward(
64 x: FloatTensor<Self>,
65 weight: FloatTensor<Self>,
66 output_grad: FloatTensor<Self>,
67 options: ConvOptions<2>,
68 ) -> FloatTensor<Self> {
69 kernel::conv::conv_data_backward(output_grad, weight, x.shape, options, Default::default())
70 .unwrap()
71 }
72
73 fn conv2d_weight_backward(
74 x: FloatTensor<Self>,
75 weight: FloatTensor<Self>,
76 output_grad: FloatTensor<Self>,
77 options: ConvOptions<2>,
78 ) -> FloatTensor<Self> {
79 kernel::conv::conv_weight_backward::<R, 2>(
80 x,
81 output_grad,
82 weight.shape.clone(),
83 options,
84 Default::default(),
85 )
86 .unwrap()
87 }
88
89 fn deform_conv2d(
90 x: FloatTensor<Self>,
91 offset: FloatTensor<Self>,
92 weight: FloatTensor<Self>,
93 mask: Option<FloatTensor<Self>>,
94 bias: Option<FloatTensor<Self>>,
95 options: DeformConvOptions<2>,
96 ) -> FloatTensor<Self> {
97 kernel::conv::deform_conv2d(x, offset, weight, mask, bias, options).unwrap()
98 }
99
100 fn deform_conv2d_backward(
101 x: FloatTensor<Self>,
102 offset: FloatTensor<Self>,
103 weight: FloatTensor<Self>,
104 mask: Option<FloatTensor<Self>>,
105 bias: Option<FloatTensor<Self>>,
106 output_grad: FloatTensor<Self>,
107 options: DeformConvOptions<2>,
108 ) -> DeformConv2dBackward<Self> {
109 let (x, o, w, m, b) = kernel::conv::deform_conv2d_backward(
110 x,
111 offset,
112 weight,
113 mask,
114 bias,
115 output_grad,
116 options,
117 )
118 .unwrap();
119 DeformConv2dBackward::new(x, o, w, m, b)
120 }
121
122 fn conv3d(
123 x: FloatTensor<Self>,
124 weight: FloatTensor<Self>,
125 bias: Option<FloatTensor<Self>>,
126 options: ConvOptions<3>,
127 ) -> FloatTensor<Self> {
128 kernel::conv::conv_forward::<R, 3>(x, weight, bias, options, Default::default()).unwrap()
129 }
130
131 fn conv3d_x_backward(
132 x: FloatTensor<Self>,
133 weight: FloatTensor<Self>,
134 output_grad: FloatTensor<Self>,
135 options: ConvOptions<3>,
136 ) -> FloatTensor<Self> {
137 kernel::conv::conv_data_backward(output_grad, weight, x.shape, options, Default::default())
138 .unwrap()
139 }
140
141 fn conv3d_weight_backward(
142 x: FloatTensor<Self>,
143 weight: FloatTensor<Self>,
144 output_grad: FloatTensor<Self>,
145 options: ConvOptions<3>,
146 ) -> FloatTensor<Self> {
147 kernel::conv::conv_weight_backward::<R, 3>(
148 x,
149 output_grad,
150 weight.shape.clone(),
151 options,
152 Default::default(),
153 )
154 .unwrap()
155 }
156
157 fn conv_transpose2d(
158 x: FloatTensor<Self>,
159 weight: FloatTensor<Self>,
160 bias: Option<FloatTensor<Self>>,
161 options: ConvTransposeOptions<2>,
162 ) -> FloatTensor<Self> {
163 kernel::conv::conv_transpose2d(x, weight, bias, options, ConvTranspose2dStrategy::default())
164 .unwrap()
165 }
166
167 fn conv_transpose3d(
168 x: FloatTensor<Self>,
169 weight: FloatTensor<Self>,
170 bias: Option<FloatTensor<Self>>,
171 options: ConvTransposeOptions<3>,
172 ) -> FloatTensor<Self> {
173 kernel::conv::conv_transpose3d(x, weight, bias, options).expect("Kernel to never fail")
174 }
175
176 fn avg_pool2d(
177 x: FloatTensor<Self>,
178 kernel_size: [usize; 2],
179 stride: [usize; 2],
180 padding: [usize; 2],
181 count_include_pad: bool,
182 ceil_mode: bool,
183 ) -> FloatTensor<Self> {
184 kernel::pool::avg_pool2d(
185 x,
186 kernel_size,
187 stride,
188 padding,
189 count_include_pad,
190 ceil_mode,
191 )
192 }
193
194 fn avg_pool2d_backward(
195 x: FloatTensor<Self>,
196 grad: FloatTensor<Self>,
197 kernel_size: [usize; 2],
198 stride: [usize; 2],
199 padding: [usize; 2],
200 count_include_pad: bool,
201 ceil_mode: bool,
202 ) -> FloatTensor<Self> {
203 kernel::pool::avg_pool2d_backward(
204 x,
205 grad,
206 kernel_size,
207 stride,
208 padding,
209 count_include_pad,
210 ceil_mode,
211 )
212 }
213
214 fn max_pool2d(
215 x: FloatTensor<Self>,
216 kernel_size: [usize; 2],
217 stride: [usize; 2],
218 padding: [usize; 2],
219 dilation: [usize; 2],
220 ceil_mode: bool,
221 ) -> FloatTensor<Self> {
222 kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)
223 }
224
225 fn max_pool2d_with_indices(
226 x: FloatTensor<Self>,
227 kernel_size: [usize; 2],
228 stride: [usize; 2],
229 padding: [usize; 2],
230 dilation: [usize; 2],
231 ceil_mode: bool,
232 ) -> MaxPool2dWithIndices<Self> {
233 let (output, indices) = kernel::pool::max_pool2d_with_indices(
234 x,
235 kernel_size,
236 stride,
237 padding,
238 dilation,
239 ceil_mode,
240 I::dtype(),
241 );
242
243 MaxPool2dWithIndices::new(output, indices)
244 }
245
246 fn max_pool2d_with_indices_backward(
247 x: FloatTensor<Self>,
248 kernel_size: [usize; 2],
249 stride: [usize; 2],
250 padding: [usize; 2],
251 dilation: [usize; 2],
252 ceil_mode: bool,
253 output_grad: FloatTensor<Self>,
254 indices: IntTensor<Self>,
255 ) -> MaxPool2dBackward<Self> {
256 MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward(
257 x,
258 output_grad,
259 indices,
260 kernel_size,
261 stride,
262 padding,
263 dilation,
264 ceil_mode,
265 ))
266 }
267
268 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
269 kernel::pool::adaptive_avg_pool2d(x, output_size)
270 }
271
272 fn adaptive_avg_pool2d_backward(
273 x: FloatTensor<Self>,
274 grad: FloatTensor<Self>,
275 ) -> FloatTensor<Self> {
276 kernel::pool::adaptive_avg_pool2d_backward(x, grad)
277 }
278
279 fn interpolate(
280 x: FloatTensor<Self>,
281 output_size: [usize; 2],
282 options: InterpolateOptions,
283 ) -> FloatTensor<Self> {
284 kernel::interpolate::interpolate(x, output_size, options)
285 }
286
287 fn interpolate_backward(
288 x: FloatTensor<Self>,
289 grad: FloatTensor<Self>,
290 output_size: [usize; 2],
291 options: InterpolateOptions,
292 ) -> FloatTensor<Self> {
293 kernel::interpolate::interpolate_backward(x, grad, output_size, options)
294 }
295
296 fn attention(
297 query: FloatTensor<Self>,
298 key: FloatTensor<Self>,
299 value: FloatTensor<Self>,
300 mask: Option<BoolTensor<Self>>,
301 ) -> FloatTensor<Self> {
302 let out_dtype = query.dtype;
303 kernel::attention::flash_attention(query, key, value, mask, out_dtype)
304 .expect("Kernel to never fail")
305 }
306}