1use super::{
2 adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3 avgpool::{avg_pool2d, avg_pool2d_backward},
4 conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5 deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6 interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7 maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9#[cfg(feature = "simd")]
10use crate::ops::simd::{
11 avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
12};
13use crate::{
14 NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
15 tensor::NdArrayTensor,
16};
17use crate::{
18 element::{IntNdArrayElement, QuantElement},
19 ops::interpolate::nearest_interpolate_backward,
20};
21use burn_backend::{
22 ElementConversion, TensorMetadata,
23 ops::{attention::attention_fallback, *},
24 tensor::FloatTensor,
25};
26
27macro_rules! module_op {
28 (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
31 #[allow(unused_parens, unreachable_patterns)]
32 match ($($x),+) {
33 ($(NdArrayTensor::F32($x)),+) => {
34 type $element = f32;
35 $op(
36 $($x.into_shared()),+
37 $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
38 )
39 }
40 ($(NdArrayTensor::F64($x)),+) => {
41 type $element = f64;
42 $op(
43 $($x.into_shared()),+
44 $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
45 )
46 }
47 _ => panic!("Data type mismatch"),
48 }
49 }};
50}
51
52impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
53 for NdArray<E, I, Q>
54where
55 NdArrayTensor: From<SharedArray<E>>,
56 NdArrayTensor: From<SharedArray<I>>,
57{
58 fn conv2d(
59 x: NdArrayTensor,
60 weight: NdArrayTensor,
61 bias: Option<NdArrayTensor>,
62 options: ConvOptions<2>,
63 ) -> NdArrayTensor {
64 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
65 #[cfg(feature = "simd")]
66 let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
67 Ok(out) => return out.into(),
68 Err(args) => args,
69 };
70 conv2d::<E>(x, weight, bias, options).into()
71 })
72 }
73
74 fn deform_conv2d(
75 x: FloatTensor<Self>,
76 offset: FloatTensor<Self>,
77 weight: FloatTensor<Self>,
78 mask: Option<FloatTensor<Self>>,
79 bias: Option<FloatTensor<Self>>,
80 options: DeformConvOptions<2>,
81 ) -> FloatTensor<Self> {
82 module_op!(
83 inp(x, offset, weight),
84 opt(mask, bias),
85 E,
86 |x, offset, weight, mask, bias| deform_conv2d::<E>(
87 x, offset, weight, mask, bias, options
88 )
89 .into()
90 )
91 }
92
93 fn deform_conv2d_backward(
94 x: FloatTensor<Self>,
95 offset: FloatTensor<Self>,
96 weight: FloatTensor<Self>,
97 mask: Option<FloatTensor<Self>>,
98 bias: Option<FloatTensor<Self>>,
99 output_grad: FloatTensor<Self>,
100 options: DeformConvOptions<2>,
101 ) -> DeformConv2dBackward<Self> {
102 module_op!(
103 inp(x, offset, weight, output_grad),
104 opt(mask, bias),
105 E,
106 |x, offset, weight, output_grad, mask, bias| {
107 let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
108 x,
109 offset,
110 weight,
111 mask,
112 bias,
113 output_grad,
114 options,
115 );
116 DeformConv2dBackward::new(
117 x.into(),
118 offset.into(),
119 weight.into(),
120 mask.map(|m| m.into()),
121 bias.map(|b| b.into()),
122 )
123 }
124 )
125 }
126
127 fn conv_transpose2d(
128 x: FloatTensor<Self>,
129 weight: FloatTensor<Self>,
130 bias: Option<FloatTensor<Self>>,
131 options: ConvTransposeOptions<2>,
132 ) -> FloatTensor<Self> {
133 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
134 conv_transpose2d::<E>(x, weight, bias, options).into()
135 })
136 }
137
138 fn avg_pool2d(
139 x: FloatTensor<Self>,
140 kernel_size: [usize; 2],
141 stride: [usize; 2],
142 padding: [usize; 2],
143 count_include_pad: bool,
144 ceil_mode: bool,
145 ) -> FloatTensor<Self> {
146 module_op!(inp(x), opt(), E, |x| {
147 #[cfg(feature = "simd")]
148 let x = match if ceil_mode {
149 Err(x)
151 } else {
152 try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad)
153 } {
154 Ok(out) => return out.into(),
155 Err(x) => x,
156 };
157 avg_pool2d::<E>(
158 x,
159 kernel_size,
160 stride,
161 padding,
162 count_include_pad,
163 ceil_mode,
164 )
165 .into()
166 })
167 }
168
169 fn avg_pool2d_backward(
170 x: FloatTensor<Self>,
171 grad: FloatTensor<Self>,
172 kernel_size: [usize; 2],
173 stride: [usize; 2],
174 padding: [usize; 2],
175 count_include_pad: bool,
176 ceil_mode: bool,
177 ) -> FloatTensor<Self> {
178 module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
179 x,
180 grad,
181 kernel_size,
182 stride,
183 padding,
184 count_include_pad,
185 ceil_mode
186 )
187 .into())
188 }
189
190 fn max_pool2d(
191 x: FloatTensor<Self>,
192 kernel_size: [usize; 2],
193 stride: [usize; 2],
194 padding: [usize; 2],
195 dilation: [usize; 2],
196 ceil_mode: bool,
197 ) -> FloatTensor<Self> {
198 module_op!(inp(x), opt(), E, |x| {
199 #[cfg(feature = "simd")]
200 let x = match if ceil_mode {
201 Err(x)
203 } else {
204 try_max_pool2d_simd(x, kernel_size, stride, padding, dilation)
205 } {
206 Ok(out) => return out.into(),
207 Err(x) => x,
208 };
209 max_pool2d::<E>(x, kernel_size, stride, padding, dilation, ceil_mode).into()
210 })
211 }
212
213 fn max_pool2d_with_indices(
214 x: FloatTensor<Self>,
215 kernel_size: [usize; 2],
216 stride: [usize; 2],
217 padding: [usize; 2],
218 dilation: [usize; 2],
219 ceil_mode: bool,
220 ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
221 module_op!(inp(x), opt(), E, |x| {
222 let (output, indices) = max_pool2d_with_indices::<E, I>(
223 x,
224 kernel_size,
225 stride,
226 padding,
227 dilation,
228 ceil_mode,
229 );
230 MaxPool2dWithIndices::new(output.into(), indices.into())
231 })
232 }
233
234 fn max_pool2d_with_indices_backward(
235 x: FloatTensor<Self>,
236 kernel_size: [usize; 2],
237 stride: [usize; 2],
238 padding: [usize; 2],
239 dilation: [usize; 2],
240 ceil_mode: bool,
241 output_grad: FloatTensor<Self>,
242 indices: NdArrayTensor,
243 ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
244 execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray<IntElem>| {
245 let indices: SharedArray<I> = idx_s.mapv(|x| x.elem()).into_shared();
248 module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
249 let output = max_pool2d_backward::<E, I>(
250 x,
251 kernel_size,
252 stride,
253 padding,
254 dilation,
255 ceil_mode,
256 output_grad,
257 indices,
258 );
259 MaxPool2dBackward::new(output.into())
260 })
261 })
262 }
263
264 fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
265 module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
266 x,
267 output_size
268 )
269 .into())
270 }
271
272 fn adaptive_avg_pool2d_backward(
273 x: FloatTensor<Self>,
274 grad: FloatTensor<Self>,
275 ) -> FloatTensor<Self> {
276 module_op!(inp(x, grad), opt(), E, |x, grad| {
277 adaptive_avg_pool2d_backward::<E>(x, grad).into()
278 })
279 }
280
281 fn interpolate(
282 x: FloatTensor<Self>,
283 output_size: [usize; 2],
284 options: InterpolateOptions,
285 ) -> FloatTensor<Self> {
286 match options.mode {
287 InterpolateMode::Nearest => {
288 module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
289 x,
290 output_size
291 )
292 .into())
293 }
294 InterpolateMode::Bilinear => {
295 let align_corners = options.align_corners;
296 module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
297 x,
298 output_size,
299 align_corners
300 )
301 .into())
302 }
303 InterpolateMode::Bicubic => {
304 let align_corners = options.align_corners;
305 module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
306 x,
307 output_size,
308 align_corners
309 )
310 .into())
311 }
312 }
313 }
314
315 fn interpolate_backward(
316 x: FloatTensor<Self>,
317 grad: FloatTensor<Self>,
318 output_size: [usize; 2],
319 options: InterpolateOptions,
320 ) -> FloatTensor<Self> {
321 match options.mode {
322 InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
323 nearest_interpolate_backward::<E>(x, grad, output_size).into()
324 }),
325 InterpolateMode::Bilinear => {
326 panic!("bilinear interpolation backward is not supported for ndarray backend")
327 }
328 InterpolateMode::Bicubic => {
329 panic!("bicubic interpolation backward is not supported for ndarray backend")
330 }
331 }
332 }
333
334 fn conv3d(
335 x: FloatTensor<Self>,
336 weight: FloatTensor<Self>,
337 bias: Option<FloatTensor<Self>>,
338 options: ConvOptions<3>,
339 ) -> FloatTensor<Self> {
340 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
341 x, weight, bias, options
342 )
343 .into())
344 }
345
346 fn conv_transpose3d(
347 x: FloatTensor<Self>,
348 weight: FloatTensor<Self>,
349 bias: Option<FloatTensor<Self>>,
350 options: ConvTransposeOptions<3>,
351 ) -> FloatTensor<Self> {
352 module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
353 conv_transpose3d::<E>(x, weight, bias, options).into()
354 })
355 }
356
357 fn attention(
358 query: FloatTensor<Self>,
359 key: FloatTensor<Self>,
360 value: FloatTensor<Self>,
361 mask: Option<burn_backend::tensor::BoolTensor<Self>>,
362 attn_bias: Option<FloatTensor<Self>>,
363 options: AttentionModuleOptions,
364 ) -> FloatTensor<Self> {
365 attention_fallback::<Self>(query, key, value, mask, attn_bias, options)
366 }
367}