burn_candle/ops/
module.rs

1use burn_tensor::{
2    ops::{
3        ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
4        IntTensor, InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices,
5        ModuleOps, UnfoldOptions,
6    },
7    Shape,
8};
9use candle_core::ToUsize2;
10
11use crate::{
12    element::{CandleElement, FloatCandleElement, IntCandleElement},
13    ops::base::reshape,
14    Candle, CandleTensor,
15};
16
17impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I> {
18    fn conv1d(
19        x: FloatTensor<Self>,
20        weight: FloatTensor<Self>,
21        bias: Option<FloatTensor<Self>>,
22        options: ConvOptions<1>,
23    ) -> FloatTensor<Self> {
24        let conv = x
25            .tensor
26            .conv1d(
27                &weight.tensor,
28                options.padding[0],
29                options.stride[0],
30                options.dilation[0],
31                options.groups,
32            )
33            .unwrap();
34        CandleTensor::new(match bias {
35            Some(bias) => conv
36                .broadcast_add(&bias.tensor.unsqueeze(1).unwrap())
37                .unwrap(),
38            None => conv,
39        })
40    }
41
42    fn conv2d(
43        x: FloatTensor<Self>,
44        weight: FloatTensor<Self>,
45        bias: Option<FloatTensor<Self>>,
46        options: ConvOptions<2>,
47    ) -> FloatTensor<Self> {
48        assert!(
49            options.dilation[0] == options.dilation[1]
50                && options.padding[0] == options.padding[1]
51                && options.stride[0] == options.stride[1],
52            "Candle does not support per dimension options in convolutions"
53        );
54        let conv = x
55            .tensor
56            .conv2d(
57                &weight.tensor,
58                options.padding[0],
59                options.stride[0],
60                options.dilation[0],
61                options.groups,
62            )
63            .unwrap();
64        CandleTensor::new(match bias {
65            Some(bias) => conv
66                .broadcast_add(
67                    &bias
68                        .tensor
69                        .unsqueeze(0)
70                        .unwrap()
71                        .unsqueeze(2)
72                        .unwrap()
73                        .unsqueeze(3)
74                        .unwrap(),
75                )
76                .unwrap(),
77            None => conv,
78        })
79    }
80
81    fn deform_conv2d(
82        x: FloatTensor<Self>,
83        offset: FloatTensor<Self>,
84        weight: FloatTensor<Self>,
85        mask: Option<FloatTensor<Self>>,
86        bias: Option<FloatTensor<Self>>,
87        options: DeformConvOptions<2>,
88    ) -> FloatTensor<Self> {
89        unimplemented!("Candle does not support deformable convolutions")
90    }
91
92    fn deform_conv2d_backward(
93        x: FloatTensor<Self>,
94        offset: FloatTensor<Self>,
95        weight: FloatTensor<Self>,
96        mask: Option<FloatTensor<Self>>,
97        bias: Option<FloatTensor<Self>>,
98        output_grad: FloatTensor<Self>,
99        options: DeformConvOptions<2>,
100    ) -> DeformConv2dBackward<Self> {
101        unimplemented!("Candle does not support deformable convolutions")
102    }
103
104    fn conv3d(
105        x: FloatTensor<Self>,
106        weight: FloatTensor<Self>,
107        bias: Option<FloatTensor<Self>>,
108        options: ConvOptions<3>,
109    ) -> FloatTensor<Self> {
110        panic!("Candle does not support 3D convolutions");
111    }
112
113    fn conv_transpose1d(
114        x: FloatTensor<Self>,
115        weight: FloatTensor<Self>,
116        bias: Option<FloatTensor<Self>>,
117        options: ConvTransposeOptions<1>,
118    ) -> FloatTensor<Self> {
119        let conv_transpose = x
120            .tensor
121            .conv_transpose1d(
122                &weight.tensor,
123                options.padding[0],
124                options.padding_out[0],
125                options.stride[0],
126                options.dilation[0],
127                options.groups,
128            )
129            .unwrap();
130        CandleTensor::new(match bias {
131            Some(bias) => conv_transpose
132                .broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
133                .unwrap(),
134            None => conv_transpose,
135        })
136    }
137
138    fn conv_transpose2d(
139        x: FloatTensor<Self>,
140        weight: FloatTensor<Self>,
141        bias: Option<FloatTensor<Self>>,
142        options: ConvTransposeOptions<2>,
143    ) -> FloatTensor<Self> {
144        assert!(
145            options.dilation[0] == options.dilation[1]
146                && options.padding[0] == options.padding[1]
147                && options.padding_out[0] == options.padding_out[1]
148                && options.stride[0] == options.stride[1],
149            "Candle does not support per dimension options in transposed convolutions"
150        );
151        assert!(
152            options.groups == 1,
153            "Candle does not support groups in transposed convolutions"
154        );
155        let conv_transpose = x
156            .tensor
157            .conv_transpose2d(
158                &weight.tensor,
159                options.padding[0],
160                options.padding_out[0],
161                options.stride[0],
162                options.dilation[0],
163            )
164            .unwrap();
165        CandleTensor::new(match bias {
166            Some(bias) => conv_transpose
167                .broadcast_add(
168                    &bias
169                        .tensor
170                        .unsqueeze(0)
171                        .unwrap()
172                        .unsqueeze(2)
173                        .unwrap()
174                        .unsqueeze(3)
175                        .unwrap(),
176                )
177                .unwrap(),
178            None => conv_transpose,
179        })
180    }
181
182    fn conv_transpose3d(
183        x: FloatTensor<Self>,
184        weight: FloatTensor<Self>,
185        bias: Option<FloatTensor<Self>>,
186        options: ConvTransposeOptions<3>,
187    ) -> FloatTensor<Self> {
188        panic!("Candle does not support 3D transposed convolutions");
189    }
190
191    fn avg_pool2d(
192        x: FloatTensor<Self>,
193        kernel_size: [usize; 2],
194        stride: [usize; 2],
195        padding: [usize; 2],
196        count_include_pad: bool,
197    ) -> FloatTensor<Self> {
198        assert!(
199            padding[0] == 0 && padding[1] == 0,
200            "Candle does not support padding in pooling"
201        );
202        assert!(
203            count_include_pad,
204            "Candle does not support excluding pad count in pooling"
205        );
206        CandleTensor::new(
207            x.tensor
208                .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
209                .unwrap(),
210        )
211    }
212
213    fn avg_pool2d_backward(
214        x: FloatTensor<Self>,
215        grad: FloatTensor<Self>,
216        kernel_size: [usize; 2],
217        stride: [usize; 2],
218        padding: [usize; 2],
219        count_include_pad: bool,
220    ) -> FloatTensor<Self> {
221        panic!("avg_pool2d_backward is not supported by Candle")
222    }
223
224    fn max_pool2d(
225        x: FloatTensor<Self>,
226        kernel_size: [usize; 2],
227        stride: [usize; 2],
228        padding: [usize; 2],
229        dilation: [usize; 2],
230    ) -> FloatTensor<Self> {
231        assert!(
232            padding[0] == 0 && padding[1] == 0,
233            "Candle does not support padding in pooling"
234        );
235        assert!(
236            dilation[0] == 1 && dilation[1] == 1,
237            "Candle does not support dilation in pooling"
238        );
239        CandleTensor::new(
240            x.tensor
241                .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
242                .unwrap(),
243        )
244    }
245
246    fn max_pool2d_with_indices(
247        x: FloatTensor<Self>,
248        kernel_size: [usize; 2],
249        stride: [usize; 2],
250        padding: [usize; 2],
251        dilation: [usize; 2],
252    ) -> MaxPool2dWithIndices<Candle<F, I>> {
253        panic!("max_pool2d_with_indices is not supported by Candle")
254    }
255
256    fn max_pool2d_with_indices_backward(
257        x: FloatTensor<Self>,
258        kernel_size: [usize; 2],
259        stride: [usize; 2],
260        padding: [usize; 2],
261        dilation: [usize; 2],
262        output_grad: FloatTensor<Self>,
263        indices: IntTensor<Self>,
264    ) -> MaxPool2dBackward<Candle<F, I>> {
265        panic!("max_pool2d_with_indices_backward is not supported by Candle")
266    }
267
268    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
269        panic!("adaptive_avg_pool2 is not supported by Candle")
270    }
271
272    fn adaptive_avg_pool2d_backward(
273        x: FloatTensor<Self>,
274        grad: FloatTensor<Self>,
275    ) -> FloatTensor<Self> {
276        panic!("adaptive_avg_pool2d_backward is not supported by Candle")
277    }
278
279    fn interpolate(
280        x: FloatTensor<Self>,
281        output_size: [usize; 2],
282        options: InterpolateOptions,
283    ) -> FloatTensor<Self> {
284        let tensor = match options.mode {
285            InterpolateMode::Nearest => x
286                .tensor
287                .upsample_nearest2d(output_size[0], output_size[1])
288                .unwrap(),
289            InterpolateMode::Bilinear => {
290                panic!("bilinear interpolation is not supported by Candle")
291            }
292            InterpolateMode::Bicubic => {
293                panic!("bicubic interpolation is not supported by Candle")
294            }
295        };
296
297        CandleTensor::new(tensor)
298    }
299
300    fn interpolate_backward(
301        x: FloatTensor<Self>,
302        grad: FloatTensor<Self>,
303        output_size: [usize; 2],
304        options: InterpolateOptions,
305    ) -> FloatTensor<Self> {
306        panic!("interpolate_backward is not supported by Candle")
307    }
308}