Skip to main content

burn_cubecl/ops/
module.rs

1use crate::{
2    CubeBackend, CubeRuntime, FloatElement, IntElement,
3    element::BoolElement,
4    kernel::{self, conv::ConvTranspose2dStrategy},
5};
6use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
7use burn_backend::{
8    TensorMetadata,
9    ops::{
10        AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,
11        DeformConvOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
12    },
13};
14
15impl<R, F, I, BT> ModuleOps<Self> for CubeBackend<R, F, I, BT>
16where
17    R: CubeRuntime,
18    F: FloatElement,
19    I: IntElement,
20    BT: BoolElement,
21{
22    fn conv1d(
23        x: FloatTensor<Self>,
24        weight: FloatTensor<Self>,
25        bias: Option<FloatTensor<Self>>,
26        options: ConvOptions<1>,
27    ) -> FloatTensor<Self> {
28        kernel::conv::conv_forward::<R, 1>(x, weight, bias, options, Default::default()).unwrap()
29    }
30
31    fn conv1d_x_backward(
32        x: FloatTensor<Self>,
33        weight: FloatTensor<Self>,
34        output_grad: FloatTensor<Self>,
35        options: ConvOptions<1>,
36    ) -> FloatTensor<Self> {
37        kernel::conv::conv_data_backward(
38            output_grad,
39            weight,
40            x.shape(),
41            options,
42            Default::default(),
43        )
44        .unwrap()
45    }
46
47    fn conv1d_weight_backward(
48        x: FloatTensor<Self>,
49        weight: FloatTensor<Self>,
50        output_grad: FloatTensor<Self>,
51        options: ConvOptions<1>,
52    ) -> FloatTensor<Self> {
53        kernel::conv::conv_weight_backward::<R, 1>(
54            x,
55            output_grad,
56            weight.shape(),
57            options,
58            Default::default(),
59        )
60        .unwrap()
61    }
62
63    fn conv2d(
64        x: FloatTensor<Self>,
65        weight: FloatTensor<Self>,
66        bias: Option<FloatTensor<Self>>,
67        options: ConvOptions<2>,
68    ) -> FloatTensor<Self> {
69        kernel::conv::conv_forward::<R, 2>(x, weight, bias, options, Default::default()).unwrap()
70    }
71
72    fn conv2d_x_backward(
73        x: FloatTensor<Self>,
74        weight: FloatTensor<Self>,
75        output_grad: FloatTensor<Self>,
76        options: ConvOptions<2>,
77    ) -> FloatTensor<Self> {
78        kernel::conv::conv_data_backward(
79            output_grad,
80            weight,
81            x.shape(),
82            options,
83            Default::default(),
84        )
85        .unwrap()
86    }
87
88    fn conv2d_weight_backward(
89        x: FloatTensor<Self>,
90        weight: FloatTensor<Self>,
91        output_grad: FloatTensor<Self>,
92        options: ConvOptions<2>,
93    ) -> FloatTensor<Self> {
94        kernel::conv::conv_weight_backward::<R, 2>(
95            x,
96            output_grad,
97            weight.shape(),
98            options,
99            Default::default(),
100        )
101        .unwrap()
102    }
103
104    fn deform_conv2d(
105        x: FloatTensor<Self>,
106        offset: FloatTensor<Self>,
107        weight: FloatTensor<Self>,
108        mask: Option<FloatTensor<Self>>,
109        bias: Option<FloatTensor<Self>>,
110        options: DeformConvOptions<2>,
111    ) -> FloatTensor<Self> {
112        kernel::conv::deform_conv2d(x, offset, weight, mask, bias, options).unwrap()
113    }
114
115    fn deform_conv2d_backward(
116        x: FloatTensor<Self>,
117        offset: FloatTensor<Self>,
118        weight: FloatTensor<Self>,
119        mask: Option<FloatTensor<Self>>,
120        bias: Option<FloatTensor<Self>>,
121        output_grad: FloatTensor<Self>,
122        options: DeformConvOptions<2>,
123    ) -> DeformConv2dBackward<Self> {
124        let (x, o, w, m, b) = kernel::conv::deform_conv2d_backward(
125            x,
126            offset,
127            weight,
128            mask,
129            bias,
130            output_grad,
131            options,
132        )
133        .unwrap();
134        DeformConv2dBackward::new(x, o, w, m, b)
135    }
136
137    fn conv3d(
138        x: FloatTensor<Self>,
139        weight: FloatTensor<Self>,
140        bias: Option<FloatTensor<Self>>,
141        options: ConvOptions<3>,
142    ) -> FloatTensor<Self> {
143        kernel::conv::conv_forward::<R, 3>(x, weight, bias, options, Default::default()).unwrap()
144    }
145
146    fn conv3d_x_backward(
147        x: FloatTensor<Self>,
148        weight: FloatTensor<Self>,
149        output_grad: FloatTensor<Self>,
150        options: ConvOptions<3>,
151    ) -> FloatTensor<Self> {
152        kernel::conv::conv_data_backward(
153            output_grad,
154            weight,
155            x.shape(),
156            options,
157            Default::default(),
158        )
159        .unwrap()
160    }
161
162    fn conv3d_weight_backward(
163        x: FloatTensor<Self>,
164        weight: FloatTensor<Self>,
165        output_grad: FloatTensor<Self>,
166        options: ConvOptions<3>,
167    ) -> FloatTensor<Self> {
168        kernel::conv::conv_weight_backward::<R, 3>(
169            x,
170            output_grad,
171            weight.shape(),
172            options,
173            Default::default(),
174        )
175        .unwrap()
176    }
177
178    fn conv_transpose2d(
179        x: FloatTensor<Self>,
180        weight: FloatTensor<Self>,
181        bias: Option<FloatTensor<Self>>,
182        options: ConvTransposeOptions<2>,
183    ) -> FloatTensor<Self> {
184        kernel::conv::conv_transpose2d(x, weight, bias, options, ConvTranspose2dStrategy::default())
185            .unwrap()
186    }
187
188    fn conv_transpose3d(
189        x: FloatTensor<Self>,
190        weight: FloatTensor<Self>,
191        bias: Option<FloatTensor<Self>>,
192        options: ConvTransposeOptions<3>,
193    ) -> FloatTensor<Self> {
194        kernel::conv::conv_transpose3d(x, weight, bias, options).expect("Kernel to never fail")
195    }
196
197    fn avg_pool2d(
198        x: FloatTensor<Self>,
199        kernel_size: [usize; 2],
200        stride: [usize; 2],
201        padding: [usize; 2],
202        count_include_pad: bool,
203        ceil_mode: bool,
204    ) -> FloatTensor<Self> {
205        kernel::pool::avg_pool2d(
206            x,
207            kernel_size,
208            stride,
209            padding,
210            count_include_pad,
211            ceil_mode,
212        )
213    }
214
215    fn avg_pool2d_backward(
216        x: FloatTensor<Self>,
217        grad: FloatTensor<Self>,
218        kernel_size: [usize; 2],
219        stride: [usize; 2],
220        padding: [usize; 2],
221        count_include_pad: bool,
222        ceil_mode: bool,
223    ) -> FloatTensor<Self> {
224        kernel::pool::avg_pool2d_backward(
225            x,
226            grad,
227            kernel_size,
228            stride,
229            padding,
230            count_include_pad,
231            ceil_mode,
232        )
233    }
234
235    fn max_pool2d(
236        x: FloatTensor<Self>,
237        kernel_size: [usize; 2],
238        stride: [usize; 2],
239        padding: [usize; 2],
240        dilation: [usize; 2],
241        ceil_mode: bool,
242    ) -> FloatTensor<Self> {
243        kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)
244    }
245
246    fn max_pool2d_with_indices(
247        x: FloatTensor<Self>,
248        kernel_size: [usize; 2],
249        stride: [usize; 2],
250        padding: [usize; 2],
251        dilation: [usize; 2],
252        ceil_mode: bool,
253    ) -> MaxPool2dWithIndices<Self> {
254        let (output, indices) = kernel::pool::max_pool2d_with_indices(
255            x,
256            kernel_size,
257            stride,
258            padding,
259            dilation,
260            ceil_mode,
261            I::dtype(),
262        );
263
264        MaxPool2dWithIndices::new(output, indices)
265    }
266
267    fn max_pool2d_with_indices_backward(
268        x: FloatTensor<Self>,
269        kernel_size: [usize; 2],
270        stride: [usize; 2],
271        padding: [usize; 2],
272        dilation: [usize; 2],
273        ceil_mode: bool,
274        output_grad: FloatTensor<Self>,
275        indices: IntTensor<Self>,
276    ) -> MaxPool2dBackward<Self> {
277        MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward(
278            x,
279            output_grad,
280            indices,
281            kernel_size,
282            stride,
283            padding,
284            dilation,
285            ceil_mode,
286        ))
287    }
288
289    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
290        kernel::pool::adaptive_avg_pool2d(x, output_size)
291    }
292
293    fn adaptive_avg_pool2d_backward(
294        x: FloatTensor<Self>,
295        grad: FloatTensor<Self>,
296    ) -> FloatTensor<Self> {
297        kernel::pool::adaptive_avg_pool2d_backward(x, grad)
298    }
299
300    fn interpolate(
301        x: FloatTensor<Self>,
302        output_size: [usize; 2],
303        options: InterpolateOptions,
304    ) -> FloatTensor<Self> {
305        kernel::interpolate::interpolate(x, output_size, options)
306    }
307
308    fn interpolate_backward(
309        x: FloatTensor<Self>,
310        grad: FloatTensor<Self>,
311        output_size: [usize; 2],
312        options: InterpolateOptions,
313    ) -> FloatTensor<Self> {
314        kernel::interpolate::interpolate_backward(x, grad, output_size, options)
315    }
316
317    fn attention(
318        query: FloatTensor<Self>,
319        key: FloatTensor<Self>,
320        value: FloatTensor<Self>,
321        mask: Option<BoolTensor<Self>>,
322        attn_bias: Option<FloatTensor<Self>>,
323        options: AttentionModuleOptions,
324    ) -> FloatTensor<Self> {
325        // Fall back to naive attention for features the flash kernel doesn't support.
326        if attn_bias.is_some() || options.softcap.is_some() || options.scale.is_some() {
327            return burn_backend::ops::attention::attention_fallback::<Self>(
328                query, key, value, mask, attn_bias, options,
329            );
330        }
331
332        kernel::attention::attention(
333            query,
334            key,
335            value,
336            mask,
337            attn_bias,
338            options,
339            Default::default(),
340        )
341        .expect("Kernel to never fail")
342    }
343
344    fn has_ctc_loss_backward() -> bool {
345        true
346    }
347
348    fn ctc_loss(
349        log_probs: FloatTensor<Self>,
350        targets: IntTensor<Self>,
351        input_lengths: IntTensor<Self>,
352        target_lengths: IntTensor<Self>,
353        blank: usize,
354    ) -> FloatTensor<Self> {
355        kernel::ctc::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank)
356    }
357
358    fn ctc_loss_backward(
359        log_probs: FloatTensor<Self>,
360        targets: IntTensor<Self>,
361        input_lengths: IntTensor<Self>,
362        target_lengths: IntTensor<Self>,
363        grad_loss: FloatTensor<Self>,
364        blank: usize,
365    ) -> FloatTensor<Self> {
366        let (log_alpha_full, log_beta_full, nll) = kernel::ctc::ctc_alpha_beta(
367            log_probs.clone(),
368            targets.clone(),
369            input_lengths.clone(),
370            target_lengths,
371            blank,
372        );
373        burn_backend::ops::ctc::ctc_grad_from_alpha_beta_default::<Self>(
374            log_probs,
375            targets,
376            input_lengths,
377            grad_loss,
378            log_alpha_full,
379            log_beta_full,
380            nll,
381            blank,
382        )
383    }
384
385    fn rfft(
386        signal: FloatTensor<Self>,
387        dim: usize,
388        n: Option<usize>,
389    ) -> (FloatTensor<Self>, FloatTensor<Self>) {
390        kernel::fft::rfft(signal, dim, n)
391    }
392
393    fn irfft(
394        spectrum_re: FloatTensor<Self>,
395        spectrum_im: FloatTensor<Self>,
396        dim: usize,
397        n: Option<usize>,
398    ) -> FloatTensor<Self> {
399        kernel::fft::irfft(spectrum_re, spectrum_im, dim, n)
400    }
401}