1use super::{
2 adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3 avgpool::{avg_pool2d, avg_pool2d_backward},
4 conv::{conv2d, conv3d, conv_transpose2d, conv_transpose3d},
5 deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6 interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7 maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray, NdArrayTensorFloat};
10use crate::{
11 element::{IntNdArrayElement, QuantElement},
12 ops::interpolate::nearest_interpolate_backward,
13};
14use burn_tensor::ops::*;
15
16macro_rules! module_op {
17 (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
19 #[allow(unused_parens, unreachable_patterns)]
20 match ($($x),+) {
21 ($(NdArrayTensorFloat::F32($x)),+) => {
22 type $element = f32;
23 $op(
24 $($x),+
25 $(, $opt.map(|o| match o { NdArrayTensorFloat::F32(val) => val, _ => panic!("Optional argument type mismatch") }))*
26 )
27 }
28 ($(NdArrayTensorFloat::F64($x)),+) => {
29 type $element = f64;
30 $op(
31 $($x),+
32 $(, $opt.map(|o| match o { NdArrayTensorFloat::F64(val) => val, _ => panic!("Optional argument type mismatch") }))*
33 )
34 }
35 _ => panic!("Data type mismatch"),
36 }
37 }};
38}
39
40impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
41 for NdArray<E, I, Q>
42{
43 fn conv2d(
44 x: NdArrayTensorFloat,
45 weight: NdArrayTensorFloat,
46 bias: Option<NdArrayTensorFloat>,
47 options: ConvOptions<2>,
48 ) -> NdArrayTensorFloat {
49 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv2d::<
50 E,
51 I,
52 Q,
53 >(
54 x, weight, bias, options
55 )
56 .into())
57 }
58
59 fn deform_conv2d(
60 x: FloatTensor<Self>,
61 offset: FloatTensor<Self>,
62 weight: FloatTensor<Self>,
63 mask: Option<FloatTensor<Self>>,
64 bias: Option<FloatTensor<Self>>,
65 options: DeformConvOptions<2>,
66 ) -> FloatTensor<Self> {
67 module_op!(
68 inp(x, offset, weight),
69 opt(mask, bias),
70 E,
71 |x, offset, weight, mask, bias| deform_conv2d::<E>(
72 x, offset, weight, mask, bias, options
73 )
74 .into()
75 )
76 }
77
78 fn deform_conv2d_backward(
79 x: FloatTensor<Self>,
80 offset: FloatTensor<Self>,
81 weight: FloatTensor<Self>,
82 mask: Option<FloatTensor<Self>>,
83 bias: Option<FloatTensor<Self>>,
84 output_grad: FloatTensor<Self>,
85 options: DeformConvOptions<2>,
86 ) -> DeformConv2dBackward<Self> {
87 module_op!(
88 inp(x, offset, weight, output_grad),
89 opt(mask, bias),
90 E,
91 |x, offset, weight, output_grad, mask, bias| {
92 let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E, I, Q>(
93 x,
94 offset,
95 weight,
96 mask,
97 bias,
98 output_grad,
99 options,
100 );
101 DeformConv2dBackward::new(
102 x.into(),
103 offset.into(),
104 weight.into(),
105 mask.map(|m| m.into()),
106 bias.map(|b| b.into()),
107 )
108 }
109 )
110 }
111
112 fn conv_transpose2d(
113 x: FloatTensor<Self>,
114 weight: FloatTensor<Self>,
115 bias: Option<FloatTensor<Self>>,
116 options: ConvTransposeOptions<2>,
117 ) -> FloatTensor<Self> {
118 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
119 conv_transpose2d::<E>(x, weight, bias, options).into()
120 })
121 }
122
123 fn avg_pool2d(
124 x: FloatTensor<Self>,
125 kernel_size: [usize; 2],
126 stride: [usize; 2],
127 padding: [usize; 2],
128 count_include_pad: bool,
129 ) -> FloatTensor<Self> {
130 module_op!(inp(x), opt(), E, |x| avg_pool2d::<E>(
131 x,
132 kernel_size,
133 stride,
134 padding,
135 count_include_pad
136 )
137 .into())
138 }
139
140 fn avg_pool2d_backward(
141 x: FloatTensor<Self>,
142 grad: FloatTensor<Self>,
143 kernel_size: [usize; 2],
144 stride: [usize; 2],
145 padding: [usize; 2],
146 count_include_pad: bool,
147 ) -> FloatTensor<Self> {
148 module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
149 x,
150 grad,
151 kernel_size,
152 stride,
153 padding,
154 count_include_pad
155 )
156 .into())
157 }
158
159 fn max_pool2d(
160 x: FloatTensor<Self>,
161 kernel_size: [usize; 2],
162 stride: [usize; 2],
163 padding: [usize; 2],
164 dilation: [usize; 2],
165 ) -> FloatTensor<Self> {
166 module_op!(inp(x), opt(), E, |x| max_pool2d::<E, I, Q>(
167 x,
168 kernel_size,
169 stride,
170 padding,
171 dilation
172 )
173 .into())
174 }
175
176 fn max_pool2d_with_indices(
177 x: FloatTensor<Self>,
178 kernel_size: [usize; 2],
179 stride: [usize; 2],
180 padding: [usize; 2],
181 dilation: [usize; 2],
182 ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
183 module_op!(inp(x), opt(), E, |x| {
184 let (output, indices) =
185 max_pool2d_with_indices::<E, I, Q>(x, kernel_size, stride, padding, dilation);
186 MaxPool2dWithIndices::new(output.into(), indices)
187 })
188 }
189
190 fn max_pool2d_with_indices_backward(
191 x: FloatTensor<Self>,
192 kernel_size: [usize; 2],
193 stride: [usize; 2],
194 padding: [usize; 2],
195 dilation: [usize; 2],
196 output_grad: FloatTensor<Self>,
197 indices: NdArrayTensor<I>,
198 ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
199 module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
200 let output = max_pool2d_backward::<E, I>(
201 x,
202 kernel_size,
203 stride,
204 padding,
205 dilation,
206 output_grad,
207 indices,
208 );
209 MaxPool2dBackward::new(output.into())
210 })
211 }
212
213 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
214 module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
215 x,
216 output_size
217 )
218 .into())
219 }
220
221 fn adaptive_avg_pool2d_backward(
222 x: FloatTensor<Self>,
223 grad: FloatTensor<Self>,
224 ) -> FloatTensor<Self> {
225 module_op!(inp(x, grad), opt(), E, |x, grad| {
226 adaptive_avg_pool2d_backward::<E>(x, grad).into()
227 })
228 }
229
230 fn interpolate(
231 x: FloatTensor<Self>,
232 output_size: [usize; 2],
233 options: InterpolateOptions,
234 ) -> FloatTensor<Self> {
235 match options.mode {
236 InterpolateMode::Nearest => {
237 module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
238 x,
239 output_size
240 )
241 .into())
242 }
243 InterpolateMode::Bilinear => {
244 module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
245 x,
246 output_size
247 )
248 .into())
249 }
250 InterpolateMode::Bicubic => {
251 module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
252 x,
253 output_size
254 )
255 .into())
256 }
257 }
258 }
259
260 fn interpolate_backward(
261 x: FloatTensor<Self>,
262 grad: FloatTensor<Self>,
263 output_size: [usize; 2],
264 options: InterpolateOptions,
265 ) -> FloatTensor<Self> {
266 match options.mode {
267 InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
268 nearest_interpolate_backward::<E>(x, grad, output_size).into()
269 }),
270 InterpolateMode::Bilinear => {
271 panic!("bilinear interpolation backward is not supported for ndarray backend")
272 }
273 InterpolateMode::Bicubic => {
274 panic!("bicubic interpolation backward is not supported for ndarray backend")
275 }
276 }
277 }
278
279 fn conv3d(
280 x: FloatTensor<Self>,
281 weight: FloatTensor<Self>,
282 bias: Option<FloatTensor<Self>>,
283 options: ConvOptions<3>,
284 ) -> FloatTensor<Self> {
285 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<
286 E,
287 I,
288 Q,
289 >(
290 x, weight, bias, options
291 )
292 .into())
293 }
294
295 fn conv_transpose3d(
296 x: FloatTensor<Self>,
297 weight: FloatTensor<Self>,
298 bias: Option<FloatTensor<Self>>,
299 options: ConvTransposeOptions<3>,
300 ) -> FloatTensor<Self> {
301 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
302 conv_transpose3d::<E>(x, weight, bias, options).into()
303 })
304 }
305}