Skip to main content

burn_ndarray/ops/
module.rs

1use super::{
2    adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3    avgpool::{avg_pool2d, avg_pool2d_backward},
4    conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5    deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6    interpolate::{
7        bicubic_interpolate, bilinear_interpolate, lanczos3_interpolate, nearest_interpolate,
8    },
9    maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
10};
11#[cfg(feature = "simd")]
12use crate::ops::simd::{
13    avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
14};
15use crate::{
16    NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
17    tensor::NdArrayTensor,
18};
19use crate::{
20    element::{IntNdArrayElement, QuantElement},
21    ops::interpolate::nearest_interpolate_backward,
22};
23use burn_backend::{
24    ElementConversion, TensorMetadata,
25    ops::{attention::attention_fallback, *},
26    tensor::FloatTensor,
27};
28
29macro_rules! module_op {
30    // Module op with inputs (inp), optional (opt) and arguments (args).
31    // Converts NdArrayStorage to SharedArray for compatibility with existing operations.
32    (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
33        #[allow(unused_parens, unreachable_patterns)]
34        match ($($x),+) {
35            ($(NdArrayTensor::F32($x)),+) => {
36                type $element = f32;
37                $op(
38                    $($x.into_shared()),+
39                    $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
40                )
41            }
42            ($(NdArrayTensor::F64($x)),+) => {
43                type $element = f64;
44                $op(
45                    $($x.into_shared()),+
46                    $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
47                )
48            }
49            _ => panic!("Data type mismatch"),
50        }
51    }};
52}
53
54impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
55    for NdArray<E, I, Q>
56where
57    NdArrayTensor: From<SharedArray<E>>,
58    NdArrayTensor: From<SharedArray<I>>,
59{
60    fn conv2d(
61        x: NdArrayTensor,
62        weight: NdArrayTensor,
63        bias: Option<NdArrayTensor>,
64        options: ConvOptions<2>,
65    ) -> NdArrayTensor {
66        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
67            #[cfg(feature = "simd")]
68            let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
69                Ok(out) => return out.into(),
70                Err(args) => args,
71            };
72            conv2d::<E>(x, weight, bias, options).into()
73        })
74    }
75
76    fn deform_conv2d(
77        x: FloatTensor<Self>,
78        offset: FloatTensor<Self>,
79        weight: FloatTensor<Self>,
80        mask: Option<FloatTensor<Self>>,
81        bias: Option<FloatTensor<Self>>,
82        options: DeformConvOptions<2>,
83    ) -> FloatTensor<Self> {
84        module_op!(
85            inp(x, offset, weight),
86            opt(mask, bias),
87            E,
88            |x, offset, weight, mask, bias| deform_conv2d::<E>(
89                x, offset, weight, mask, bias, options
90            )
91            .into()
92        )
93    }
94
95    fn deform_conv2d_backward(
96        x: FloatTensor<Self>,
97        offset: FloatTensor<Self>,
98        weight: FloatTensor<Self>,
99        mask: Option<FloatTensor<Self>>,
100        bias: Option<FloatTensor<Self>>,
101        output_grad: FloatTensor<Self>,
102        options: DeformConvOptions<2>,
103    ) -> DeformConv2dBackward<Self> {
104        module_op!(
105            inp(x, offset, weight, output_grad),
106            opt(mask, bias),
107            E,
108            |x, offset, weight, output_grad, mask, bias| {
109                let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
110                    x,
111                    offset,
112                    weight,
113                    mask,
114                    bias,
115                    output_grad,
116                    options,
117                );
118                DeformConv2dBackward::new(
119                    x.into(),
120                    offset.into(),
121                    weight.into(),
122                    mask.map(|m| m.into()),
123                    bias.map(|b| b.into()),
124                )
125            }
126        )
127    }
128
129    fn conv_transpose2d(
130        x: FloatTensor<Self>,
131        weight: FloatTensor<Self>,
132        bias: Option<FloatTensor<Self>>,
133        options: ConvTransposeOptions<2>,
134    ) -> FloatTensor<Self> {
135        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
136            conv_transpose2d::<E>(x, weight, bias, options).into()
137        })
138    }
139
140    fn avg_pool2d(
141        x: FloatTensor<Self>,
142        kernel_size: [usize; 2],
143        stride: [usize; 2],
144        padding: [usize; 2],
145        count_include_pad: bool,
146        ceil_mode: bool,
147    ) -> FloatTensor<Self> {
148        module_op!(inp(x), opt(), E, |x| {
149            #[cfg(feature = "simd")]
150            let x = match if ceil_mode {
151                // SIMD path doesn't support ceil_mode yet, skip it
152                Err(x)
153            } else {
154                try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad)
155            } {
156                Ok(out) => return out.into(),
157                Err(x) => x,
158            };
159            avg_pool2d::<E>(
160                x,
161                kernel_size,
162                stride,
163                padding,
164                count_include_pad,
165                ceil_mode,
166            )
167            .into()
168        })
169    }
170
171    fn avg_pool2d_backward(
172        x: FloatTensor<Self>,
173        grad: FloatTensor<Self>,
174        kernel_size: [usize; 2],
175        stride: [usize; 2],
176        padding: [usize; 2],
177        count_include_pad: bool,
178        ceil_mode: bool,
179    ) -> FloatTensor<Self> {
180        module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
181            x,
182            grad,
183            kernel_size,
184            stride,
185            padding,
186            count_include_pad,
187            ceil_mode
188        )
189        .into())
190    }
191
192    fn max_pool2d(
193        x: FloatTensor<Self>,
194        kernel_size: [usize; 2],
195        stride: [usize; 2],
196        padding: [usize; 2],
197        dilation: [usize; 2],
198        ceil_mode: bool,
199    ) -> FloatTensor<Self> {
200        module_op!(inp(x), opt(), E, |x| {
201            #[cfg(feature = "simd")]
202            let x = match if ceil_mode {
203                // SIMD path doesn't support ceil_mode yet, skip it
204                Err(x)
205            } else {
206                try_max_pool2d_simd(x, kernel_size, stride, padding, dilation)
207            } {
208                Ok(out) => return out.into(),
209                Err(x) => x,
210            };
211            max_pool2d::<E>(x, kernel_size, stride, padding, dilation, ceil_mode).into()
212        })
213    }
214
215    fn max_pool2d_with_indices(
216        x: FloatTensor<Self>,
217        kernel_size: [usize; 2],
218        stride: [usize; 2],
219        padding: [usize; 2],
220        dilation: [usize; 2],
221        ceil_mode: bool,
222    ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
223        module_op!(inp(x), opt(), E, |x| {
224            let (output, indices) = max_pool2d_with_indices::<E, I>(
225                x,
226                kernel_size,
227                stride,
228                padding,
229                dilation,
230                ceil_mode,
231            );
232            MaxPool2dWithIndices::new(output.into(), indices.into())
233        })
234    }
235
236    fn max_pool2d_with_indices_backward(
237        x: FloatTensor<Self>,
238        kernel_size: [usize; 2],
239        stride: [usize; 2],
240        padding: [usize; 2],
241        dilation: [usize; 2],
242        ceil_mode: bool,
243        output_grad: FloatTensor<Self>,
244        indices: NdArrayTensor,
245    ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
246        execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray<IntElem>| {
247            // Convert indices from runtime dtype to the expected I type
248            // (pool indices are bounded by tensor dimensions, so conversion is safe)
249            let indices: SharedArray<I> = idx_s.mapv(|x| x.elem()).into_shared();
250            module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
251                let output = max_pool2d_backward::<E, I>(
252                    x,
253                    kernel_size,
254                    stride,
255                    padding,
256                    dilation,
257                    ceil_mode,
258                    output_grad,
259                    indices,
260                );
261                MaxPool2dBackward::new(output.into())
262            })
263        })
264    }
265
266    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
267        module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
268            x,
269            output_size
270        )
271        .into())
272    }
273
274    fn adaptive_avg_pool2d_backward(
275        x: FloatTensor<Self>,
276        grad: FloatTensor<Self>,
277    ) -> FloatTensor<Self> {
278        module_op!(inp(x, grad), opt(), E, |x, grad| {
279            adaptive_avg_pool2d_backward::<E>(x, grad).into()
280        })
281    }
282
283    fn interpolate(
284        x: FloatTensor<Self>,
285        output_size: [usize; 2],
286        options: InterpolateOptions,
287    ) -> FloatTensor<Self> {
288        match options.mode {
289            InterpolateMode::Nearest => {
290                module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
291                    x,
292                    output_size
293                )
294                .into())
295            }
296            InterpolateMode::Bilinear => {
297                let align_corners = options.align_corners;
298                module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
299                    x,
300                    output_size,
301                    align_corners
302                )
303                .into())
304            }
305            InterpolateMode::Bicubic => {
306                let align_corners = options.align_corners;
307                module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
308                    x,
309                    output_size,
310                    align_corners
311                )
312                .into())
313            }
314            InterpolateMode::Lanczos3 => {
315                let align_corners = options.align_corners;
316                module_op!(inp(x), opt(), E, |x| lanczos3_interpolate::<E>(
317                    x,
318                    output_size,
319                    align_corners
320                )
321                .into())
322            }
323        }
324    }
325
326    fn interpolate_backward(
327        x: FloatTensor<Self>,
328        grad: FloatTensor<Self>,
329        output_size: [usize; 2],
330        options: InterpolateOptions,
331    ) -> FloatTensor<Self> {
332        match options.mode {
333            InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
334                nearest_interpolate_backward::<E>(x, grad, output_size).into()
335            }),
336            InterpolateMode::Bilinear => {
337                panic!("bilinear interpolation backward is not supported for ndarray backend")
338            }
339            InterpolateMode::Bicubic => {
340                panic!("bicubic interpolation backward is not supported for ndarray backend")
341            }
342            InterpolateMode::Lanczos3 => {
343                panic!("lanczos3 interpolation backward is not supported for ndarray backend")
344            }
345        }
346    }
347
348    fn conv3d(
349        x: FloatTensor<Self>,
350        weight: FloatTensor<Self>,
351        bias: Option<FloatTensor<Self>>,
352        options: ConvOptions<3>,
353    ) -> FloatTensor<Self> {
354        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
355            x, weight, bias, options
356        )
357        .into())
358    }
359
360    fn conv_transpose3d(
361        x: FloatTensor<Self>,
362        weight: FloatTensor<Self>,
363        bias: Option<FloatTensor<Self>>,
364        options: ConvTransposeOptions<3>,
365    ) -> FloatTensor<Self> {
366        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
367            conv_transpose3d::<E>(x, weight, bias, options).into()
368        })
369    }
370
371    fn attention(
372        query: FloatTensor<Self>,
373        key: FloatTensor<Self>,
374        value: FloatTensor<Self>,
375        mask: Option<burn_backend::tensor::BoolTensor<Self>>,
376        attn_bias: Option<FloatTensor<Self>>,
377        options: AttentionModuleOptions,
378    ) -> FloatTensor<Self> {
379        attention_fallback::<Self>(query, key, value, mask, attn_bias, options)
380    }
381
382    fn rfft(_signal: FloatTensor<Self>, _dim: usize) -> (FloatTensor<Self>, FloatTensor<Self>) {
383        todo!("rfft is not supported for ndarray")
384    }
385
386    fn irfft(
387        _spectrum_re: FloatTensor<Self>,
388        _spectrum_im: FloatTensor<Self>,
389        _dim: usize,
390    ) -> FloatTensor<Self> {
391        todo!("irfft is not supported for ndarray")
392    }
393}