1use super::{
2 adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3 avgpool::{avg_pool2d, avg_pool2d_backward},
4 conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5 deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6 interpolate::{
7 bicubic_interpolate, bilinear_interpolate, lanczos3_interpolate, nearest_interpolate,
8 },
9 maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
10};
11#[cfg(feature = "simd")]
12use crate::ops::simd::{
13 avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
14};
15use crate::{
16 NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
17 tensor::NdArrayTensor,
18};
19use crate::{
20 element::{IntNdArrayElement, QuantElement},
21 ops::interpolate::nearest_interpolate_backward,
22};
23use burn_backend::{
24 ElementConversion, TensorMetadata,
25 ops::{attention::attention_fallback, *},
26 tensor::FloatTensor,
27};
28
29macro_rules! module_op {
30 (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
33 #[allow(unused_parens, unreachable_patterns)]
34 match ($($x),+) {
35 ($(NdArrayTensor::F32($x)),+) => {
36 type $element = f32;
37 $op(
38 $($x.into_shared()),+
39 $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
40 )
41 }
42 ($(NdArrayTensor::F64($x)),+) => {
43 type $element = f64;
44 $op(
45 $($x.into_shared()),+
46 $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
47 )
48 }
49 _ => panic!("Data type mismatch"),
50 }
51 }};
52}
53
54impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
55 for NdArray<E, I, Q>
56where
57 NdArrayTensor: From<SharedArray<E>>,
58 NdArrayTensor: From<SharedArray<I>>,
59{
60 fn conv2d(
61 x: NdArrayTensor,
62 weight: NdArrayTensor,
63 bias: Option<NdArrayTensor>,
64 options: ConvOptions<2>,
65 ) -> NdArrayTensor {
66 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
67 #[cfg(feature = "simd")]
68 let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
69 Ok(out) => return out.into(),
70 Err(args) => args,
71 };
72 conv2d::<E>(x, weight, bias, options).into()
73 })
74 }
75
76 fn deform_conv2d(
77 x: FloatTensor<Self>,
78 offset: FloatTensor<Self>,
79 weight: FloatTensor<Self>,
80 mask: Option<FloatTensor<Self>>,
81 bias: Option<FloatTensor<Self>>,
82 options: DeformConvOptions<2>,
83 ) -> FloatTensor<Self> {
84 module_op!(
85 inp(x, offset, weight),
86 opt(mask, bias),
87 E,
88 |x, offset, weight, mask, bias| deform_conv2d::<E>(
89 x, offset, weight, mask, bias, options
90 )
91 .into()
92 )
93 }
94
95 fn deform_conv2d_backward(
96 x: FloatTensor<Self>,
97 offset: FloatTensor<Self>,
98 weight: FloatTensor<Self>,
99 mask: Option<FloatTensor<Self>>,
100 bias: Option<FloatTensor<Self>>,
101 output_grad: FloatTensor<Self>,
102 options: DeformConvOptions<2>,
103 ) -> DeformConv2dBackward<Self> {
104 module_op!(
105 inp(x, offset, weight, output_grad),
106 opt(mask, bias),
107 E,
108 |x, offset, weight, output_grad, mask, bias| {
109 let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
110 x,
111 offset,
112 weight,
113 mask,
114 bias,
115 output_grad,
116 options,
117 );
118 DeformConv2dBackward::new(
119 x.into(),
120 offset.into(),
121 weight.into(),
122 mask.map(|m| m.into()),
123 bias.map(|b| b.into()),
124 )
125 }
126 )
127 }
128
129 fn conv_transpose2d(
130 x: FloatTensor<Self>,
131 weight: FloatTensor<Self>,
132 bias: Option<FloatTensor<Self>>,
133 options: ConvTransposeOptions<2>,
134 ) -> FloatTensor<Self> {
135 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
136 conv_transpose2d::<E>(x, weight, bias, options).into()
137 })
138 }
139
140 fn avg_pool2d(
141 x: FloatTensor<Self>,
142 kernel_size: [usize; 2],
143 stride: [usize; 2],
144 padding: [usize; 2],
145 count_include_pad: bool,
146 ceil_mode: bool,
147 ) -> FloatTensor<Self> {
148 module_op!(inp(x), opt(), E, |x| {
149 #[cfg(feature = "simd")]
150 let x = match if ceil_mode {
151 Err(x)
153 } else {
154 try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad)
155 } {
156 Ok(out) => return out.into(),
157 Err(x) => x,
158 };
159 avg_pool2d::<E>(
160 x,
161 kernel_size,
162 stride,
163 padding,
164 count_include_pad,
165 ceil_mode,
166 )
167 .into()
168 })
169 }
170
171 fn avg_pool2d_backward(
172 x: FloatTensor<Self>,
173 grad: FloatTensor<Self>,
174 kernel_size: [usize; 2],
175 stride: [usize; 2],
176 padding: [usize; 2],
177 count_include_pad: bool,
178 ceil_mode: bool,
179 ) -> FloatTensor<Self> {
180 module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
181 x,
182 grad,
183 kernel_size,
184 stride,
185 padding,
186 count_include_pad,
187 ceil_mode
188 )
189 .into())
190 }
191
192 fn max_pool2d(
193 x: FloatTensor<Self>,
194 kernel_size: [usize; 2],
195 stride: [usize; 2],
196 padding: [usize; 2],
197 dilation: [usize; 2],
198 ceil_mode: bool,
199 ) -> FloatTensor<Self> {
200 module_op!(inp(x), opt(), E, |x| {
201 #[cfg(feature = "simd")]
202 let x = match if ceil_mode {
203 Err(x)
205 } else {
206 try_max_pool2d_simd(x, kernel_size, stride, padding, dilation)
207 } {
208 Ok(out) => return out.into(),
209 Err(x) => x,
210 };
211 max_pool2d::<E>(x, kernel_size, stride, padding, dilation, ceil_mode).into()
212 })
213 }
214
215 fn max_pool2d_with_indices(
216 x: FloatTensor<Self>,
217 kernel_size: [usize; 2],
218 stride: [usize; 2],
219 padding: [usize; 2],
220 dilation: [usize; 2],
221 ceil_mode: bool,
222 ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
223 module_op!(inp(x), opt(), E, |x| {
224 let (output, indices) = max_pool2d_with_indices::<E, I>(
225 x,
226 kernel_size,
227 stride,
228 padding,
229 dilation,
230 ceil_mode,
231 );
232 MaxPool2dWithIndices::new(output.into(), indices.into())
233 })
234 }
235
236 fn max_pool2d_with_indices_backward(
237 x: FloatTensor<Self>,
238 kernel_size: [usize; 2],
239 stride: [usize; 2],
240 padding: [usize; 2],
241 dilation: [usize; 2],
242 ceil_mode: bool,
243 output_grad: FloatTensor<Self>,
244 indices: NdArrayTensor,
245 ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
246 execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray<IntElem>| {
247 let indices: SharedArray<I> = idx_s.mapv(|x| x.elem()).into_shared();
250 module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
251 let output = max_pool2d_backward::<E, I>(
252 x,
253 kernel_size,
254 stride,
255 padding,
256 dilation,
257 ceil_mode,
258 output_grad,
259 indices,
260 );
261 MaxPool2dBackward::new(output.into())
262 })
263 })
264 }
265
266 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
267 module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
268 x,
269 output_size
270 )
271 .into())
272 }
273
274 fn adaptive_avg_pool2d_backward(
275 x: FloatTensor<Self>,
276 grad: FloatTensor<Self>,
277 ) -> FloatTensor<Self> {
278 module_op!(inp(x, grad), opt(), E, |x, grad| {
279 adaptive_avg_pool2d_backward::<E>(x, grad).into()
280 })
281 }
282
283 fn interpolate(
284 x: FloatTensor<Self>,
285 output_size: [usize; 2],
286 options: InterpolateOptions,
287 ) -> FloatTensor<Self> {
288 match options.mode {
289 InterpolateMode::Nearest => {
290 module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
291 x,
292 output_size
293 )
294 .into())
295 }
296 InterpolateMode::Bilinear => {
297 let align_corners = options.align_corners;
298 module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
299 x,
300 output_size,
301 align_corners
302 )
303 .into())
304 }
305 InterpolateMode::Bicubic => {
306 let align_corners = options.align_corners;
307 module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
308 x,
309 output_size,
310 align_corners
311 )
312 .into())
313 }
314 InterpolateMode::Lanczos3 => {
315 let align_corners = options.align_corners;
316 module_op!(inp(x), opt(), E, |x| lanczos3_interpolate::<E>(
317 x,
318 output_size,
319 align_corners
320 )
321 .into())
322 }
323 }
324 }
325
326 fn interpolate_backward(
327 x: FloatTensor<Self>,
328 grad: FloatTensor<Self>,
329 output_size: [usize; 2],
330 options: InterpolateOptions,
331 ) -> FloatTensor<Self> {
332 match options.mode {
333 InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
334 nearest_interpolate_backward::<E>(x, grad, output_size).into()
335 }),
336 InterpolateMode::Bilinear => {
337 panic!("bilinear interpolation backward is not supported for ndarray backend")
338 }
339 InterpolateMode::Bicubic => {
340 panic!("bicubic interpolation backward is not supported for ndarray backend")
341 }
342 InterpolateMode::Lanczos3 => {
343 panic!("lanczos3 interpolation backward is not supported for ndarray backend")
344 }
345 }
346 }
347
348 fn conv3d(
349 x: FloatTensor<Self>,
350 weight: FloatTensor<Self>,
351 bias: Option<FloatTensor<Self>>,
352 options: ConvOptions<3>,
353 ) -> FloatTensor<Self> {
354 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
355 x, weight, bias, options
356 )
357 .into())
358 }
359
360 fn conv_transpose3d(
361 x: FloatTensor<Self>,
362 weight: FloatTensor<Self>,
363 bias: Option<FloatTensor<Self>>,
364 options: ConvTransposeOptions<3>,
365 ) -> FloatTensor<Self> {
366 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
367 conv_transpose3d::<E>(x, weight, bias, options).into()
368 })
369 }
370
371 fn attention(
372 query: FloatTensor<Self>,
373 key: FloatTensor<Self>,
374 value: FloatTensor<Self>,
375 mask: Option<burn_backend::tensor::BoolTensor<Self>>,
376 attn_bias: Option<FloatTensor<Self>>,
377 options: AttentionModuleOptions,
378 ) -> FloatTensor<Self> {
379 attention_fallback::<Self>(query, key, value, mask, attn_bias, options)
380 }
381
382 fn rfft(_signal: FloatTensor<Self>, _dim: usize) -> (FloatTensor<Self>, FloatTensor<Self>) {
383 todo!("rfft is not supported for ndarray")
384 }
385
386 fn irfft(
387 _spectrum_re: FloatTensor<Self>,
388 _spectrum_im: FloatTensor<Self>,
389 _dim: usize,
390 ) -> FloatTensor<Self> {
391 todo!("irfft is not supported for ndarray")
392 }
393}