1use super::{
2 adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3 avgpool::{avg_pool2d, avg_pool2d_backward},
4 conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5 deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6 interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7 maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9#[cfg(feature = "simd")]
10use crate::ops::simd::{
11 avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
12};
13use crate::{
14 NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
15 tensor::NdArrayTensor,
16};
17use crate::{
18 element::{IntNdArrayElement, QuantElement},
19 ops::interpolate::nearest_interpolate_backward,
20};
21use burn_tensor::{TensorMetadata, ops::*};
22
23macro_rules! module_op {
24 (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
26 #[allow(unused_parens, unreachable_patterns)]
27 match ($($x),+) {
28 ($(NdArrayTensor::F32($x)),+) => {
29 type $element = f32;
30 $op(
31 $($x),+
32 $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val, _ => panic!("Optional argument type mismatch") }))*
33 )
34 }
35 ($(NdArrayTensor::F64($x)),+) => {
36 type $element = f64;
37 $op(
38 $($x),+
39 $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val, _ => panic!("Optional argument type mismatch") }))*
40 )
41 }
42 _ => panic!("Data type mismatch"),
43 }
44 }};
45}
46
47impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
48 for NdArray<E, I, Q>
49where
50 NdArrayTensor: From<SharedArray<E>>,
51 NdArrayTensor: From<SharedArray<I>>,
52{
53 fn conv2d(
54 x: NdArrayTensor,
55 weight: NdArrayTensor,
56 bias: Option<NdArrayTensor>,
57 options: ConvOptions<2>,
58 ) -> NdArrayTensor {
59 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
60 #[cfg(feature = "simd")]
61 let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
62 Ok(out) => return out.into(),
63 Err(args) => args,
64 };
65 conv2d::<E>(x, weight, bias, options).into()
66 })
67 }
68
69 fn deform_conv2d(
70 x: FloatTensor<Self>,
71 offset: FloatTensor<Self>,
72 weight: FloatTensor<Self>,
73 mask: Option<FloatTensor<Self>>,
74 bias: Option<FloatTensor<Self>>,
75 options: DeformConvOptions<2>,
76 ) -> FloatTensor<Self> {
77 module_op!(
78 inp(x, offset, weight),
79 opt(mask, bias),
80 E,
81 |x, offset, weight, mask, bias| deform_conv2d::<E>(
82 x, offset, weight, mask, bias, options
83 )
84 .into()
85 )
86 }
87
88 fn deform_conv2d_backward(
89 x: FloatTensor<Self>,
90 offset: FloatTensor<Self>,
91 weight: FloatTensor<Self>,
92 mask: Option<FloatTensor<Self>>,
93 bias: Option<FloatTensor<Self>>,
94 output_grad: FloatTensor<Self>,
95 options: DeformConvOptions<2>,
96 ) -> DeformConv2dBackward<Self> {
97 module_op!(
98 inp(x, offset, weight, output_grad),
99 opt(mask, bias),
100 E,
101 |x, offset, weight, output_grad, mask, bias| {
102 let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
103 x,
104 offset,
105 weight,
106 mask,
107 bias,
108 output_grad,
109 options,
110 );
111 DeformConv2dBackward::new(
112 x.into(),
113 offset.into(),
114 weight.into(),
115 mask.map(|m| m.into()),
116 bias.map(|b| b.into()),
117 )
118 }
119 )
120 }
121
122 fn conv_transpose2d(
123 x: FloatTensor<Self>,
124 weight: FloatTensor<Self>,
125 bias: Option<FloatTensor<Self>>,
126 options: ConvTransposeOptions<2>,
127 ) -> FloatTensor<Self> {
128 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
129 conv_transpose2d::<E>(x, weight, bias, options).into()
130 })
131 }
132
133 fn avg_pool2d(
134 x: FloatTensor<Self>,
135 kernel_size: [usize; 2],
136 stride: [usize; 2],
137 padding: [usize; 2],
138 count_include_pad: bool,
139 ) -> FloatTensor<Self> {
140 module_op!(inp(x), opt(), E, |x| {
141 #[cfg(feature = "simd")]
142 let x = match try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad) {
143 Ok(out) => return out.into(),
144 Err(x) => x,
145 };
146 avg_pool2d::<E>(x, kernel_size, stride, padding, count_include_pad).into()
147 })
148 }
149
150 fn avg_pool2d_backward(
151 x: FloatTensor<Self>,
152 grad: FloatTensor<Self>,
153 kernel_size: [usize; 2],
154 stride: [usize; 2],
155 padding: [usize; 2],
156 count_include_pad: bool,
157 ) -> FloatTensor<Self> {
158 module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
159 x,
160 grad,
161 kernel_size,
162 stride,
163 padding,
164 count_include_pad
165 )
166 .into())
167 }
168
169 fn max_pool2d(
170 x: FloatTensor<Self>,
171 kernel_size: [usize; 2],
172 stride: [usize; 2],
173 padding: [usize; 2],
174 dilation: [usize; 2],
175 ) -> FloatTensor<Self> {
176 module_op!(inp(x), opt(), E, |x| {
177 #[cfg(feature = "simd")]
178 let x = match try_max_pool2d_simd(x, kernel_size, stride, padding, dilation) {
179 Ok(out) => return out.into(),
180 Err(x) => x,
181 };
182 max_pool2d::<E>(x, kernel_size, stride, padding, dilation).into()
183 })
184 }
185
186 fn max_pool2d_with_indices(
187 x: FloatTensor<Self>,
188 kernel_size: [usize; 2],
189 stride: [usize; 2],
190 padding: [usize; 2],
191 dilation: [usize; 2],
192 ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
193 module_op!(inp(x), opt(), E, |x| {
194 let (output, indices) =
195 max_pool2d_with_indices::<E, I>(x, kernel_size, stride, padding, dilation);
196 MaxPool2dWithIndices::new(output.into(), indices.into())
197 })
198 }
199
200 fn max_pool2d_with_indices_backward(
201 x: FloatTensor<Self>,
202 kernel_size: [usize; 2],
203 stride: [usize; 2],
204 padding: [usize; 2],
205 dilation: [usize; 2],
206 output_grad: FloatTensor<Self>,
207 indices: NdArrayTensor,
208 ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
209 execute_with_int_dtype!(indices, I, |indices| {
210 module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
211 let output = max_pool2d_backward::<E, I>(
212 x,
213 kernel_size,
214 stride,
215 padding,
216 dilation,
217 output_grad,
218 indices,
219 );
220 MaxPool2dBackward::new(output.into())
221 })
222 })
223 }
224
225 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
226 module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
227 x,
228 output_size
229 )
230 .into())
231 }
232
233 fn adaptive_avg_pool2d_backward(
234 x: FloatTensor<Self>,
235 grad: FloatTensor<Self>,
236 ) -> FloatTensor<Self> {
237 module_op!(inp(x, grad), opt(), E, |x, grad| {
238 adaptive_avg_pool2d_backward::<E>(x, grad).into()
239 })
240 }
241
242 fn interpolate(
243 x: FloatTensor<Self>,
244 output_size: [usize; 2],
245 options: InterpolateOptions,
246 ) -> FloatTensor<Self> {
247 match options.mode {
248 InterpolateMode::Nearest => {
249 module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
250 x,
251 output_size
252 )
253 .into())
254 }
255 InterpolateMode::Bilinear => {
256 module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
257 x,
258 output_size
259 )
260 .into())
261 }
262 InterpolateMode::Bicubic => {
263 module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
264 x,
265 output_size
266 )
267 .into())
268 }
269 }
270 }
271
272 fn interpolate_backward(
273 x: FloatTensor<Self>,
274 grad: FloatTensor<Self>,
275 output_size: [usize; 2],
276 options: InterpolateOptions,
277 ) -> FloatTensor<Self> {
278 match options.mode {
279 InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
280 nearest_interpolate_backward::<E>(x, grad, output_size).into()
281 }),
282 InterpolateMode::Bilinear => {
283 panic!("bilinear interpolation backward is not supported for ndarray backend")
284 }
285 InterpolateMode::Bicubic => {
286 panic!("bicubic interpolation backward is not supported for ndarray backend")
287 }
288 }
289 }
290
291 fn conv3d(
292 x: FloatTensor<Self>,
293 weight: FloatTensor<Self>,
294 bias: Option<FloatTensor<Self>>,
295 options: ConvOptions<3>,
296 ) -> FloatTensor<Self> {
297 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
298 x, weight, bias, options
299 )
300 .into())
301 }
302
303 fn conv_transpose3d(
304 x: FloatTensor<Self>,
305 weight: FloatTensor<Self>,
306 bias: Option<FloatTensor<Self>>,
307 options: ConvTransposeOptions<3>,
308 ) -> FloatTensor<Self> {
309 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
310 conv_transpose3d::<E>(x, weight, bias, options).into()
311 })
312 }
313}