Skip to main content

burn_candle/ops/
module.rs

1use burn_backend::{
2    Shape,
3    ops::{
4        ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
5        InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
6        UnfoldOptions,
7    },
8    tensor::{FloatTensor, IntTensor},
9};
10use candle_core::ToUsize2;
11
12use crate::{
13    Candle, CandleTensor,
14    element::{CandleElement, FloatCandleElement, IntCandleElement},
15    ops::base::reshape,
16};
17
18impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I> {
19    fn conv1d(
20        x: FloatTensor<Self>,
21        weight: FloatTensor<Self>,
22        bias: Option<FloatTensor<Self>>,
23        options: ConvOptions<1>,
24    ) -> FloatTensor<Self> {
25        let conv = x
26            .tensor
27            .conv1d(
28                &weight.tensor,
29                options.padding[0],
30                options.stride[0],
31                options.dilation[0],
32                options.groups,
33            )
34            .unwrap();
35        CandleTensor::new(match bias {
36            Some(bias) => conv
37                .broadcast_add(&bias.tensor.unsqueeze(1).unwrap())
38                .unwrap(),
39            None => conv,
40        })
41    }
42
43    fn conv2d(
44        x: FloatTensor<Self>,
45        weight: FloatTensor<Self>,
46        bias: Option<FloatTensor<Self>>,
47        options: ConvOptions<2>,
48    ) -> FloatTensor<Self> {
49        assert!(
50            options.dilation[0] == options.dilation[1]
51                && options.padding[0] == options.padding[1]
52                && options.stride[0] == options.stride[1],
53            "Candle does not support per dimension options in convolutions"
54        );
55        let conv = x
56            .tensor
57            .conv2d(
58                &weight.tensor,
59                options.padding[0],
60                options.stride[0],
61                options.dilation[0],
62                options.groups,
63            )
64            .unwrap();
65        CandleTensor::new(match bias {
66            Some(bias) => conv
67                .broadcast_add(
68                    &bias
69                        .tensor
70                        .unsqueeze(0)
71                        .unwrap()
72                        .unsqueeze(2)
73                        .unwrap()
74                        .unsqueeze(3)
75                        .unwrap(),
76                )
77                .unwrap(),
78            None => conv,
79        })
80    }
81
82    fn deform_conv2d(
83        x: FloatTensor<Self>,
84        offset: FloatTensor<Self>,
85        weight: FloatTensor<Self>,
86        mask: Option<FloatTensor<Self>>,
87        bias: Option<FloatTensor<Self>>,
88        options: DeformConvOptions<2>,
89    ) -> FloatTensor<Self> {
90        unimplemented!("Candle does not support deformable convolutions")
91    }
92
93    fn deform_conv2d_backward(
94        x: FloatTensor<Self>,
95        offset: FloatTensor<Self>,
96        weight: FloatTensor<Self>,
97        mask: Option<FloatTensor<Self>>,
98        bias: Option<FloatTensor<Self>>,
99        output_grad: FloatTensor<Self>,
100        options: DeformConvOptions<2>,
101    ) -> DeformConv2dBackward<Self> {
102        unimplemented!("Candle does not support deformable convolutions")
103    }
104
105    fn conv3d(
106        x: FloatTensor<Self>,
107        weight: FloatTensor<Self>,
108        bias: Option<FloatTensor<Self>>,
109        options: ConvOptions<3>,
110    ) -> FloatTensor<Self> {
111        panic!("Candle does not support 3D convolutions");
112    }
113
114    fn conv_transpose1d(
115        x: FloatTensor<Self>,
116        weight: FloatTensor<Self>,
117        bias: Option<FloatTensor<Self>>,
118        options: ConvTransposeOptions<1>,
119    ) -> FloatTensor<Self> {
120        let conv_transpose = x
121            .tensor
122            .conv_transpose1d(
123                &weight.tensor,
124                options.padding[0],
125                options.padding_out[0],
126                options.stride[0],
127                options.dilation[0],
128                options.groups,
129            )
130            .unwrap();
131        CandleTensor::new(match bias {
132            Some(bias) => conv_transpose
133                .broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
134                .unwrap(),
135            None => conv_transpose,
136        })
137    }
138
139    fn conv_transpose2d(
140        x: FloatTensor<Self>,
141        weight: FloatTensor<Self>,
142        bias: Option<FloatTensor<Self>>,
143        options: ConvTransposeOptions<2>,
144    ) -> FloatTensor<Self> {
145        assert!(
146            options.dilation[0] == options.dilation[1]
147                && options.padding[0] == options.padding[1]
148                && options.padding_out[0] == options.padding_out[1]
149                && options.stride[0] == options.stride[1],
150            "Candle does not support per dimension options in transposed convolutions"
151        );
152        assert!(
153            options.groups == 1,
154            "Candle does not support groups in transposed convolutions"
155        );
156        let conv_transpose = x
157            .tensor
158            .conv_transpose2d(
159                &weight.tensor,
160                options.padding[0],
161                options.padding_out[0],
162                options.stride[0],
163                options.dilation[0],
164            )
165            .unwrap();
166        CandleTensor::new(match bias {
167            Some(bias) => conv_transpose
168                .broadcast_add(
169                    &bias
170                        .tensor
171                        .unsqueeze(0)
172                        .unwrap()
173                        .unsqueeze(2)
174                        .unwrap()
175                        .unsqueeze(3)
176                        .unwrap(),
177                )
178                .unwrap(),
179            None => conv_transpose,
180        })
181    }
182
183    fn conv_transpose3d(
184        x: FloatTensor<Self>,
185        weight: FloatTensor<Self>,
186        bias: Option<FloatTensor<Self>>,
187        options: ConvTransposeOptions<3>,
188    ) -> FloatTensor<Self> {
189        panic!("Candle does not support 3D transposed convolutions");
190    }
191
192    fn avg_pool2d(
193        x: FloatTensor<Self>,
194        kernel_size: [usize; 2],
195        stride: [usize; 2],
196        padding: [usize; 2],
197        count_include_pad: bool,
198        ceil_mode: bool,
199    ) -> FloatTensor<Self> {
200        assert!(
201            padding[0] == 0 && padding[1] == 0,
202            "Candle does not support padding in pooling"
203        );
204        assert!(
205            count_include_pad,
206            "Candle does not support excluding pad count in pooling"
207        );
208        assert!(!ceil_mode, "Candle does not support ceil_mode in pooling");
209        CandleTensor::new(
210            x.tensor
211                .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
212                .unwrap(),
213        )
214    }
215
216    fn avg_pool2d_backward(
217        x: FloatTensor<Self>,
218        grad: FloatTensor<Self>,
219        kernel_size: [usize; 2],
220        stride: [usize; 2],
221        padding: [usize; 2],
222        count_include_pad: bool,
223        _ceil_mode: bool,
224    ) -> FloatTensor<Self> {
225        panic!("avg_pool2d_backward is not supported by Candle")
226    }
227
228    fn max_pool2d(
229        x: FloatTensor<Self>,
230        kernel_size: [usize; 2],
231        stride: [usize; 2],
232        padding: [usize; 2],
233        dilation: [usize; 2],
234        ceil_mode: bool,
235    ) -> FloatTensor<Self> {
236        assert!(
237            padding[0] == 0 && padding[1] == 0,
238            "Candle does not support padding in pooling"
239        );
240        assert!(
241            dilation[0] == 1 && dilation[1] == 1,
242            "Candle does not support dilation in pooling"
243        );
244        assert!(!ceil_mode, "Candle does not support ceil_mode in pooling");
245        CandleTensor::new(
246            x.tensor
247                .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
248                .unwrap(),
249        )
250    }
251
252    fn max_pool2d_with_indices(
253        x: FloatTensor<Self>,
254        kernel_size: [usize; 2],
255        stride: [usize; 2],
256        padding: [usize; 2],
257        dilation: [usize; 2],
258        _ceil_mode: bool,
259    ) -> MaxPool2dWithIndices<Candle<F, I>> {
260        panic!("max_pool2d_with_indices is not supported by Candle")
261    }
262
263    fn max_pool2d_with_indices_backward(
264        x: FloatTensor<Self>,
265        kernel_size: [usize; 2],
266        stride: [usize; 2],
267        padding: [usize; 2],
268        dilation: [usize; 2],
269        _ceil_mode: bool,
270        output_grad: FloatTensor<Self>,
271        indices: IntTensor<Self>,
272    ) -> MaxPool2dBackward<Candle<F, I>> {
273        panic!("max_pool2d_with_indices_backward is not supported by Candle")
274    }
275
276    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
277        panic!("adaptive_avg_pool2 is not supported by Candle")
278    }
279
280    fn adaptive_avg_pool2d_backward(
281        x: FloatTensor<Self>,
282        grad: FloatTensor<Self>,
283    ) -> FloatTensor<Self> {
284        panic!("adaptive_avg_pool2d_backward is not supported by Candle")
285    }
286
287    fn interpolate(
288        x: FloatTensor<Self>,
289        output_size: [usize; 2],
290        options: InterpolateOptions,
291    ) -> FloatTensor<Self> {
292        let tensor = match options.mode {
293            InterpolateMode::Nearest => x
294                .tensor
295                .upsample_nearest2d(output_size[0], output_size[1])
296                .unwrap(),
297            InterpolateMode::Bilinear => {
298                panic!("bilinear interpolation is not supported by Candle")
299            }
300            InterpolateMode::Bicubic => {
301                panic!("bicubic interpolation is not supported by Candle")
302            }
303        };
304
305        CandleTensor::new(tensor)
306    }
307
308    fn interpolate_backward(
309        x: FloatTensor<Self>,
310        grad: FloatTensor<Self>,
311        output_size: [usize; 2],
312        options: InterpolateOptions,
313    ) -> FloatTensor<Self> {
314        panic!("interpolate_backward is not supported by Candle")
315    }
316}