Skip to main content

burn_ndarray/ops/
module.rs

1use super::{
2    adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3    avgpool::{avg_pool2d, avg_pool2d_backward},
4    conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5    deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6    interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7    maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9#[cfg(feature = "simd")]
10use crate::ops::simd::{
11    avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
12};
13use crate::{
14    NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
15    tensor::NdArrayTensor,
16};
17use crate::{
18    element::{IntNdArrayElement, QuantElement},
19    ops::interpolate::nearest_interpolate_backward,
20};
21use burn_backend::{
22    ElementConversion, TensorMetadata,
23    ops::{attention::attention_fallback, *},
24    tensor::FloatTensor,
25};
26
27macro_rules! module_op {
28    // Module op with inputs (inp), optional (opt) and arguments (args).
29    // Converts NdArrayStorage to SharedArray for compatibility with existing operations.
30    (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
31        #[allow(unused_parens, unreachable_patterns)]
32        match ($($x),+) {
33            ($(NdArrayTensor::F32($x)),+) => {
34                type $element = f32;
35                $op(
36                    $($x.into_shared()),+
37                    $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
38                )
39            }
40            ($(NdArrayTensor::F64($x)),+) => {
41                type $element = f64;
42                $op(
43                    $($x.into_shared()),+
44                    $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
45                )
46            }
47            _ => panic!("Data type mismatch"),
48        }
49    }};
50}
51
52impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
53    for NdArray<E, I, Q>
54where
55    NdArrayTensor: From<SharedArray<E>>,
56    NdArrayTensor: From<SharedArray<I>>,
57{
58    fn conv2d(
59        x: NdArrayTensor,
60        weight: NdArrayTensor,
61        bias: Option<NdArrayTensor>,
62        options: ConvOptions<2>,
63    ) -> NdArrayTensor {
64        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
65            #[cfg(feature = "simd")]
66            let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
67                Ok(out) => return out.into(),
68                Err(args) => args,
69            };
70            conv2d::<E>(x, weight, bias, options).into()
71        })
72    }
73
74    fn deform_conv2d(
75        x: FloatTensor<Self>,
76        offset: FloatTensor<Self>,
77        weight: FloatTensor<Self>,
78        mask: Option<FloatTensor<Self>>,
79        bias: Option<FloatTensor<Self>>,
80        options: DeformConvOptions<2>,
81    ) -> FloatTensor<Self> {
82        module_op!(
83            inp(x, offset, weight),
84            opt(mask, bias),
85            E,
86            |x, offset, weight, mask, bias| deform_conv2d::<E>(
87                x, offset, weight, mask, bias, options
88            )
89            .into()
90        )
91    }
92
93    fn deform_conv2d_backward(
94        x: FloatTensor<Self>,
95        offset: FloatTensor<Self>,
96        weight: FloatTensor<Self>,
97        mask: Option<FloatTensor<Self>>,
98        bias: Option<FloatTensor<Self>>,
99        output_grad: FloatTensor<Self>,
100        options: DeformConvOptions<2>,
101    ) -> DeformConv2dBackward<Self> {
102        module_op!(
103            inp(x, offset, weight, output_grad),
104            opt(mask, bias),
105            E,
106            |x, offset, weight, output_grad, mask, bias| {
107                let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
108                    x,
109                    offset,
110                    weight,
111                    mask,
112                    bias,
113                    output_grad,
114                    options,
115                );
116                DeformConv2dBackward::new(
117                    x.into(),
118                    offset.into(),
119                    weight.into(),
120                    mask.map(|m| m.into()),
121                    bias.map(|b| b.into()),
122                )
123            }
124        )
125    }
126
127    fn conv_transpose2d(
128        x: FloatTensor<Self>,
129        weight: FloatTensor<Self>,
130        bias: Option<FloatTensor<Self>>,
131        options: ConvTransposeOptions<2>,
132    ) -> FloatTensor<Self> {
133        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
134            conv_transpose2d::<E>(x, weight, bias, options).into()
135        })
136    }
137
138    fn avg_pool2d(
139        x: FloatTensor<Self>,
140        kernel_size: [usize; 2],
141        stride: [usize; 2],
142        padding: [usize; 2],
143        count_include_pad: bool,
144        ceil_mode: bool,
145    ) -> FloatTensor<Self> {
146        module_op!(inp(x), opt(), E, |x| {
147            #[cfg(feature = "simd")]
148            let x = match if ceil_mode {
149                // SIMD path doesn't support ceil_mode yet, skip it
150                Err(x)
151            } else {
152                try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad)
153            } {
154                Ok(out) => return out.into(),
155                Err(x) => x,
156            };
157            avg_pool2d::<E>(
158                x,
159                kernel_size,
160                stride,
161                padding,
162                count_include_pad,
163                ceil_mode,
164            )
165            .into()
166        })
167    }
168
169    fn avg_pool2d_backward(
170        x: FloatTensor<Self>,
171        grad: FloatTensor<Self>,
172        kernel_size: [usize; 2],
173        stride: [usize; 2],
174        padding: [usize; 2],
175        count_include_pad: bool,
176        ceil_mode: bool,
177    ) -> FloatTensor<Self> {
178        module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
179            x,
180            grad,
181            kernel_size,
182            stride,
183            padding,
184            count_include_pad,
185            ceil_mode
186        )
187        .into())
188    }
189
190    fn max_pool2d(
191        x: FloatTensor<Self>,
192        kernel_size: [usize; 2],
193        stride: [usize; 2],
194        padding: [usize; 2],
195        dilation: [usize; 2],
196        ceil_mode: bool,
197    ) -> FloatTensor<Self> {
198        module_op!(inp(x), opt(), E, |x| {
199            #[cfg(feature = "simd")]
200            let x = match if ceil_mode {
201                // SIMD path doesn't support ceil_mode yet, skip it
202                Err(x)
203            } else {
204                try_max_pool2d_simd(x, kernel_size, stride, padding, dilation)
205            } {
206                Ok(out) => return out.into(),
207                Err(x) => x,
208            };
209            max_pool2d::<E>(x, kernel_size, stride, padding, dilation, ceil_mode).into()
210        })
211    }
212
213    fn max_pool2d_with_indices(
214        x: FloatTensor<Self>,
215        kernel_size: [usize; 2],
216        stride: [usize; 2],
217        padding: [usize; 2],
218        dilation: [usize; 2],
219        ceil_mode: bool,
220    ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
221        module_op!(inp(x), opt(), E, |x| {
222            let (output, indices) = max_pool2d_with_indices::<E, I>(
223                x,
224                kernel_size,
225                stride,
226                padding,
227                dilation,
228                ceil_mode,
229            );
230            MaxPool2dWithIndices::new(output.into(), indices.into())
231        })
232    }
233
234    fn max_pool2d_with_indices_backward(
235        x: FloatTensor<Self>,
236        kernel_size: [usize; 2],
237        stride: [usize; 2],
238        padding: [usize; 2],
239        dilation: [usize; 2],
240        ceil_mode: bool,
241        output_grad: FloatTensor<Self>,
242        indices: NdArrayTensor,
243    ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
244        execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray<IntElem>| {
245            // Convert indices from runtime dtype to the expected I type
246            // (pool indices are bounded by tensor dimensions, so conversion is safe)
247            let indices: SharedArray<I> = idx_s.mapv(|x| x.elem()).into_shared();
248            module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
249                let output = max_pool2d_backward::<E, I>(
250                    x,
251                    kernel_size,
252                    stride,
253                    padding,
254                    dilation,
255                    ceil_mode,
256                    output_grad,
257                    indices,
258                );
259                MaxPool2dBackward::new(output.into())
260            })
261        })
262    }
263
264    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
265        module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
266            x,
267            output_size
268        )
269        .into())
270    }
271
272    fn adaptive_avg_pool2d_backward(
273        x: FloatTensor<Self>,
274        grad: FloatTensor<Self>,
275    ) -> FloatTensor<Self> {
276        module_op!(inp(x, grad), opt(), E, |x, grad| {
277            adaptive_avg_pool2d_backward::<E>(x, grad).into()
278        })
279    }
280
281    fn interpolate(
282        x: FloatTensor<Self>,
283        output_size: [usize; 2],
284        options: InterpolateOptions,
285    ) -> FloatTensor<Self> {
286        match options.mode {
287            InterpolateMode::Nearest => {
288                module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
289                    x,
290                    output_size
291                )
292                .into())
293            }
294            InterpolateMode::Bilinear => {
295                let align_corners = options.align_corners;
296                module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
297                    x,
298                    output_size,
299                    align_corners
300                )
301                .into())
302            }
303            InterpolateMode::Bicubic => {
304                let align_corners = options.align_corners;
305                module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
306                    x,
307                    output_size,
308                    align_corners
309                )
310                .into())
311            }
312        }
313    }
314
315    fn interpolate_backward(
316        x: FloatTensor<Self>,
317        grad: FloatTensor<Self>,
318        output_size: [usize; 2],
319        options: InterpolateOptions,
320    ) -> FloatTensor<Self> {
321        match options.mode {
322            InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
323                nearest_interpolate_backward::<E>(x, grad, output_size).into()
324            }),
325            InterpolateMode::Bilinear => {
326                panic!("bilinear interpolation backward is not supported for ndarray backend")
327            }
328            InterpolateMode::Bicubic => {
329                panic!("bicubic interpolation backward is not supported for ndarray backend")
330            }
331        }
332    }
333
334    fn conv3d(
335        x: FloatTensor<Self>,
336        weight: FloatTensor<Self>,
337        bias: Option<FloatTensor<Self>>,
338        options: ConvOptions<3>,
339    ) -> FloatTensor<Self> {
340        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
341            x, weight, bias, options
342        )
343        .into())
344    }
345
346    fn conv_transpose3d(
347        x: FloatTensor<Self>,
348        weight: FloatTensor<Self>,
349        bias: Option<FloatTensor<Self>>,
350        options: ConvTransposeOptions<3>,
351    ) -> FloatTensor<Self> {
352        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
353            conv_transpose3d::<E>(x, weight, bias, options).into()
354        })
355    }
356
357    fn attention(
358        query: FloatTensor<Self>,
359        key: FloatTensor<Self>,
360        value: FloatTensor<Self>,
361        mask: Option<burn_backend::tensor::BoolTensor<Self>>,
362        attn_bias: Option<FloatTensor<Self>>,
363        options: AttentionModuleOptions,
364    ) -> FloatTensor<Self> {
365        attention_fallback::<Self>(query, key, value, mask, attn_bias, options)
366    }
367}