burn_ndarray/ops/
module.rs

1use super::{
2    adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3    avgpool::{avg_pool2d, avg_pool2d_backward},
4    conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5    deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6    interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7    maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9#[cfg(feature = "simd")]
10use crate::ops::simd::{
11    avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
12};
13use crate::{
14    NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,
15    tensor::NdArrayTensor,
16};
17use crate::{
18    element::{IntNdArrayElement, QuantElement},
19    ops::interpolate::nearest_interpolate_backward,
20};
21use burn_tensor::{TensorMetadata, ops::*};
22
23macro_rules! module_op {
24    // Module op with inputs (inp), optional (opt) and arguments (args).
25    (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
26        #[allow(unused_parens, unreachable_patterns)]
27        match ($($x),+) {
28            ($(NdArrayTensor::F32($x)),+) => {
29                type $element = f32;
30                $op(
31                    $($x),+
32                    $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val, _ => panic!("Optional argument type mismatch") }))*
33                )
34            }
35            ($(NdArrayTensor::F64($x)),+) => {
36                type $element = f64;
37                $op(
38                    $($x),+
39                    $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val, _ => panic!("Optional argument type mismatch") }))*
40                )
41            }
42            _ => panic!("Data type mismatch"),
43        }
44    }};
45}
46
47impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
48    for NdArray<E, I, Q>
49where
50    NdArrayTensor: From<SharedArray<E>>,
51    NdArrayTensor: From<SharedArray<I>>,
52{
53    fn conv2d(
54        x: NdArrayTensor,
55        weight: NdArrayTensor,
56        bias: Option<NdArrayTensor>,
57        options: ConvOptions<2>,
58    ) -> NdArrayTensor {
59        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
60            #[cfg(feature = "simd")]
61            let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
62                Ok(out) => return out.into(),
63                Err(args) => args,
64            };
65            conv2d::<E>(x, weight, bias, options).into()
66        })
67    }
68
69    fn deform_conv2d(
70        x: FloatTensor<Self>,
71        offset: FloatTensor<Self>,
72        weight: FloatTensor<Self>,
73        mask: Option<FloatTensor<Self>>,
74        bias: Option<FloatTensor<Self>>,
75        options: DeformConvOptions<2>,
76    ) -> FloatTensor<Self> {
77        module_op!(
78            inp(x, offset, weight),
79            opt(mask, bias),
80            E,
81            |x, offset, weight, mask, bias| deform_conv2d::<E>(
82                x, offset, weight, mask, bias, options
83            )
84            .into()
85        )
86    }
87
88    fn deform_conv2d_backward(
89        x: FloatTensor<Self>,
90        offset: FloatTensor<Self>,
91        weight: FloatTensor<Self>,
92        mask: Option<FloatTensor<Self>>,
93        bias: Option<FloatTensor<Self>>,
94        output_grad: FloatTensor<Self>,
95        options: DeformConvOptions<2>,
96    ) -> DeformConv2dBackward<Self> {
97        module_op!(
98            inp(x, offset, weight, output_grad),
99            opt(mask, bias),
100            E,
101            |x, offset, weight, output_grad, mask, bias| {
102                let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
103                    x,
104                    offset,
105                    weight,
106                    mask,
107                    bias,
108                    output_grad,
109                    options,
110                );
111                DeformConv2dBackward::new(
112                    x.into(),
113                    offset.into(),
114                    weight.into(),
115                    mask.map(|m| m.into()),
116                    bias.map(|b| b.into()),
117                )
118            }
119        )
120    }
121
122    fn conv_transpose2d(
123        x: FloatTensor<Self>,
124        weight: FloatTensor<Self>,
125        bias: Option<FloatTensor<Self>>,
126        options: ConvTransposeOptions<2>,
127    ) -> FloatTensor<Self> {
128        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
129            conv_transpose2d::<E>(x, weight, bias, options).into()
130        })
131    }
132
133    fn avg_pool2d(
134        x: FloatTensor<Self>,
135        kernel_size: [usize; 2],
136        stride: [usize; 2],
137        padding: [usize; 2],
138        count_include_pad: bool,
139    ) -> FloatTensor<Self> {
140        module_op!(inp(x), opt(), E, |x| {
141            #[cfg(feature = "simd")]
142            let x = match try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad) {
143                Ok(out) => return out.into(),
144                Err(x) => x,
145            };
146            avg_pool2d::<E>(x, kernel_size, stride, padding, count_include_pad).into()
147        })
148    }
149
150    fn avg_pool2d_backward(
151        x: FloatTensor<Self>,
152        grad: FloatTensor<Self>,
153        kernel_size: [usize; 2],
154        stride: [usize; 2],
155        padding: [usize; 2],
156        count_include_pad: bool,
157    ) -> FloatTensor<Self> {
158        module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
159            x,
160            grad,
161            kernel_size,
162            stride,
163            padding,
164            count_include_pad
165        )
166        .into())
167    }
168
169    fn max_pool2d(
170        x: FloatTensor<Self>,
171        kernel_size: [usize; 2],
172        stride: [usize; 2],
173        padding: [usize; 2],
174        dilation: [usize; 2],
175    ) -> FloatTensor<Self> {
176        module_op!(inp(x), opt(), E, |x| {
177            #[cfg(feature = "simd")]
178            let x = match try_max_pool2d_simd(x, kernel_size, stride, padding, dilation) {
179                Ok(out) => return out.into(),
180                Err(x) => x,
181            };
182            max_pool2d::<E>(x, kernel_size, stride, padding, dilation).into()
183        })
184    }
185
186    fn max_pool2d_with_indices(
187        x: FloatTensor<Self>,
188        kernel_size: [usize; 2],
189        stride: [usize; 2],
190        padding: [usize; 2],
191        dilation: [usize; 2],
192    ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
193        module_op!(inp(x), opt(), E, |x| {
194            let (output, indices) =
195                max_pool2d_with_indices::<E, I>(x, kernel_size, stride, padding, dilation);
196            MaxPool2dWithIndices::new(output.into(), indices.into())
197        })
198    }
199
200    fn max_pool2d_with_indices_backward(
201        x: FloatTensor<Self>,
202        kernel_size: [usize; 2],
203        stride: [usize; 2],
204        padding: [usize; 2],
205        dilation: [usize; 2],
206        output_grad: FloatTensor<Self>,
207        indices: NdArrayTensor,
208    ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
209        execute_with_int_dtype!(indices, I, |indices| {
210            module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
211                let output = max_pool2d_backward::<E, I>(
212                    x,
213                    kernel_size,
214                    stride,
215                    padding,
216                    dilation,
217                    output_grad,
218                    indices,
219                );
220                MaxPool2dBackward::new(output.into())
221            })
222        })
223    }
224
225    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
226        module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
227            x,
228            output_size
229        )
230        .into())
231    }
232
233    fn adaptive_avg_pool2d_backward(
234        x: FloatTensor<Self>,
235        grad: FloatTensor<Self>,
236    ) -> FloatTensor<Self> {
237        module_op!(inp(x, grad), opt(), E, |x, grad| {
238            adaptive_avg_pool2d_backward::<E>(x, grad).into()
239        })
240    }
241
242    fn interpolate(
243        x: FloatTensor<Self>,
244        output_size: [usize; 2],
245        options: InterpolateOptions,
246    ) -> FloatTensor<Self> {
247        match options.mode {
248            InterpolateMode::Nearest => {
249                module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
250                    x,
251                    output_size
252                )
253                .into())
254            }
255            InterpolateMode::Bilinear => {
256                module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
257                    x,
258                    output_size
259                )
260                .into())
261            }
262            InterpolateMode::Bicubic => {
263                module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
264                    x,
265                    output_size
266                )
267                .into())
268            }
269        }
270    }
271
272    fn interpolate_backward(
273        x: FloatTensor<Self>,
274        grad: FloatTensor<Self>,
275        output_size: [usize; 2],
276        options: InterpolateOptions,
277    ) -> FloatTensor<Self> {
278        match options.mode {
279            InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
280                nearest_interpolate_backward::<E>(x, grad, output_size).into()
281            }),
282            InterpolateMode::Bilinear => {
283                panic!("bilinear interpolation backward is not supported for ndarray backend")
284            }
285            InterpolateMode::Bicubic => {
286                panic!("bicubic interpolation backward is not supported for ndarray backend")
287            }
288        }
289    }
290
291    fn conv3d(
292        x: FloatTensor<Self>,
293        weight: FloatTensor<Self>,
294        bias: Option<FloatTensor<Self>>,
295        options: ConvOptions<3>,
296    ) -> FloatTensor<Self> {
297        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
298            x, weight, bias, options
299        )
300        .into())
301    }
302
303    fn conv_transpose3d(
304        x: FloatTensor<Self>,
305        weight: FloatTensor<Self>,
306        bias: Option<FloatTensor<Self>>,
307        options: ConvTransposeOptions<3>,
308    ) -> FloatTensor<Self> {
309        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
310            conv_transpose3d::<E>(x, weight, bias, options).into()
311        })
312    }
313}