burn_ndarray/ops/
module.rs

1use super::{
2    adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3    avgpool::{avg_pool2d, avg_pool2d_backward},
4    conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},
5    deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6    interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7    maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9#[cfg(feature = "simd")]
10use crate::ops::simd::{
11    avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,
12};
13use crate::{NdArray, NdArrayTensorFloat, element::FloatNdArrayElement, tensor::NdArrayTensor};
14use crate::{
15    element::{IntNdArrayElement, QuantElement},
16    ops::interpolate::nearest_interpolate_backward,
17};
18use burn_tensor::ops::*;
19
20macro_rules! module_op {
21    // Module op with inputs (inp), optional (opt) and arguments (args).
22    (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
23        #[allow(unused_parens, unreachable_patterns)]
24        match ($($x),+) {
25            ($(NdArrayTensorFloat::F32($x)),+) => {
26                type $element = f32;
27                $op(
28                    $($x),+
29                    $(, $opt.map(|o| match o { NdArrayTensorFloat::F32(val) => val, _ => panic!("Optional argument type mismatch") }))*
30                )
31            }
32            ($(NdArrayTensorFloat::F64($x)),+) => {
33                type $element = f64;
34                $op(
35                    $($x),+
36                    $(, $opt.map(|o| match o { NdArrayTensorFloat::F64(val) => val, _ => panic!("Optional argument type mismatch") }))*
37                )
38            }
39            _ => panic!("Data type mismatch"),
40        }
41    }};
42}
43
44impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
45    for NdArray<E, I, Q>
46{
47    fn conv2d(
48        x: NdArrayTensorFloat,
49        weight: NdArrayTensorFloat,
50        bias: Option<NdArrayTensorFloat>,
51        options: ConvOptions<2>,
52    ) -> NdArrayTensorFloat {
53        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
54            #[cfg(feature = "simd")]
55            let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {
56                Ok(out) => return out.into(),
57                Err(args) => args,
58            };
59            conv2d::<E>(x, weight, bias, options).into()
60        })
61    }
62
63    fn deform_conv2d(
64        x: FloatTensor<Self>,
65        offset: FloatTensor<Self>,
66        weight: FloatTensor<Self>,
67        mask: Option<FloatTensor<Self>>,
68        bias: Option<FloatTensor<Self>>,
69        options: DeformConvOptions<2>,
70    ) -> FloatTensor<Self> {
71        module_op!(
72            inp(x, offset, weight),
73            opt(mask, bias),
74            E,
75            |x, offset, weight, mask, bias| deform_conv2d::<E>(
76                x, offset, weight, mask, bias, options
77            )
78            .into()
79        )
80    }
81
82    fn deform_conv2d_backward(
83        x: FloatTensor<Self>,
84        offset: FloatTensor<Self>,
85        weight: FloatTensor<Self>,
86        mask: Option<FloatTensor<Self>>,
87        bias: Option<FloatTensor<Self>>,
88        output_grad: FloatTensor<Self>,
89        options: DeformConvOptions<2>,
90    ) -> DeformConv2dBackward<Self> {
91        module_op!(
92            inp(x, offset, weight, output_grad),
93            opt(mask, bias),
94            E,
95            |x, offset, weight, output_grad, mask, bias| {
96                let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(
97                    x,
98                    offset,
99                    weight,
100                    mask,
101                    bias,
102                    output_grad,
103                    options,
104                );
105                DeformConv2dBackward::new(
106                    x.into(),
107                    offset.into(),
108                    weight.into(),
109                    mask.map(|m| m.into()),
110                    bias.map(|b| b.into()),
111                )
112            }
113        )
114    }
115
116    fn conv_transpose2d(
117        x: FloatTensor<Self>,
118        weight: FloatTensor<Self>,
119        bias: Option<FloatTensor<Self>>,
120        options: ConvTransposeOptions<2>,
121    ) -> FloatTensor<Self> {
122        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
123            conv_transpose2d::<E>(x, weight, bias, options).into()
124        })
125    }
126
127    fn avg_pool2d(
128        x: FloatTensor<Self>,
129        kernel_size: [usize; 2],
130        stride: [usize; 2],
131        padding: [usize; 2],
132        count_include_pad: bool,
133    ) -> FloatTensor<Self> {
134        module_op!(inp(x), opt(), E, |x| {
135            #[cfg(feature = "simd")]
136            let x = match try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad) {
137                Ok(out) => return out.into(),
138                Err(x) => x,
139            };
140            avg_pool2d::<E>(x, kernel_size, stride, padding, count_include_pad).into()
141        })
142    }
143
144    fn avg_pool2d_backward(
145        x: FloatTensor<Self>,
146        grad: FloatTensor<Self>,
147        kernel_size: [usize; 2],
148        stride: [usize; 2],
149        padding: [usize; 2],
150        count_include_pad: bool,
151    ) -> FloatTensor<Self> {
152        module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
153            x,
154            grad,
155            kernel_size,
156            stride,
157            padding,
158            count_include_pad
159        )
160        .into())
161    }
162
163    fn max_pool2d(
164        x: FloatTensor<Self>,
165        kernel_size: [usize; 2],
166        stride: [usize; 2],
167        padding: [usize; 2],
168        dilation: [usize; 2],
169    ) -> FloatTensor<Self> {
170        module_op!(inp(x), opt(), E, |x| {
171            #[cfg(feature = "simd")]
172            let x = match try_max_pool2d_simd(x, kernel_size, stride, padding, dilation) {
173                Ok(out) => return out.into(),
174                Err(x) => x,
175            };
176            max_pool2d::<E>(x, kernel_size, stride, padding, dilation).into()
177        })
178    }
179
180    fn max_pool2d_with_indices(
181        x: FloatTensor<Self>,
182        kernel_size: [usize; 2],
183        stride: [usize; 2],
184        padding: [usize; 2],
185        dilation: [usize; 2],
186    ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
187        module_op!(inp(x), opt(), E, |x| {
188            let (output, indices) =
189                max_pool2d_with_indices::<E, I>(x, kernel_size, stride, padding, dilation);
190            MaxPool2dWithIndices::new(output.into(), indices)
191        })
192    }
193
194    fn max_pool2d_with_indices_backward(
195        x: FloatTensor<Self>,
196        kernel_size: [usize; 2],
197        stride: [usize; 2],
198        padding: [usize; 2],
199        dilation: [usize; 2],
200        output_grad: FloatTensor<Self>,
201        indices: NdArrayTensor<I>,
202    ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
203        module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
204            let output = max_pool2d_backward::<E, I>(
205                x,
206                kernel_size,
207                stride,
208                padding,
209                dilation,
210                output_grad,
211                indices,
212            );
213            MaxPool2dBackward::new(output.into())
214        })
215    }
216
217    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
218        module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
219            x,
220            output_size
221        )
222        .into())
223    }
224
225    fn adaptive_avg_pool2d_backward(
226        x: FloatTensor<Self>,
227        grad: FloatTensor<Self>,
228    ) -> FloatTensor<Self> {
229        module_op!(inp(x, grad), opt(), E, |x, grad| {
230            adaptive_avg_pool2d_backward::<E>(x, grad).into()
231        })
232    }
233
234    fn interpolate(
235        x: FloatTensor<Self>,
236        output_size: [usize; 2],
237        options: InterpolateOptions,
238    ) -> FloatTensor<Self> {
239        match options.mode {
240            InterpolateMode::Nearest => {
241                module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
242                    x,
243                    output_size
244                )
245                .into())
246            }
247            InterpolateMode::Bilinear => {
248                module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
249                    x,
250                    output_size
251                )
252                .into())
253            }
254            InterpolateMode::Bicubic => {
255                module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
256                    x,
257                    output_size
258                )
259                .into())
260            }
261        }
262    }
263
264    fn interpolate_backward(
265        x: FloatTensor<Self>,
266        grad: FloatTensor<Self>,
267        output_size: [usize; 2],
268        options: InterpolateOptions,
269    ) -> FloatTensor<Self> {
270        match options.mode {
271            InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
272                nearest_interpolate_backward::<E>(x, grad, output_size).into()
273            }),
274            InterpolateMode::Bilinear => {
275                panic!("bilinear interpolation backward is not supported for ndarray backend")
276            }
277            InterpolateMode::Bicubic => {
278                panic!("bicubic interpolation backward is not supported for ndarray backend")
279            }
280        }
281    }
282
283    fn conv3d(
284        x: FloatTensor<Self>,
285        weight: FloatTensor<Self>,
286        bias: Option<FloatTensor<Self>>,
287        options: ConvOptions<3>,
288    ) -> FloatTensor<Self> {
289        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(
290            x, weight, bias, options
291        )
292        .into())
293    }
294
295    fn conv_transpose3d(
296        x: FloatTensor<Self>,
297        weight: FloatTensor<Self>,
298        bias: Option<FloatTensor<Self>>,
299        options: ConvTransposeOptions<3>,
300    ) -> FloatTensor<Self> {
301        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
302            conv_transpose3d::<E>(x, weight, bias, options).into()
303        })
304    }
305}