burn_cubecl/ops/
module.rs

1use crate::{
2    CubeBackend, CubeRuntime, FloatElement, IntElement,
3    element::BoolElement,
4    kernel::{self, conv::ConvTranspose2dStrategy},
5};
6use burn_backend::ops::{
7    ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
8    MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
9};
10use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};
11
12impl<R, F, I, BT> ModuleOps<Self> for CubeBackend<R, F, I, BT>
13where
14    R: CubeRuntime,
15    F: FloatElement,
16    I: IntElement,
17    BT: BoolElement,
18{
19    fn conv1d(
20        x: FloatTensor<Self>,
21        weight: FloatTensor<Self>,
22        bias: Option<FloatTensor<Self>>,
23        options: ConvOptions<1>,
24    ) -> FloatTensor<Self> {
25        kernel::conv::conv_forward::<R, 1>(x, weight, bias, options, Default::default()).unwrap()
26    }
27
28    fn conv1d_x_backward(
29        x: FloatTensor<Self>,
30        weight: FloatTensor<Self>,
31        output_grad: FloatTensor<Self>,
32        options: ConvOptions<1>,
33    ) -> FloatTensor<Self> {
34        kernel::conv::conv_data_backward(output_grad, weight, x.shape, options, Default::default())
35            .unwrap()
36    }
37
38    fn conv1d_weight_backward(
39        x: FloatTensor<Self>,
40        weight: FloatTensor<Self>,
41        output_grad: FloatTensor<Self>,
42        options: ConvOptions<1>,
43    ) -> FloatTensor<Self> {
44        kernel::conv::conv_weight_backward::<R, 1>(
45            x,
46            output_grad,
47            weight.shape.clone(),
48            options,
49            Default::default(),
50        )
51        .unwrap()
52    }
53
54    fn conv2d(
55        x: FloatTensor<Self>,
56        weight: FloatTensor<Self>,
57        bias: Option<FloatTensor<Self>>,
58        options: ConvOptions<2>,
59    ) -> FloatTensor<Self> {
60        kernel::conv::conv_forward::<R, 2>(x, weight, bias, options, Default::default()).unwrap()
61    }
62
63    fn conv2d_x_backward(
64        x: FloatTensor<Self>,
65        weight: FloatTensor<Self>,
66        output_grad: FloatTensor<Self>,
67        options: ConvOptions<2>,
68    ) -> FloatTensor<Self> {
69        kernel::conv::conv_data_backward(output_grad, weight, x.shape, options, Default::default())
70            .unwrap()
71    }
72
73    fn conv2d_weight_backward(
74        x: FloatTensor<Self>,
75        weight: FloatTensor<Self>,
76        output_grad: FloatTensor<Self>,
77        options: ConvOptions<2>,
78    ) -> FloatTensor<Self> {
79        kernel::conv::conv_weight_backward::<R, 2>(
80            x,
81            output_grad,
82            weight.shape.clone(),
83            options,
84            Default::default(),
85        )
86        .unwrap()
87    }
88
89    fn deform_conv2d(
90        x: FloatTensor<Self>,
91        offset: FloatTensor<Self>,
92        weight: FloatTensor<Self>,
93        mask: Option<FloatTensor<Self>>,
94        bias: Option<FloatTensor<Self>>,
95        options: DeformConvOptions<2>,
96    ) -> FloatTensor<Self> {
97        kernel::conv::deform_conv2d(x, offset, weight, mask, bias, options).unwrap()
98    }
99
100    fn deform_conv2d_backward(
101        x: FloatTensor<Self>,
102        offset: FloatTensor<Self>,
103        weight: FloatTensor<Self>,
104        mask: Option<FloatTensor<Self>>,
105        bias: Option<FloatTensor<Self>>,
106        output_grad: FloatTensor<Self>,
107        options: DeformConvOptions<2>,
108    ) -> DeformConv2dBackward<Self> {
109        let (x, o, w, m, b) = kernel::conv::deform_conv2d_backward(
110            x,
111            offset,
112            weight,
113            mask,
114            bias,
115            output_grad,
116            options,
117        )
118        .unwrap();
119        DeformConv2dBackward::new(x, o, w, m, b)
120    }
121
122    fn conv3d(
123        x: FloatTensor<Self>,
124        weight: FloatTensor<Self>,
125        bias: Option<FloatTensor<Self>>,
126        options: ConvOptions<3>,
127    ) -> FloatTensor<Self> {
128        kernel::conv::conv_forward::<R, 3>(x, weight, bias, options, Default::default()).unwrap()
129    }
130
131    fn conv3d_x_backward(
132        x: FloatTensor<Self>,
133        weight: FloatTensor<Self>,
134        output_grad: FloatTensor<Self>,
135        options: ConvOptions<3>,
136    ) -> FloatTensor<Self> {
137        kernel::conv::conv_data_backward(output_grad, weight, x.shape, options, Default::default())
138            .unwrap()
139    }
140
141    fn conv3d_weight_backward(
142        x: FloatTensor<Self>,
143        weight: FloatTensor<Self>,
144        output_grad: FloatTensor<Self>,
145        options: ConvOptions<3>,
146    ) -> FloatTensor<Self> {
147        kernel::conv::conv_weight_backward::<R, 3>(
148            x,
149            output_grad,
150            weight.shape.clone(),
151            options,
152            Default::default(),
153        )
154        .unwrap()
155    }
156
157    fn conv_transpose2d(
158        x: FloatTensor<Self>,
159        weight: FloatTensor<Self>,
160        bias: Option<FloatTensor<Self>>,
161        options: ConvTransposeOptions<2>,
162    ) -> FloatTensor<Self> {
163        kernel::conv::conv_transpose2d(x, weight, bias, options, ConvTranspose2dStrategy::default())
164            .unwrap()
165    }
166
167    fn conv_transpose3d(
168        x: FloatTensor<Self>,
169        weight: FloatTensor<Self>,
170        bias: Option<FloatTensor<Self>>,
171        options: ConvTransposeOptions<3>,
172    ) -> FloatTensor<Self> {
173        kernel::conv::conv_transpose3d(x, weight, bias, options).expect("Kernel to never fail")
174    }
175
176    fn avg_pool2d(
177        x: FloatTensor<Self>,
178        kernel_size: [usize; 2],
179        stride: [usize; 2],
180        padding: [usize; 2],
181        count_include_pad: bool,
182        ceil_mode: bool,
183    ) -> FloatTensor<Self> {
184        kernel::pool::avg_pool2d(
185            x,
186            kernel_size,
187            stride,
188            padding,
189            count_include_pad,
190            ceil_mode,
191        )
192    }
193
194    fn avg_pool2d_backward(
195        x: FloatTensor<Self>,
196        grad: FloatTensor<Self>,
197        kernel_size: [usize; 2],
198        stride: [usize; 2],
199        padding: [usize; 2],
200        count_include_pad: bool,
201        ceil_mode: bool,
202    ) -> FloatTensor<Self> {
203        kernel::pool::avg_pool2d_backward(
204            x,
205            grad,
206            kernel_size,
207            stride,
208            padding,
209            count_include_pad,
210            ceil_mode,
211        )
212    }
213
214    fn max_pool2d(
215        x: FloatTensor<Self>,
216        kernel_size: [usize; 2],
217        stride: [usize; 2],
218        padding: [usize; 2],
219        dilation: [usize; 2],
220        ceil_mode: bool,
221    ) -> FloatTensor<Self> {
222        kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)
223    }
224
225    fn max_pool2d_with_indices(
226        x: FloatTensor<Self>,
227        kernel_size: [usize; 2],
228        stride: [usize; 2],
229        padding: [usize; 2],
230        dilation: [usize; 2],
231        ceil_mode: bool,
232    ) -> MaxPool2dWithIndices<Self> {
233        let (output, indices) = kernel::pool::max_pool2d_with_indices(
234            x,
235            kernel_size,
236            stride,
237            padding,
238            dilation,
239            ceil_mode,
240            I::dtype(),
241        );
242
243        MaxPool2dWithIndices::new(output, indices)
244    }
245
246    fn max_pool2d_with_indices_backward(
247        x: FloatTensor<Self>,
248        kernel_size: [usize; 2],
249        stride: [usize; 2],
250        padding: [usize; 2],
251        dilation: [usize; 2],
252        ceil_mode: bool,
253        output_grad: FloatTensor<Self>,
254        indices: IntTensor<Self>,
255    ) -> MaxPool2dBackward<Self> {
256        MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward(
257            x,
258            output_grad,
259            indices,
260            kernel_size,
261            stride,
262            padding,
263            dilation,
264            ceil_mode,
265        ))
266    }
267
268    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
269        kernel::pool::adaptive_avg_pool2d(x, output_size)
270    }
271
272    fn adaptive_avg_pool2d_backward(
273        x: FloatTensor<Self>,
274        grad: FloatTensor<Self>,
275    ) -> FloatTensor<Self> {
276        kernel::pool::adaptive_avg_pool2d_backward(x, grad)
277    }
278
279    fn interpolate(
280        x: FloatTensor<Self>,
281        output_size: [usize; 2],
282        options: InterpolateOptions,
283    ) -> FloatTensor<Self> {
284        kernel::interpolate::interpolate(x, output_size, options)
285    }
286
287    fn interpolate_backward(
288        x: FloatTensor<Self>,
289        grad: FloatTensor<Self>,
290        output_size: [usize; 2],
291        options: InterpolateOptions,
292    ) -> FloatTensor<Self> {
293        kernel::interpolate::interpolate_backward(x, grad, output_size, options)
294    }
295
296    fn attention(
297        query: FloatTensor<Self>,
298        key: FloatTensor<Self>,
299        value: FloatTensor<Self>,
300        mask: Option<BoolTensor<Self>>,
301    ) -> FloatTensor<Self> {
302        let out_dtype = query.dtype;
303        kernel::attention::flash_attention(query, key, value, mask, out_dtype)
304            .expect("Kernel to never fail")
305    }
306}