1use crate::{
2 CubeBackend, CubeRuntime, FloatElement, IntElement,
3 element::BoolElement,
4 kernel::{self, conv::ConvTranspose2dStrategy},
5};
6use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
7use burn_backend::{
8 TensorMetadata,
9 ops::{
10 AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
11 DeformConvOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
12 },
13};
14
15impl<R, F, I, BT> ModuleOps<Self> for CubeBackend<R, F, I, BT>
16where
17 R: CubeRuntime,
18 F: FloatElement,
19 I: IntElement,
20 BT: BoolElement,
21{
22 fn conv1d(
23 x: FloatTensor<Self>,
24 weight: FloatTensor<Self>,
25 bias: Option<FloatTensor<Self>>,
26 options: ConvOptions<1>,
27 ) -> FloatTensor<Self> {
28 kernel::conv::conv_forward::<R, 1>(x, weight, bias, options, Default::default()).unwrap()
29 }
30
31 fn conv1d_x_backward(
32 x: FloatTensor<Self>,
33 weight: FloatTensor<Self>,
34 output_grad: FloatTensor<Self>,
35 options: ConvOptions<1>,
36 ) -> FloatTensor<Self> {
37 kernel::conv::conv_data_backward(
38 output_grad,
39 weight,
40 x.shape(),
41 options,
42 Default::default(),
43 )
44 .unwrap()
45 }
46
47 fn conv1d_weight_backward(
48 x: FloatTensor<Self>,
49 weight: FloatTensor<Self>,
50 output_grad: FloatTensor<Self>,
51 options: ConvOptions<1>,
52 ) -> FloatTensor<Self> {
53 kernel::conv::conv_weight_backward::<R, 1>(
54 x,
55 output_grad,
56 weight.shape(),
57 options,
58 Default::default(),
59 )
60 .unwrap()
61 }
62
63 fn conv2d(
64 x: FloatTensor<Self>,
65 weight: FloatTensor<Self>,
66 bias: Option<FloatTensor<Self>>,
67 options: ConvOptions<2>,
68 ) -> FloatTensor<Self> {
69 kernel::conv::conv_forward::<R, 2>(x, weight, bias, options, Default::default()).unwrap()
70 }
71
72 fn conv2d_x_backward(
73 x: FloatTensor<Self>,
74 weight: FloatTensor<Self>,
75 output_grad: FloatTensor<Self>,
76 options: ConvOptions<2>,
77 ) -> FloatTensor<Self> {
78 kernel::conv::conv_data_backward(
79 output_grad,
80 weight,
81 x.shape(),
82 options,
83 Default::default(),
84 )
85 .unwrap()
86 }
87
88 fn conv2d_weight_backward(
89 x: FloatTensor<Self>,
90 weight: FloatTensor<Self>,
91 output_grad: FloatTensor<Self>,
92 options: ConvOptions<2>,
93 ) -> FloatTensor<Self> {
94 kernel::conv::conv_weight_backward::<R, 2>(
95 x,
96 output_grad,
97 weight.shape(),
98 options,
99 Default::default(),
100 )
101 .unwrap()
102 }
103
104 fn deform_conv2d(
105 x: FloatTensor<Self>,
106 offset: FloatTensor<Self>,
107 weight: FloatTensor<Self>,
108 mask: Option<FloatTensor<Self>>,
109 bias: Option<FloatTensor<Self>>,
110 options: DeformConvOptions<2>,
111 ) -> FloatTensor<Self> {
112 kernel::conv::deform_conv2d(x, offset, weight, mask, bias, options).unwrap()
113 }
114
115 fn deform_conv2d_backward(
116 x: FloatTensor<Self>,
117 offset: FloatTensor<Self>,
118 weight: FloatTensor<Self>,
119 mask: Option<FloatTensor<Self>>,
120 bias: Option<FloatTensor<Self>>,
121 output_grad: FloatTensor<Self>,
122 options: DeformConvOptions<2>,
123 ) -> DeformConv2dBackward<Self> {
124 let (x, o, w, m, b) = kernel::conv::deform_conv2d_backward(
125 x,
126 offset,
127 weight,
128 mask,
129 bias,
130 output_grad,
131 options,
132 )
133 .unwrap();
134 DeformConv2dBackward::new(x, o, w, m, b)
135 }
136
137 fn conv3d(
138 x: FloatTensor<Self>,
139 weight: FloatTensor<Self>,
140 bias: Option<FloatTensor<Self>>,
141 options: ConvOptions<3>,
142 ) -> FloatTensor<Self> {
143 kernel::conv::conv_forward::<R, 3>(x, weight, bias, options, Default::default()).unwrap()
144 }
145
146 fn conv3d_x_backward(
147 x: FloatTensor<Self>,
148 weight: FloatTensor<Self>,
149 output_grad: FloatTensor<Self>,
150 options: ConvOptions<3>,
151 ) -> FloatTensor<Self> {
152 kernel::conv::conv_data_backward(
153 output_grad,
154 weight,
155 x.shape(),
156 options,
157 Default::default(),
158 )
159 .unwrap()
160 }
161
162 fn conv3d_weight_backward(
163 x: FloatTensor<Self>,
164 weight: FloatTensor<Self>,
165 output_grad: FloatTensor<Self>,
166 options: ConvOptions<3>,
167 ) -> FloatTensor<Self> {
168 kernel::conv::conv_weight_backward::<R, 3>(
169 x,
170 output_grad,
171 weight.shape(),
172 options,
173 Default::default(),
174 )
175 .unwrap()
176 }
177
178 fn conv_transpose2d(
179 x: FloatTensor<Self>,
180 weight: FloatTensor<Self>,
181 bias: Option<FloatTensor<Self>>,
182 options: ConvTransposeOptions<2>,
183 ) -> FloatTensor<Self> {
184 kernel::conv::conv_transpose2d(x, weight, bias, options, ConvTranspose2dStrategy::default())
185 .unwrap()
186 }
187
188 fn conv_transpose3d(
189 x: FloatTensor<Self>,
190 weight: FloatTensor<Self>,
191 bias: Option<FloatTensor<Self>>,
192 options: ConvTransposeOptions<3>,
193 ) -> FloatTensor<Self> {
194 kernel::conv::conv_transpose3d(x, weight, bias, options).expect("Kernel to never fail")
195 }
196
197 fn avg_pool2d(
198 x: FloatTensor<Self>,
199 kernel_size: [usize; 2],
200 stride: [usize; 2],
201 padding: [usize; 2],
202 count_include_pad: bool,
203 ceil_mode: bool,
204 ) -> FloatTensor<Self> {
205 kernel::pool::avg_pool2d(
206 x,
207 kernel_size,
208 stride,
209 padding,
210 count_include_pad,
211 ceil_mode,
212 )
213 }
214
215 fn avg_pool2d_backward(
216 x: FloatTensor<Self>,
217 grad: FloatTensor<Self>,
218 kernel_size: [usize; 2],
219 stride: [usize; 2],
220 padding: [usize; 2],
221 count_include_pad: bool,
222 ceil_mode: bool,
223 ) -> FloatTensor<Self> {
224 kernel::pool::avg_pool2d_backward(
225 x,
226 grad,
227 kernel_size,
228 stride,
229 padding,
230 count_include_pad,
231 ceil_mode,
232 )
233 }
234
235 fn max_pool2d(
236 x: FloatTensor<Self>,
237 kernel_size: [usize; 2],
238 stride: [usize; 2],
239 padding: [usize; 2],
240 dilation: [usize; 2],
241 ceil_mode: bool,
242 ) -> FloatTensor<Self> {
243 kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)
244 }
245
246 fn max_pool2d_with_indices(
247 x: FloatTensor<Self>,
248 kernel_size: [usize; 2],
249 stride: [usize; 2],
250 padding: [usize; 2],
251 dilation: [usize; 2],
252 ceil_mode: bool,
253 ) -> MaxPool2dWithIndices<Self> {
254 let (output, indices) = kernel::pool::max_pool2d_with_indices(
255 x,
256 kernel_size,
257 stride,
258 padding,
259 dilation,
260 ceil_mode,
261 I::dtype(),
262 );
263
264 MaxPool2dWithIndices::new(output, indices)
265 }
266
267 fn max_pool2d_with_indices_backward(
268 x: FloatTensor<Self>,
269 kernel_size: [usize; 2],
270 stride: [usize; 2],
271 padding: [usize; 2],
272 dilation: [usize; 2],
273 ceil_mode: bool,
274 output_grad: FloatTensor<Self>,
275 indices: IntTensor<Self>,
276 ) -> MaxPool2dBackward<Self> {
277 MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward(
278 x,
279 output_grad,
280 indices,
281 kernel_size,
282 stride,
283 padding,
284 dilation,
285 ceil_mode,
286 ))
287 }
288
289 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
290 kernel::pool::adaptive_avg_pool2d(x, output_size)
291 }
292
293 fn adaptive_avg_pool2d_backward(
294 x: FloatTensor<Self>,
295 grad: FloatTensor<Self>,
296 ) -> FloatTensor<Self> {
297 kernel::pool::adaptive_avg_pool2d_backward(x, grad)
298 }
299
300 fn interpolate(
301 x: FloatTensor<Self>,
302 output_size: [usize; 2],
303 options: InterpolateOptions,
304 ) -> FloatTensor<Self> {
305 kernel::interpolate::interpolate(x, output_size, options)
306 }
307
308 fn interpolate_backward(
309 x: FloatTensor<Self>,
310 grad: FloatTensor<Self>,
311 output_size: [usize; 2],
312 options: InterpolateOptions,
313 ) -> FloatTensor<Self> {
314 kernel::interpolate::interpolate_backward(x, grad, output_size, options)
315 }
316
317 fn attention(
318 query: FloatTensor<Self>,
319 key: FloatTensor<Self>,
320 value: FloatTensor<Self>,
321 mask: Option<BoolTensor<Self>>,
322 attn_bias: Option<FloatTensor<Self>>,
323 options: AttentionModuleOptions,
324 ) -> FloatTensor<Self> {
325 if attn_bias.is_some() || options.softcap.is_some() || options.scale.is_some() {
327 return burn_backend::ops::attention::attention_fallback::<Self>(
328 query, key, value, mask, attn_bias, options,
329 );
330 }
331
332 kernel::attention::attention(
333 query,
334 key,
335 value,
336 mask,
337 attn_bias,
338 options,
339 Default::default(),
340 )
341 .expect("Kernel to never fail")
342 }
343
344 fn has_ctc_loss_backward() -> bool {
345 true
346 }
347
348 fn ctc_loss(
349 log_probs: FloatTensor<Self>,
350 targets: IntTensor<Self>,
351 input_lengths: IntTensor<Self>,
352 target_lengths: IntTensor<Self>,
353 blank: usize,
354 ) -> FloatTensor<Self> {
355 kernel::ctc::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank)
356 }
357
358 fn ctc_loss_backward(
359 log_probs: FloatTensor<Self>,
360 targets: IntTensor<Self>,
361 input_lengths: IntTensor<Self>,
362 target_lengths: IntTensor<Self>,
363 grad_loss: FloatTensor<Self>,
364 blank: usize,
365 ) -> FloatTensor<Self> {
366 let (log_alpha_full, log_beta_full, nll) = kernel::ctc::ctc_alpha_beta(
367 log_probs.clone(),
368 targets.clone(),
369 input_lengths.clone(),
370 target_lengths,
371 blank,
372 );
373 burn_backend::ops::ctc::ctc_grad_from_alpha_beta_default::<Self>(
374 log_probs,
375 targets,
376 input_lengths,
377 grad_loss,
378 log_alpha_full,
379 log_beta_full,
380 nll,
381 blank,
382 )
383 }
384
385 fn rfft(
386 signal: FloatTensor<Self>,
387 dim: usize,
388 n: Option<usize>,
389 ) -> (FloatTensor<Self>, FloatTensor<Self>) {
390 kernel::fft::rfft(signal, dim, n)
391 }
392
393 fn irfft(
394 spectrum_re: FloatTensor<Self>,
395 spectrum_im: FloatTensor<Self>,
396 dim: usize,
397 n: Option<usize>,
398 ) -> FloatTensor<Self> {
399 kernel::fft::irfft(spectrum_re, spectrum_im, dim, n)
400 }
401}