1use burn_backend::{
2 Shape,
3 ops::{
4 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,
5 InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
6 UnfoldOptions,
7 },
8 tensor::{FloatTensor, IntTensor},
9};
10use candle_core::ToUsize2;
11
12use crate::{
13 Candle, CandleTensor,
14 element::{CandleElement, FloatCandleElement, IntCandleElement},
15 ops::base::reshape,
16};
17
18impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I> {
19 fn conv1d(
20 x: FloatTensor<Self>,
21 weight: FloatTensor<Self>,
22 bias: Option<FloatTensor<Self>>,
23 options: ConvOptions<1>,
24 ) -> FloatTensor<Self> {
25 let conv = x
26 .tensor
27 .conv1d(
28 &weight.tensor,
29 options.padding[0],
30 options.stride[0],
31 options.dilation[0],
32 options.groups,
33 )
34 .unwrap();
35 CandleTensor::new(match bias {
36 Some(bias) => conv
37 .broadcast_add(&bias.tensor.unsqueeze(1).unwrap())
38 .unwrap(),
39 None => conv,
40 })
41 }
42
43 fn conv2d(
44 x: FloatTensor<Self>,
45 weight: FloatTensor<Self>,
46 bias: Option<FloatTensor<Self>>,
47 options: ConvOptions<2>,
48 ) -> FloatTensor<Self> {
49 assert!(
50 options.dilation[0] == options.dilation[1]
51 && options.padding[0] == options.padding[1]
52 && options.stride[0] == options.stride[1],
53 "Candle does not support per dimension options in convolutions"
54 );
55 let conv = x
56 .tensor
57 .conv2d(
58 &weight.tensor,
59 options.padding[0],
60 options.stride[0],
61 options.dilation[0],
62 options.groups,
63 )
64 .unwrap();
65 CandleTensor::new(match bias {
66 Some(bias) => conv
67 .broadcast_add(
68 &bias
69 .tensor
70 .unsqueeze(0)
71 .unwrap()
72 .unsqueeze(2)
73 .unwrap()
74 .unsqueeze(3)
75 .unwrap(),
76 )
77 .unwrap(),
78 None => conv,
79 })
80 }
81
82 fn deform_conv2d(
83 x: FloatTensor<Self>,
84 offset: FloatTensor<Self>,
85 weight: FloatTensor<Self>,
86 mask: Option<FloatTensor<Self>>,
87 bias: Option<FloatTensor<Self>>,
88 options: DeformConvOptions<2>,
89 ) -> FloatTensor<Self> {
90 unimplemented!("Candle does not support deformable convolutions")
91 }
92
93 fn deform_conv2d_backward(
94 x: FloatTensor<Self>,
95 offset: FloatTensor<Self>,
96 weight: FloatTensor<Self>,
97 mask: Option<FloatTensor<Self>>,
98 bias: Option<FloatTensor<Self>>,
99 output_grad: FloatTensor<Self>,
100 options: DeformConvOptions<2>,
101 ) -> DeformConv2dBackward<Self> {
102 unimplemented!("Candle does not support deformable convolutions")
103 }
104
105 fn conv3d(
106 x: FloatTensor<Self>,
107 weight: FloatTensor<Self>,
108 bias: Option<FloatTensor<Self>>,
109 options: ConvOptions<3>,
110 ) -> FloatTensor<Self> {
111 panic!("Candle does not support 3D convolutions");
112 }
113
114 fn conv_transpose1d(
115 x: FloatTensor<Self>,
116 weight: FloatTensor<Self>,
117 bias: Option<FloatTensor<Self>>,
118 options: ConvTransposeOptions<1>,
119 ) -> FloatTensor<Self> {
120 let conv_transpose = x
121 .tensor
122 .conv_transpose1d(
123 &weight.tensor,
124 options.padding[0],
125 options.padding_out[0],
126 options.stride[0],
127 options.dilation[0],
128 options.groups,
129 )
130 .unwrap();
131 CandleTensor::new(match bias {
132 Some(bias) => conv_transpose
133 .broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
134 .unwrap(),
135 None => conv_transpose,
136 })
137 }
138
139 fn conv_transpose2d(
140 x: FloatTensor<Self>,
141 weight: FloatTensor<Self>,
142 bias: Option<FloatTensor<Self>>,
143 options: ConvTransposeOptions<2>,
144 ) -> FloatTensor<Self> {
145 assert!(
146 options.dilation[0] == options.dilation[1]
147 && options.padding[0] == options.padding[1]
148 && options.padding_out[0] == options.padding_out[1]
149 && options.stride[0] == options.stride[1],
150 "Candle does not support per dimension options in transposed convolutions"
151 );
152 assert!(
153 options.groups == 1,
154 "Candle does not support groups in transposed convolutions"
155 );
156 let conv_transpose = x
157 .tensor
158 .conv_transpose2d(
159 &weight.tensor,
160 options.padding[0],
161 options.padding_out[0],
162 options.stride[0],
163 options.dilation[0],
164 )
165 .unwrap();
166 CandleTensor::new(match bias {
167 Some(bias) => conv_transpose
168 .broadcast_add(
169 &bias
170 .tensor
171 .unsqueeze(0)
172 .unwrap()
173 .unsqueeze(2)
174 .unwrap()
175 .unsqueeze(3)
176 .unwrap(),
177 )
178 .unwrap(),
179 None => conv_transpose,
180 })
181 }
182
183 fn conv_transpose3d(
184 x: FloatTensor<Self>,
185 weight: FloatTensor<Self>,
186 bias: Option<FloatTensor<Self>>,
187 options: ConvTransposeOptions<3>,
188 ) -> FloatTensor<Self> {
189 panic!("Candle does not support 3D transposed convolutions");
190 }
191
192 fn avg_pool2d(
193 x: FloatTensor<Self>,
194 kernel_size: [usize; 2],
195 stride: [usize; 2],
196 padding: [usize; 2],
197 count_include_pad: bool,
198 ceil_mode: bool,
199 ) -> FloatTensor<Self> {
200 assert!(
201 padding[0] == 0 && padding[1] == 0,
202 "Candle does not support padding in pooling"
203 );
204 assert!(
205 count_include_pad,
206 "Candle does not support excluding pad count in pooling"
207 );
208 assert!(!ceil_mode, "Candle does not support ceil_mode in pooling");
209 CandleTensor::new(
210 x.tensor
211 .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
212 .unwrap(),
213 )
214 }
215
216 fn avg_pool2d_backward(
217 x: FloatTensor<Self>,
218 grad: FloatTensor<Self>,
219 kernel_size: [usize; 2],
220 stride: [usize; 2],
221 padding: [usize; 2],
222 count_include_pad: bool,
223 _ceil_mode: bool,
224 ) -> FloatTensor<Self> {
225 panic!("avg_pool2d_backward is not supported by Candle")
226 }
227
228 fn max_pool2d(
229 x: FloatTensor<Self>,
230 kernel_size: [usize; 2],
231 stride: [usize; 2],
232 padding: [usize; 2],
233 dilation: [usize; 2],
234 ceil_mode: bool,
235 ) -> FloatTensor<Self> {
236 assert!(
237 padding[0] == 0 && padding[1] == 0,
238 "Candle does not support padding in pooling"
239 );
240 assert!(
241 dilation[0] == 1 && dilation[1] == 1,
242 "Candle does not support dilation in pooling"
243 );
244 assert!(!ceil_mode, "Candle does not support ceil_mode in pooling");
245 CandleTensor::new(
246 x.tensor
247 .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
248 .unwrap(),
249 )
250 }
251
252 fn max_pool2d_with_indices(
253 x: FloatTensor<Self>,
254 kernel_size: [usize; 2],
255 stride: [usize; 2],
256 padding: [usize; 2],
257 dilation: [usize; 2],
258 _ceil_mode: bool,
259 ) -> MaxPool2dWithIndices<Candle<F, I>> {
260 panic!("max_pool2d_with_indices is not supported by Candle")
261 }
262
263 fn max_pool2d_with_indices_backward(
264 x: FloatTensor<Self>,
265 kernel_size: [usize; 2],
266 stride: [usize; 2],
267 padding: [usize; 2],
268 dilation: [usize; 2],
269 _ceil_mode: bool,
270 output_grad: FloatTensor<Self>,
271 indices: IntTensor<Self>,
272 ) -> MaxPool2dBackward<Candle<F, I>> {
273 panic!("max_pool2d_with_indices_backward is not supported by Candle")
274 }
275
276 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
277 panic!("adaptive_avg_pool2 is not supported by Candle")
278 }
279
280 fn adaptive_avg_pool2d_backward(
281 x: FloatTensor<Self>,
282 grad: FloatTensor<Self>,
283 ) -> FloatTensor<Self> {
284 panic!("adaptive_avg_pool2d_backward is not supported by Candle")
285 }
286
287 fn interpolate(
288 x: FloatTensor<Self>,
289 output_size: [usize; 2],
290 options: InterpolateOptions,
291 ) -> FloatTensor<Self> {
292 let tensor = match options.mode {
293 InterpolateMode::Nearest => x
294 .tensor
295 .upsample_nearest2d(output_size[0], output_size[1])
296 .unwrap(),
297 InterpolateMode::Bilinear => {
298 panic!("bilinear interpolation is not supported by Candle")
299 }
300 InterpolateMode::Bicubic => {
301 panic!("bicubic interpolation is not supported by Candle")
302 }
303 };
304
305 CandleTensor::new(tensor)
306 }
307
308 fn interpolate_backward(
309 x: FloatTensor<Self>,
310 grad: FloatTensor<Self>,
311 output_size: [usize; 2],
312 options: InterpolateOptions,
313 ) -> FloatTensor<Self> {
314 panic!("interpolate_backward is not supported by Candle")
315 }
316}