1use super::{
2 adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3 avgpool::{avg_pool2d, avg_pool2d_backward},
4 conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5 deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6 interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7 maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9#[cfg(feature = "simd")]
10use crate::ops::simd::{
11 avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
12};
13use crate::{
14 NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
15 tensor::NdArrayTensor,
16};
17use crate::{
18 element::{IntNdArrayElement, QuantElement},
19 ops::interpolate::nearest_interpolate_backward,
20};
21use burn_backend::{ElementConversion, TensorMetadata, ops::*, tensor::FloatTensor};
22
23macro_rules! module_op {
24 (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
27 #[allow(unused_parens, unreachable_patterns)]
28 match ($($x),+) {
29 ($(NdArrayTensor::F32($x)),+) => {
30 type $element = f32;
31 $op(
32 $($x.into_shared()),+
33 $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
34 )
35 }
36 ($(NdArrayTensor::F64($x)),+) => {
37 type $element = f64;
38 $op(
39 $($x.into_shared()),+
40 $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
41 )
42 }
43 _ => panic!("Data type mismatch"),
44 }
45 }};
46}
47
48impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
49 for NdArray<E, I, Q>
50where
51 NdArrayTensor: From<SharedArray<E>>,
52 NdArrayTensor: From<SharedArray<I>>,
53{
54 fn conv2d(
55 x: NdArrayTensor,
56 weight: NdArrayTensor,
57 bias: Option<NdArrayTensor>,
58 options: ConvOptions<2>,
59 ) -> NdArrayTensor {
60 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
61 #[cfg(feature = "simd")]
62 let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
63 Ok(out) => return out.into(),
64 Err(args) => args,
65 };
66 conv2d::<E>(x, weight, bias, options).into()
67 })
68 }
69
70 fn deform_conv2d(
71 x: FloatTensor<Self>,
72 offset: FloatTensor<Self>,
73 weight: FloatTensor<Self>,
74 mask: Option<FloatTensor<Self>>,
75 bias: Option<FloatTensor<Self>>,
76 options: DeformConvOptions<2>,
77 ) -> FloatTensor<Self> {
78 module_op!(
79 inp(x, offset, weight),
80 opt(mask, bias),
81 E,
82 |x, offset, weight, mask, bias| deform_conv2d::<E>(
83 x, offset, weight, mask, bias, options
84 )
85 .into()
86 )
87 }
88
89 fn deform_conv2d_backward(
90 x: FloatTensor<Self>,
91 offset: FloatTensor<Self>,
92 weight: FloatTensor<Self>,
93 mask: Option<FloatTensor<Self>>,
94 bias: Option<FloatTensor<Self>>,
95 output_grad: FloatTensor<Self>,
96 options: DeformConvOptions<2>,
97 ) -> DeformConv2dBackward<Self> {
98 module_op!(
99 inp(x, offset, weight, output_grad),
100 opt(mask, bias),
101 E,
102 |x, offset, weight, output_grad, mask, bias| {
103 let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
104 x,
105 offset,
106 weight,
107 mask,
108 bias,
109 output_grad,
110 options,
111 );
112 DeformConv2dBackward::new(
113 x.into(),
114 offset.into(),
115 weight.into(),
116 mask.map(|m| m.into()),
117 bias.map(|b| b.into()),
118 )
119 }
120 )
121 }
122
123 fn conv_transpose2d(
124 x: FloatTensor<Self>,
125 weight: FloatTensor<Self>,
126 bias: Option<FloatTensor<Self>>,
127 options: ConvTransposeOptions<2>,
128 ) -> FloatTensor<Self> {
129 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
130 conv_transpose2d::<E>(x, weight, bias, options).into()
131 })
132 }
133
134 fn avg_pool2d(
135 x: FloatTensor<Self>,
136 kernel_size: [usize; 2],
137 stride: [usize; 2],
138 padding: [usize; 2],
139 count_include_pad: bool,
140 ceil_mode: bool,
141 ) -> FloatTensor<Self> {
142 module_op!(inp(x), opt(), E, |x| {
143 #[cfg(feature = "simd")]
144 let x = match if ceil_mode {
145 Err(x)
147 } else {
148 try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad)
149 } {
150 Ok(out) => return out.into(),
151 Err(x) => x,
152 };
153 avg_pool2d::<E>(
154 x,
155 kernel_size,
156 stride,
157 padding,
158 count_include_pad,
159 ceil_mode,
160 )
161 .into()
162 })
163 }
164
165 fn avg_pool2d_backward(
166 x: FloatTensor<Self>,
167 grad: FloatTensor<Self>,
168 kernel_size: [usize; 2],
169 stride: [usize; 2],
170 padding: [usize; 2],
171 count_include_pad: bool,
172 ceil_mode: bool,
173 ) -> FloatTensor<Self> {
174 module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
175 x,
176 grad,
177 kernel_size,
178 stride,
179 padding,
180 count_include_pad,
181 ceil_mode
182 )
183 .into())
184 }
185
186 fn max_pool2d(
187 x: FloatTensor<Self>,
188 kernel_size: [usize; 2],
189 stride: [usize; 2],
190 padding: [usize; 2],
191 dilation: [usize; 2],
192 ceil_mode: bool,
193 ) -> FloatTensor<Self> {
194 module_op!(inp(x), opt(), E, |x| {
195 #[cfg(feature = "simd")]
196 let x = match if ceil_mode {
197 Err(x)
199 } else {
200 try_max_pool2d_simd(x, kernel_size, stride, padding, dilation)
201 } {
202 Ok(out) => return out.into(),
203 Err(x) => x,
204 };
205 max_pool2d::<E>(x, kernel_size, stride, padding, dilation, ceil_mode).into()
206 })
207 }
208
209 fn max_pool2d_with_indices(
210 x: FloatTensor<Self>,
211 kernel_size: [usize; 2],
212 stride: [usize; 2],
213 padding: [usize; 2],
214 dilation: [usize; 2],
215 ceil_mode: bool,
216 ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
217 module_op!(inp(x), opt(), E, |x| {
218 let (output, indices) = max_pool2d_with_indices::<E, I>(
219 x,
220 kernel_size,
221 stride,
222 padding,
223 dilation,
224 ceil_mode,
225 );
226 MaxPool2dWithIndices::new(output.into(), indices.into())
227 })
228 }
229
230 fn max_pool2d_with_indices_backward(
231 x: FloatTensor<Self>,
232 kernel_size: [usize; 2],
233 stride: [usize; 2],
234 padding: [usize; 2],
235 dilation: [usize; 2],
236 ceil_mode: bool,
237 output_grad: FloatTensor<Self>,
238 indices: NdArrayTensor,
239 ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
240 execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray<IntElem>| {
241 let indices: SharedArray<I> = idx_s.mapv(|x| x.elem()).into_shared();
244 module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
245 let output = max_pool2d_backward::<E, I>(
246 x,
247 kernel_size,
248 stride,
249 padding,
250 dilation,
251 ceil_mode,
252 output_grad,
253 indices,
254 );
255 MaxPool2dBackward::new(output.into())
256 })
257 })
258 }
259
260 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
261 module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
262 x,
263 output_size
264 )
265 .into())
266 }
267
268 fn adaptive_avg_pool2d_backward(
269 x: FloatTensor<Self>,
270 grad: FloatTensor<Self>,
271 ) -> FloatTensor<Self> {
272 module_op!(inp(x, grad), opt(), E, |x, grad| {
273 adaptive_avg_pool2d_backward::<E>(x, grad).into()
274 })
275 }
276
277 fn interpolate(
278 x: FloatTensor<Self>,
279 output_size: [usize; 2],
280 options: InterpolateOptions,
281 ) -> FloatTensor<Self> {
282 match options.mode {
283 InterpolateMode::Nearest => {
284 module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
285 x,
286 output_size
287 )
288 .into())
289 }
290 InterpolateMode::Bilinear => {
291 module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
292 x,
293 output_size
294 )
295 .into())
296 }
297 InterpolateMode::Bicubic => {
298 module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
299 x,
300 output_size
301 )
302 .into())
303 }
304 }
305 }
306
307 fn interpolate_backward(
308 x: FloatTensor<Self>,
309 grad: FloatTensor<Self>,
310 output_size: [usize; 2],
311 options: InterpolateOptions,
312 ) -> FloatTensor<Self> {
313 match options.mode {
314 InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
315 nearest_interpolate_backward::<E>(x, grad, output_size).into()
316 }),
317 InterpolateMode::Bilinear => {
318 panic!("bilinear interpolation backward is not supported for ndarray backend")
319 }
320 InterpolateMode::Bicubic => {
321 panic!("bicubic interpolation backward is not supported for ndarray backend")
322 }
323 }
324 }
325
326 fn conv3d(
327 x: FloatTensor<Self>,
328 weight: FloatTensor<Self>,
329 bias: Option<FloatTensor<Self>>,
330 options: ConvOptions<3>,
331 ) -> FloatTensor<Self> {
332 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
333 x, weight, bias, options
334 )
335 .into())
336 }
337
338 fn conv_transpose3d(
339 x: FloatTensor<Self>,
340 weight: FloatTensor<Self>,
341 bias: Option<FloatTensor<Self>>,
342 options: ConvTransposeOptions<3>,
343 ) -> FloatTensor<Self> {
344 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
345 conv_transpose3d::<E>(x, weight, bias, options).into()
346 })
347 }
348}