burn_cubecl/ops/
module_ops.rs

1use crate::{
2    CubeBackend, CubeRuntime, FloatElement, IntElement,
3    element::BoolElement,
4    execute_with_dtype,
5    kernel::{
6        self,
7        conv::{ConvStrategy, ConvTranspose2dStrategy},
8    },
9};
10use burn_tensor::ops::{
11    ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
12    MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
13};
14use burn_tensor::ops::{FloatTensor, IntTensor};
15
16impl<R, F, I, BT> ModuleOps<Self> for CubeBackend<R, F, I, BT>
17where
18    R: CubeRuntime,
19    F: FloatElement,
20    I: IntElement,
21    BT: BoolElement,
22{
23    fn conv1d(
24        x: FloatTensor<Self>,
25        weight: FloatTensor<Self>,
26        bias: Option<FloatTensor<Self>>,
27        options: ConvOptions<1>,
28    ) -> FloatTensor<Self> {
29        execute_with_dtype!(
30            float(x.dtype),
31            E,
32            kernel::conv::conv::<R, E, 1>(x, weight, bias, options, ConvStrategy::default())
33                .unwrap()
34        )
35    }
36
37    fn conv2d(
38        x: FloatTensor<Self>,
39        weight: FloatTensor<Self>,
40        bias: Option<FloatTensor<Self>>,
41        options: ConvOptions<2>,
42    ) -> FloatTensor<Self> {
43        execute_with_dtype!(
44            float(x.dtype),
45            E,
46            kernel::conv::conv::<R, E, 2>(x, weight, bias, options, ConvStrategy::default())
47                .unwrap()
48        )
49    }
50
51    fn deform_conv2d(
52        x: FloatTensor<Self>,
53        offset: FloatTensor<Self>,
54        weight: FloatTensor<Self>,
55        mask: Option<FloatTensor<Self>>,
56        bias: Option<FloatTensor<Self>>,
57        options: DeformConvOptions<2>,
58    ) -> FloatTensor<Self> {
59        execute_with_dtype!(
60            float(x.dtype),
61            E,
62            kernel::conv::deform_conv2d::<R, E>(x, offset, weight, mask, bias, options).unwrap()
63        )
64    }
65
66    fn deform_conv2d_backward(
67        x: FloatTensor<Self>,
68        offset: FloatTensor<Self>,
69        weight: FloatTensor<Self>,
70        mask: Option<FloatTensor<Self>>,
71        bias: Option<FloatTensor<Self>>,
72        output_grad: FloatTensor<Self>,
73        options: DeformConvOptions<2>,
74    ) -> DeformConv2dBackward<Self> {
75        execute_with_dtype!(float(x.dtype), E, {
76            let (x, o, w, m, b) = kernel::conv::deform_conv2d_backward::<R, E, I, BT>(
77                x,
78                offset,
79                weight,
80                mask,
81                bias,
82                output_grad,
83                options,
84            )
85            .unwrap();
86            DeformConv2dBackward::new(x, o, w, m, b)
87        })
88    }
89
90    fn conv3d(
91        x: FloatTensor<Self>,
92        weight: FloatTensor<Self>,
93        bias: Option<FloatTensor<Self>>,
94        options: ConvOptions<3>,
95    ) -> FloatTensor<Self> {
96        execute_with_dtype!(
97            float(x.dtype),
98            E,
99            kernel::conv::conv::<R, E, 3>(x, weight, bias, options, ConvStrategy::Direct).unwrap()
100        )
101    }
102
103    fn conv_transpose2d(
104        x: FloatTensor<Self>,
105        weight: FloatTensor<Self>,
106        bias: Option<FloatTensor<Self>>,
107        options: ConvTransposeOptions<2>,
108    ) -> FloatTensor<Self> {
109        execute_with_dtype!(
110            float(x.dtype),
111            E,
112            kernel::conv::conv_transpose2d::<R, E, I>(
113                x,
114                weight,
115                bias,
116                options,
117                ConvTranspose2dStrategy::default(),
118            )
119            .unwrap()
120        )
121    }
122
123    fn conv_transpose3d(
124        x: FloatTensor<Self>,
125        weight: FloatTensor<Self>,
126        bias: Option<FloatTensor<Self>>,
127        options: ConvTransposeOptions<3>,
128    ) -> FloatTensor<Self> {
129        execute_with_dtype!(
130            float(x.dtype),
131            E,
132            kernel::conv::conv_transpose3d::<R, E>(x, weight, bias, options)
133        )
134    }
135
136    fn avg_pool2d(
137        x: FloatTensor<Self>,
138        kernel_size: [usize; 2],
139        stride: [usize; 2],
140        padding: [usize; 2],
141        count_include_pad: bool,
142    ) -> FloatTensor<Self> {
143        execute_with_dtype!(
144            float(x.dtype),
145            E,
146            kernel::pool::avg_pool2d::<R, E>(x, kernel_size, stride, padding, count_include_pad)
147        )
148    }
149
150    fn avg_pool2d_backward(
151        x: FloatTensor<Self>,
152        grad: FloatTensor<Self>,
153        kernel_size: [usize; 2],
154        stride: [usize; 2],
155        padding: [usize; 2],
156        count_include_pad: bool,
157    ) -> FloatTensor<Self> {
158        execute_with_dtype!(
159            float(x.dtype),
160            E,
161            kernel::pool::avg_pool2d_backward::<R, E>(
162                x,
163                grad,
164                kernel_size,
165                stride,
166                padding,
167                count_include_pad,
168            )
169        )
170    }
171
172    fn max_pool2d(
173        x: FloatTensor<Self>,
174        kernel_size: [usize; 2],
175        stride: [usize; 2],
176        padding: [usize; 2],
177        dilation: [usize; 2],
178    ) -> FloatTensor<Self> {
179        execute_with_dtype!(
180            float(x.dtype),
181            E,
182            kernel::pool::max_pool2d::<R, E>(x, kernel_size, stride, padding, dilation)
183        )
184    }
185
186    fn max_pool2d_with_indices(
187        x: FloatTensor<Self>,
188        kernel_size: [usize; 2],
189        stride: [usize; 2],
190        padding: [usize; 2],
191        dilation: [usize; 2],
192    ) -> MaxPool2dWithIndices<Self> {
193        execute_with_dtype!(float(x.dtype), E, {
194            let (output, indices) = kernel::pool::max_pool2d_with_indices::<R, E, I>(
195                x,
196                kernel_size,
197                stride,
198                padding,
199                dilation,
200            );
201
202            MaxPool2dWithIndices::new(output, indices)
203        })
204    }
205
206    fn max_pool2d_with_indices_backward(
207        x: FloatTensor<Self>,
208        kernel_size: [usize; 2],
209        stride: [usize; 2],
210        padding: [usize; 2],
211        dilation: [usize; 2],
212        output_grad: FloatTensor<Self>,
213        indices: IntTensor<Self>,
214    ) -> MaxPool2dBackward<Self> {
215        execute_with_dtype!(
216            int(indices.dtype),
217            I,
218            execute_with_dtype!(
219                float(x.dtype),
220                E,
221                MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward::<R, E, I>(
222                    x,
223                    output_grad,
224                    indices,
225                    kernel_size,
226                    stride,
227                    padding,
228                    dilation,
229                ))
230            )
231        )
232    }
233
234    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
235        execute_with_dtype!(
236            float(x.dtype),
237            E,
238            kernel::pool::adaptive_avg_pool2d::<R, E>(x, output_size)
239        )
240    }
241
242    fn adaptive_avg_pool2d_backward(
243        x: FloatTensor<Self>,
244        grad: FloatTensor<Self>,
245    ) -> FloatTensor<Self> {
246        execute_with_dtype!(
247            float(x.dtype),
248            E,
249            kernel::pool::adaptive_avg_pool2d_backward::<R, E>(x, grad)
250        )
251    }
252
253    fn interpolate(
254        x: FloatTensor<Self>,
255        output_size: [usize; 2],
256        options: InterpolateOptions,
257    ) -> FloatTensor<Self> {
258        execute_with_dtype!(
259            float(x.dtype),
260            E,
261            kernel::interpolate::interpolate::<R, E>(x, output_size, options)
262        )
263    }
264
265    fn interpolate_backward(
266        x: FloatTensor<Self>,
267        grad: FloatTensor<Self>,
268        output_size: [usize; 2],
269        options: InterpolateOptions,
270    ) -> FloatTensor<Self> {
271        execute_with_dtype!(
272            float(x.dtype),
273            E,
274            kernel::interpolate::interpolate_backward::<R, E>(x, grad, output_size, options)
275        )
276    }
277}