1use burn_tensor::{
2 ops::{
3 ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
4 IntTensor, InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices,
5 ModuleOps, UnfoldOptions,
6 },
7 Shape,
8};
9use candle_core::ToUsize2;
10
11use crate::{
12 element::{CandleElement, FloatCandleElement, IntCandleElement},
13 ops::base::reshape,
14 Candle, CandleTensor,
15};
16
17impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I> {
18 fn conv1d(
19 x: FloatTensor<Self>,
20 weight: FloatTensor<Self>,
21 bias: Option<FloatTensor<Self>>,
22 options: ConvOptions<1>,
23 ) -> FloatTensor<Self> {
24 let conv = x
25 .tensor
26 .conv1d(
27 &weight.tensor,
28 options.padding[0],
29 options.stride[0],
30 options.dilation[0],
31 options.groups,
32 )
33 .unwrap();
34 CandleTensor::new(match bias {
35 Some(bias) => conv
36 .broadcast_add(&bias.tensor.unsqueeze(1).unwrap())
37 .unwrap(),
38 None => conv,
39 })
40 }
41
42 fn conv2d(
43 x: FloatTensor<Self>,
44 weight: FloatTensor<Self>,
45 bias: Option<FloatTensor<Self>>,
46 options: ConvOptions<2>,
47 ) -> FloatTensor<Self> {
48 assert!(
49 options.dilation[0] == options.dilation[1]
50 && options.padding[0] == options.padding[1]
51 && options.stride[0] == options.stride[1],
52 "Candle does not support per dimension options in convolutions"
53 );
54 let conv = x
55 .tensor
56 .conv2d(
57 &weight.tensor,
58 options.padding[0],
59 options.stride[0],
60 options.dilation[0],
61 options.groups,
62 )
63 .unwrap();
64 CandleTensor::new(match bias {
65 Some(bias) => conv
66 .broadcast_add(
67 &bias
68 .tensor
69 .unsqueeze(0)
70 .unwrap()
71 .unsqueeze(2)
72 .unwrap()
73 .unsqueeze(3)
74 .unwrap(),
75 )
76 .unwrap(),
77 None => conv,
78 })
79 }
80
81 fn deform_conv2d(
82 x: FloatTensor<Self>,
83 offset: FloatTensor<Self>,
84 weight: FloatTensor<Self>,
85 mask: Option<FloatTensor<Self>>,
86 bias: Option<FloatTensor<Self>>,
87 options: DeformConvOptions<2>,
88 ) -> FloatTensor<Self> {
89 unimplemented!("Candle does not support deformable convolutions")
90 }
91
92 fn deform_conv2d_backward(
93 x: FloatTensor<Self>,
94 offset: FloatTensor<Self>,
95 weight: FloatTensor<Self>,
96 mask: Option<FloatTensor<Self>>,
97 bias: Option<FloatTensor<Self>>,
98 output_grad: FloatTensor<Self>,
99 options: DeformConvOptions<2>,
100 ) -> DeformConv2dBackward<Self> {
101 unimplemented!("Candle does not support deformable convolutions")
102 }
103
104 fn conv3d(
105 x: FloatTensor<Self>,
106 weight: FloatTensor<Self>,
107 bias: Option<FloatTensor<Self>>,
108 options: ConvOptions<3>,
109 ) -> FloatTensor<Self> {
110 panic!("Candle does not support 3D convolutions");
111 }
112
113 fn conv_transpose1d(
114 x: FloatTensor<Self>,
115 weight: FloatTensor<Self>,
116 bias: Option<FloatTensor<Self>>,
117 options: ConvTransposeOptions<1>,
118 ) -> FloatTensor<Self> {
119 let conv_transpose = x
120 .tensor
121 .conv_transpose1d(
122 &weight.tensor,
123 options.padding[0],
124 options.padding_out[0],
125 options.stride[0],
126 options.dilation[0],
127 options.groups,
128 )
129 .unwrap();
130 CandleTensor::new(match bias {
131 Some(bias) => conv_transpose
132 .broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
133 .unwrap(),
134 None => conv_transpose,
135 })
136 }
137
138 fn conv_transpose2d(
139 x: FloatTensor<Self>,
140 weight: FloatTensor<Self>,
141 bias: Option<FloatTensor<Self>>,
142 options: ConvTransposeOptions<2>,
143 ) -> FloatTensor<Self> {
144 assert!(
145 options.dilation[0] == options.dilation[1]
146 && options.padding[0] == options.padding[1]
147 && options.padding_out[0] == options.padding_out[1]
148 && options.stride[0] == options.stride[1],
149 "Candle does not support per dimension options in transposed convolutions"
150 );
151 assert!(
152 options.groups == 1,
153 "Candle does not support groups in transposed convolutions"
154 );
155 let conv_transpose = x
156 .tensor
157 .conv_transpose2d(
158 &weight.tensor,
159 options.padding[0],
160 options.padding_out[0],
161 options.stride[0],
162 options.dilation[0],
163 )
164 .unwrap();
165 CandleTensor::new(match bias {
166 Some(bias) => conv_transpose
167 .broadcast_add(
168 &bias
169 .tensor
170 .unsqueeze(0)
171 .unwrap()
172 .unsqueeze(2)
173 .unwrap()
174 .unsqueeze(3)
175 .unwrap(),
176 )
177 .unwrap(),
178 None => conv_transpose,
179 })
180 }
181
182 fn conv_transpose3d(
183 x: FloatTensor<Self>,
184 weight: FloatTensor<Self>,
185 bias: Option<FloatTensor<Self>>,
186 options: ConvTransposeOptions<3>,
187 ) -> FloatTensor<Self> {
188 panic!("Candle does not support 3D transposed convolutions");
189 }
190
191 fn avg_pool2d(
192 x: FloatTensor<Self>,
193 kernel_size: [usize; 2],
194 stride: [usize; 2],
195 padding: [usize; 2],
196 count_include_pad: bool,
197 ) -> FloatTensor<Self> {
198 assert!(
199 padding[0] == 0 && padding[1] == 0,
200 "Candle does not support padding in pooling"
201 );
202 assert!(
203 count_include_pad,
204 "Candle does not support excluding pad count in pooling"
205 );
206 CandleTensor::new(
207 x.tensor
208 .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
209 .unwrap(),
210 )
211 }
212
213 fn avg_pool2d_backward(
214 x: FloatTensor<Self>,
215 grad: FloatTensor<Self>,
216 kernel_size: [usize; 2],
217 stride: [usize; 2],
218 padding: [usize; 2],
219 count_include_pad: bool,
220 ) -> FloatTensor<Self> {
221 panic!("avg_pool2d_backward is not supported by Candle")
222 }
223
224 fn max_pool2d(
225 x: FloatTensor<Self>,
226 kernel_size: [usize; 2],
227 stride: [usize; 2],
228 padding: [usize; 2],
229 dilation: [usize; 2],
230 ) -> FloatTensor<Self> {
231 assert!(
232 padding[0] == 0 && padding[1] == 0,
233 "Candle does not support padding in pooling"
234 );
235 assert!(
236 dilation[0] == 1 && dilation[1] == 1,
237 "Candle does not support dilation in pooling"
238 );
239 CandleTensor::new(
240 x.tensor
241 .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))
242 .unwrap(),
243 )
244 }
245
246 fn max_pool2d_with_indices(
247 x: FloatTensor<Self>,
248 kernel_size: [usize; 2],
249 stride: [usize; 2],
250 padding: [usize; 2],
251 dilation: [usize; 2],
252 ) -> MaxPool2dWithIndices<Candle<F, I>> {
253 panic!("max_pool2d_with_indices is not supported by Candle")
254 }
255
256 fn max_pool2d_with_indices_backward(
257 x: FloatTensor<Self>,
258 kernel_size: [usize; 2],
259 stride: [usize; 2],
260 padding: [usize; 2],
261 dilation: [usize; 2],
262 output_grad: FloatTensor<Self>,
263 indices: IntTensor<Self>,
264 ) -> MaxPool2dBackward<Candle<F, I>> {
265 panic!("max_pool2d_with_indices_backward is not supported by Candle")
266 }
267
268 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
269 panic!("adaptive_avg_pool2 is not supported by Candle")
270 }
271
272 fn adaptive_avg_pool2d_backward(
273 x: FloatTensor<Self>,
274 grad: FloatTensor<Self>,
275 ) -> FloatTensor<Self> {
276 panic!("adaptive_avg_pool2d_backward is not supported by Candle")
277 }
278
279 fn interpolate(
280 x: FloatTensor<Self>,
281 output_size: [usize; 2],
282 options: InterpolateOptions,
283 ) -> FloatTensor<Self> {
284 let tensor = match options.mode {
285 InterpolateMode::Nearest => x
286 .tensor
287 .upsample_nearest2d(output_size[0], output_size[1])
288 .unwrap(),
289 InterpolateMode::Bilinear => {
290 panic!("bilinear interpolation is not supported by Candle")
291 }
292 InterpolateMode::Bicubic => {
293 panic!("bicubic interpolation is not supported by Candle")
294 }
295 };
296
297 CandleTensor::new(tensor)
298 }
299
300 fn interpolate_backward(
301 x: FloatTensor<Self>,
302 grad: FloatTensor<Self>,
303 output_size: [usize; 2],
304 options: InterpolateOptions,
305 ) -> FloatTensor<Self> {
306 panic!("interpolate_backward is not supported by Candle")
307 }
308}