1use super::{
2 adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3 avgpool::{avg_pool2d, avg_pool2d_backward},
4 conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5 deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6 interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7 maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9#[cfg(feature = "simd")]
10use crate::ops::simd::{
11 avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
12};
13use crate::{NdArray, NdArrayTensorFloat, element::FloatNdArrayElement, tensor::NdArrayTensor};
14use crate::{
15 element::{IntNdArrayElement, QuantElement},
16 ops::interpolate::nearest_interpolate_backward,
17};
18use burn_tensor::ops::*;
19
20macro_rules! module_op {
21 (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
23 #[allow(unused_parens, unreachable_patterns)]
24 match ($($x),+) {
25 ($(NdArrayTensorFloat::F32($x)),+) => {
26 type $element = f32;
27 $op(
28 $($x),+
29 $(, $opt.map(|o| match o { NdArrayTensorFloat::F32(val) => val, _ => panic!("Optional argument type mismatch") }))*
30 )
31 }
32 ($(NdArrayTensorFloat::F64($x)),+) => {
33 type $element = f64;
34 $op(
35 $($x),+
36 $(, $opt.map(|o| match o { NdArrayTensorFloat::F64(val) => val, _ => panic!("Optional argument type mismatch") }))*
37 )
38 }
39 _ => panic!("Data type mismatch"),
40 }
41 }};
42}
43
44impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
45 for NdArray<E, I, Q>
46{
47 fn conv2d(
48 x: NdArrayTensorFloat,
49 weight: NdArrayTensorFloat,
50 bias: Option<NdArrayTensorFloat>,
51 options: ConvOptions<2>,
52 ) -> NdArrayTensorFloat {
53 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
54 #[cfg(feature = "simd")]
55 let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
56 Ok(out) => return out.into(),
57 Err(args) => args,
58 };
59 conv2d::<E>(x, weight, bias, options).into()
60 })
61 }
62
63 fn deform_conv2d(
64 x: FloatTensor<Self>,
65 offset: FloatTensor<Self>,
66 weight: FloatTensor<Self>,
67 mask: Option<FloatTensor<Self>>,
68 bias: Option<FloatTensor<Self>>,
69 options: DeformConvOptions<2>,
70 ) -> FloatTensor<Self> {
71 module_op!(
72 inp(x, offset, weight),
73 opt(mask, bias),
74 E,
75 |x, offset, weight, mask, bias| deform_conv2d::<E>(
76 x, offset, weight, mask, bias, options
77 )
78 .into()
79 )
80 }
81
82 fn deform_conv2d_backward(
83 x: FloatTensor<Self>,
84 offset: FloatTensor<Self>,
85 weight: FloatTensor<Self>,
86 mask: Option<FloatTensor<Self>>,
87 bias: Option<FloatTensor<Self>>,
88 output_grad: FloatTensor<Self>,
89 options: DeformConvOptions<2>,
90 ) -> DeformConv2dBackward<Self> {
91 module_op!(
92 inp(x, offset, weight, output_grad),
93 opt(mask, bias),
94 E,
95 |x, offset, weight, output_grad, mask, bias| {
96 let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
97 x,
98 offset,
99 weight,
100 mask,
101 bias,
102 output_grad,
103 options,
104 );
105 DeformConv2dBackward::new(
106 x.into(),
107 offset.into(),
108 weight.into(),
109 mask.map(|m| m.into()),
110 bias.map(|b| b.into()),
111 )
112 }
113 )
114 }
115
116 fn conv_transpose2d(
117 x: FloatTensor<Self>,
118 weight: FloatTensor<Self>,
119 bias: Option<FloatTensor<Self>>,
120 options: ConvTransposeOptions<2>,
121 ) -> FloatTensor<Self> {
122 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
123 conv_transpose2d::<E>(x, weight, bias, options).into()
124 })
125 }
126
127 fn avg_pool2d(
128 x: FloatTensor<Self>,
129 kernel_size: [usize; 2],
130 stride: [usize; 2],
131 padding: [usize; 2],
132 count_include_pad: bool,
133 ) -> FloatTensor<Self> {
134 module_op!(inp(x), opt(), E, |x| {
135 #[cfg(feature = "simd")]
136 let x = match try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad) {
137 Ok(out) => return out.into(),
138 Err(x) => x,
139 };
140 avg_pool2d::<E>(x, kernel_size, stride, padding, count_include_pad).into()
141 })
142 }
143
144 fn avg_pool2d_backward(
145 x: FloatTensor<Self>,
146 grad: FloatTensor<Self>,
147 kernel_size: [usize; 2],
148 stride: [usize; 2],
149 padding: [usize; 2],
150 count_include_pad: bool,
151 ) -> FloatTensor<Self> {
152 module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
153 x,
154 grad,
155 kernel_size,
156 stride,
157 padding,
158 count_include_pad
159 )
160 .into())
161 }
162
163 fn max_pool2d(
164 x: FloatTensor<Self>,
165 kernel_size: [usize; 2],
166 stride: [usize; 2],
167 padding: [usize; 2],
168 dilation: [usize; 2],
169 ) -> FloatTensor<Self> {
170 module_op!(inp(x), opt(), E, |x| {
171 #[cfg(feature = "simd")]
172 let x = match try_max_pool2d_simd(x, kernel_size, stride, padding, dilation) {
173 Ok(out) => return out.into(),
174 Err(x) => x,
175 };
176 max_pool2d::<E>(x, kernel_size, stride, padding, dilation).into()
177 })
178 }
179
180 fn max_pool2d_with_indices(
181 x: FloatTensor<Self>,
182 kernel_size: [usize; 2],
183 stride: [usize; 2],
184 padding: [usize; 2],
185 dilation: [usize; 2],
186 ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
187 module_op!(inp(x), opt(), E, |x| {
188 let (output, indices) =
189 max_pool2d_with_indices::<E, I>(x, kernel_size, stride, padding, dilation);
190 MaxPool2dWithIndices::new(output.into(), indices)
191 })
192 }
193
194 fn max_pool2d_with_indices_backward(
195 x: FloatTensor<Self>,
196 kernel_size: [usize; 2],
197 stride: [usize; 2],
198 padding: [usize; 2],
199 dilation: [usize; 2],
200 output_grad: FloatTensor<Self>,
201 indices: NdArrayTensor<I>,
202 ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
203 module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
204 let output = max_pool2d_backward::<E, I>(
205 x,
206 kernel_size,
207 stride,
208 padding,
209 dilation,
210 output_grad,
211 indices,
212 );
213 MaxPool2dBackward::new(output.into())
214 })
215 }
216
217 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
218 module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
219 x,
220 output_size
221 )
222 .into())
223 }
224
225 fn adaptive_avg_pool2d_backward(
226 x: FloatTensor<Self>,
227 grad: FloatTensor<Self>,
228 ) -> FloatTensor<Self> {
229 module_op!(inp(x, grad), opt(), E, |x, grad| {
230 adaptive_avg_pool2d_backward::<E>(x, grad).into()
231 })
232 }
233
234 fn interpolate(
235 x: FloatTensor<Self>,
236 output_size: [usize; 2],
237 options: InterpolateOptions,
238 ) -> FloatTensor<Self> {
239 match options.mode {
240 InterpolateMode::Nearest => {
241 module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
242 x,
243 output_size
244 )
245 .into())
246 }
247 InterpolateMode::Bilinear => {
248 module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
249 x,
250 output_size
251 )
252 .into())
253 }
254 InterpolateMode::Bicubic => {
255 module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
256 x,
257 output_size
258 )
259 .into())
260 }
261 }
262 }
263
264 fn interpolate_backward(
265 x: FloatTensor<Self>,
266 grad: FloatTensor<Self>,
267 output_size: [usize; 2],
268 options: InterpolateOptions,
269 ) -> FloatTensor<Self> {
270 match options.mode {
271 InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
272 nearest_interpolate_backward::<E>(x, grad, output_size).into()
273 }),
274 InterpolateMode::Bilinear => {
275 panic!("bilinear interpolation backward is not supported for ndarray backend")
276 }
277 InterpolateMode::Bicubic => {
278 panic!("bicubic interpolation backward is not supported for ndarray backend")
279 }
280 }
281 }
282
283 fn conv3d(
284 x: FloatTensor<Self>,
285 weight: FloatTensor<Self>,
286 bias: Option<FloatTensor<Self>>,
287 options: ConvOptions<3>,
288 ) -> FloatTensor<Self> {
289 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
290 x, weight, bias, options
291 )
292 .into())
293 }
294
295 fn conv_transpose3d(
296 x: FloatTensor<Self>,
297 weight: FloatTensor<Self>,
298 bias: Option<FloatTensor<Self>>,
299 options: ConvTransposeOptions<3>,
300 ) -> FloatTensor<Self> {
301 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
302 conv_transpose3d::<E>(x, weight, bias, options).into()
303 })
304 }
305}