burn_ndarray/ops/
module.rs

1use super::{
2    adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3    avgpool::{avg_pool2d, avg_pool2d_backward},
4    conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5    deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6    interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7    maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9#[cfg(feature = "simd")]
10use crate::ops::simd::{
11    avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
12};
13use crate::{
14    NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
15    tensor::NdArrayTensor,
16};
17use crate::{
18    element::{IntNdArrayElement, QuantElement},
19    ops::interpolate::nearest_interpolate_backward,
20};
21use burn_backend::{ElementConversion, TensorMetadata, ops::*, tensor::FloatTensor};
22
23macro_rules! module_op {
24    // Module op with inputs (inp), optional (opt) and arguments (args).
25    // Converts NdArrayStorage to SharedArray for compatibility with existing operations.
26    (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
27        #[allow(unused_parens, unreachable_patterns)]
28        match ($($x),+) {
29            ($(NdArrayTensor::F32($x)),+) => {
30                type $element = f32;
31                $op(
32                    $($x.into_shared()),+
33                    $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
34                )
35            }
36            ($(NdArrayTensor::F64($x)),+) => {
37                type $element = f64;
38                $op(
39                    $($x.into_shared()),+
40                    $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))*
41                )
42            }
43            _ => panic!("Data type mismatch"),
44        }
45    }};
46}
47
48impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
49    for NdArray<E, I, Q>
50where
51    NdArrayTensor: From<SharedArray<E>>,
52    NdArrayTensor: From<SharedArray<I>>,
53{
54    fn conv2d(
55        x: NdArrayTensor,
56        weight: NdArrayTensor,
57        bias: Option<NdArrayTensor>,
58        options: ConvOptions<2>,
59    ) -> NdArrayTensor {
60        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
61            #[cfg(feature = "simd")]
62            let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
63                Ok(out) => return out.into(),
64                Err(args) => args,
65            };
66            conv2d::<E>(x, weight, bias, options).into()
67        })
68    }
69
70    fn deform_conv2d(
71        x: FloatTensor<Self>,
72        offset: FloatTensor<Self>,
73        weight: FloatTensor<Self>,
74        mask: Option<FloatTensor<Self>>,
75        bias: Option<FloatTensor<Self>>,
76        options: DeformConvOptions<2>,
77    ) -> FloatTensor<Self> {
78        module_op!(
79            inp(x, offset, weight),
80            opt(mask, bias),
81            E,
82            |x, offset, weight, mask, bias| deform_conv2d::<E>(
83                x, offset, weight, mask, bias, options
84            )
85            .into()
86        )
87    }
88
89    fn deform_conv2d_backward(
90        x: FloatTensor<Self>,
91        offset: FloatTensor<Self>,
92        weight: FloatTensor<Self>,
93        mask: Option<FloatTensor<Self>>,
94        bias: Option<FloatTensor<Self>>,
95        output_grad: FloatTensor<Self>,
96        options: DeformConvOptions<2>,
97    ) -> DeformConv2dBackward<Self> {
98        module_op!(
99            inp(x, offset, weight, output_grad),
100            opt(mask, bias),
101            E,
102            |x, offset, weight, output_grad, mask, bias| {
103                let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
104                    x,
105                    offset,
106                    weight,
107                    mask,
108                    bias,
109                    output_grad,
110                    options,
111                );
112                DeformConv2dBackward::new(
113                    x.into(),
114                    offset.into(),
115                    weight.into(),
116                    mask.map(|m| m.into()),
117                    bias.map(|b| b.into()),
118                )
119            }
120        )
121    }
122
123    fn conv_transpose2d(
124        x: FloatTensor<Self>,
125        weight: FloatTensor<Self>,
126        bias: Option<FloatTensor<Self>>,
127        options: ConvTransposeOptions<2>,
128    ) -> FloatTensor<Self> {
129        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
130            conv_transpose2d::<E>(x, weight, bias, options).into()
131        })
132    }
133
134    fn avg_pool2d(
135        x: FloatTensor<Self>,
136        kernel_size: [usize; 2],
137        stride: [usize; 2],
138        padding: [usize; 2],
139        count_include_pad: bool,
140        ceil_mode: bool,
141    ) -> FloatTensor<Self> {
142        module_op!(inp(x), opt(), E, |x| {
143            #[cfg(feature = "simd")]
144            let x = match if ceil_mode {
145                // SIMD path doesn't support ceil_mode yet, skip it
146                Err(x)
147            } else {
148                try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad)
149            } {
150                Ok(out) => return out.into(),
151                Err(x) => x,
152            };
153            avg_pool2d::<E>(
154                x,
155                kernel_size,
156                stride,
157                padding,
158                count_include_pad,
159                ceil_mode,
160            )
161            .into()
162        })
163    }
164
165    fn avg_pool2d_backward(
166        x: FloatTensor<Self>,
167        grad: FloatTensor<Self>,
168        kernel_size: [usize; 2],
169        stride: [usize; 2],
170        padding: [usize; 2],
171        count_include_pad: bool,
172        ceil_mode: bool,
173    ) -> FloatTensor<Self> {
174        module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
175            x,
176            grad,
177            kernel_size,
178            stride,
179            padding,
180            count_include_pad,
181            ceil_mode
182        )
183        .into())
184    }
185
186    fn max_pool2d(
187        x: FloatTensor<Self>,
188        kernel_size: [usize; 2],
189        stride: [usize; 2],
190        padding: [usize; 2],
191        dilation: [usize; 2],
192        ceil_mode: bool,
193    ) -> FloatTensor<Self> {
194        module_op!(inp(x), opt(), E, |x| {
195            #[cfg(feature = "simd")]
196            let x = match if ceil_mode {
197                // SIMD path doesn't support ceil_mode yet, skip it
198                Err(x)
199            } else {
200                try_max_pool2d_simd(x, kernel_size, stride, padding, dilation)
201            } {
202                Ok(out) => return out.into(),
203                Err(x) => x,
204            };
205            max_pool2d::<E>(x, kernel_size, stride, padding, dilation, ceil_mode).into()
206        })
207    }
208
209    fn max_pool2d_with_indices(
210        x: FloatTensor<Self>,
211        kernel_size: [usize; 2],
212        stride: [usize; 2],
213        padding: [usize; 2],
214        dilation: [usize; 2],
215        ceil_mode: bool,
216    ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
217        module_op!(inp(x), opt(), E, |x| {
218            let (output, indices) = max_pool2d_with_indices::<E, I>(
219                x,
220                kernel_size,
221                stride,
222                padding,
223                dilation,
224                ceil_mode,
225            );
226            MaxPool2dWithIndices::new(output.into(), indices.into())
227        })
228    }
229
230    fn max_pool2d_with_indices_backward(
231        x: FloatTensor<Self>,
232        kernel_size: [usize; 2],
233        stride: [usize; 2],
234        padding: [usize; 2],
235        dilation: [usize; 2],
236        ceil_mode: bool,
237        output_grad: FloatTensor<Self>,
238        indices: NdArrayTensor,
239    ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
240        execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray<IntElem>| {
241            // Convert indices from runtime dtype to the expected I type
242            // (pool indices are bounded by tensor dimensions, so conversion is safe)
243            let indices: SharedArray<I> = idx_s.mapv(|x| x.elem()).into_shared();
244            module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
245                let output = max_pool2d_backward::<E, I>(
246                    x,
247                    kernel_size,
248                    stride,
249                    padding,
250                    dilation,
251                    ceil_mode,
252                    output_grad,
253                    indices,
254                );
255                MaxPool2dBackward::new(output.into())
256            })
257        })
258    }
259
260    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
261        module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
262            x,
263            output_size
264        )
265        .into())
266    }
267
268    fn adaptive_avg_pool2d_backward(
269        x: FloatTensor<Self>,
270        grad: FloatTensor<Self>,
271    ) -> FloatTensor<Self> {
272        module_op!(inp(x, grad), opt(), E, |x, grad| {
273            adaptive_avg_pool2d_backward::<E>(x, grad).into()
274        })
275    }
276
277    fn interpolate(
278        x: FloatTensor<Self>,
279        output_size: [usize; 2],
280        options: InterpolateOptions,
281    ) -> FloatTensor<Self> {
282        match options.mode {
283            InterpolateMode::Nearest => {
284                module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
285                    x,
286                    output_size
287                )
288                .into())
289            }
290            InterpolateMode::Bilinear => {
291                module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
292                    x,
293                    output_size
294                )
295                .into())
296            }
297            InterpolateMode::Bicubic => {
298                module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
299                    x,
300                    output_size
301                )
302                .into())
303            }
304        }
305    }
306
307    fn interpolate_backward(
308        x: FloatTensor<Self>,
309        grad: FloatTensor<Self>,
310        output_size: [usize; 2],
311        options: InterpolateOptions,
312    ) -> FloatTensor<Self> {
313        match options.mode {
314            InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
315                nearest_interpolate_backward::<E>(x, grad, output_size).into()
316            }),
317            InterpolateMode::Bilinear => {
318                panic!("bilinear interpolation backward is not supported for ndarray backend")
319            }
320            InterpolateMode::Bicubic => {
321                panic!("bicubic interpolation backward is not supported for ndarray backend")
322            }
323        }
324    }
325
326    fn conv3d(
327        x: FloatTensor<Self>,
328        weight: FloatTensor<Self>,
329        bias: Option<FloatTensor<Self>>,
330        options: ConvOptions<3>,
331    ) -> FloatTensor<Self> {
332        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
333            x, weight, bias, options
334        )
335        .into())
336    }
337
338    fn conv_transpose3d(
339        x: FloatTensor<Self>,
340        weight: FloatTensor<Self>,
341        bias: Option<FloatTensor<Self>>,
342        options: ConvTransposeOptions<3>,
343    ) -> FloatTensor<Self> {
344        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
345            conv_transpose3d::<E>(x, weight, bias, options).into()
346        })
347    }
348}