burn_ndarray/ops/
module.rs

1use super::{
2    adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
3    avgpool::{avg_pool2d, avg_pool2d_backward},
4    conv::{conv2d, conv3d, conv_transpose2d, conv_transpose3d},
5    deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
6    interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
7    maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
8};
9use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray, NdArrayTensorFloat};
10use crate::{
11    element::{IntNdArrayElement, QuantElement},
12    ops::interpolate::nearest_interpolate_backward,
13};
14use burn_tensor::ops::*;
15
16macro_rules! module_op {
17    // Module op with inputs (inp), optional (opt) and arguments (args).
18    (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{
19        #[allow(unused_parens, unreachable_patterns)]
20        match ($($x),+) {
21            ($(NdArrayTensorFloat::F32($x)),+) => {
22                type $element = f32;
23                $op(
24                    $($x),+
25                    $(, $opt.map(|o| match o { NdArrayTensorFloat::F32(val) => val, _ => panic!("Optional argument type mismatch") }))*
26                )
27            }
28            ($(NdArrayTensorFloat::F64($x)),+) => {
29                type $element = f64;
30                $op(
31                    $($x),+
32                    $(, $opt.map(|o| match o { NdArrayTensorFloat::F64(val) => val, _ => panic!("Optional argument type mismatch") }))*
33                )
34            }
35            _ => panic!("Data type mismatch"),
36        }
37    }};
38}
39
40impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>
41    for NdArray<E, I, Q>
42{
43    fn conv2d(
44        x: NdArrayTensorFloat,
45        weight: NdArrayTensorFloat,
46        bias: Option<NdArrayTensorFloat>,
47        options: ConvOptions<2>,
48    ) -> NdArrayTensorFloat {
49        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv2d::<
50            E,
51            I,
52            Q,
53        >(
54            x, weight, bias, options
55        )
56        .into())
57    }
58
59    fn deform_conv2d(
60        x: FloatTensor<Self>,
61        offset: FloatTensor<Self>,
62        weight: FloatTensor<Self>,
63        mask: Option<FloatTensor<Self>>,
64        bias: Option<FloatTensor<Self>>,
65        options: DeformConvOptions<2>,
66    ) -> FloatTensor<Self> {
67        module_op!(
68            inp(x, offset, weight),
69            opt(mask, bias),
70            E,
71            |x, offset, weight, mask, bias| deform_conv2d::<E>(
72                x, offset, weight, mask, bias, options
73            )
74            .into()
75        )
76    }
77
78    fn deform_conv2d_backward(
79        x: FloatTensor<Self>,
80        offset: FloatTensor<Self>,
81        weight: FloatTensor<Self>,
82        mask: Option<FloatTensor<Self>>,
83        bias: Option<FloatTensor<Self>>,
84        output_grad: FloatTensor<Self>,
85        options: DeformConvOptions<2>,
86    ) -> DeformConv2dBackward<Self> {
87        module_op!(
88            inp(x, offset, weight, output_grad),
89            opt(mask, bias),
90            E,
91            |x, offset, weight, output_grad, mask, bias| {
92                let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E, I, Q>(
93                    x,
94                    offset,
95                    weight,
96                    mask,
97                    bias,
98                    output_grad,
99                    options,
100                );
101                DeformConv2dBackward::new(
102                    x.into(),
103                    offset.into(),
104                    weight.into(),
105                    mask.map(|m| m.into()),
106                    bias.map(|b| b.into()),
107                )
108            }
109        )
110    }
111
112    fn conv_transpose2d(
113        x: FloatTensor<Self>,
114        weight: FloatTensor<Self>,
115        bias: Option<FloatTensor<Self>>,
116        options: ConvTransposeOptions<2>,
117    ) -> FloatTensor<Self> {
118        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
119            conv_transpose2d::<E>(x, weight, bias, options).into()
120        })
121    }
122
123    fn avg_pool2d(
124        x: FloatTensor<Self>,
125        kernel_size: [usize; 2],
126        stride: [usize; 2],
127        padding: [usize; 2],
128        count_include_pad: bool,
129    ) -> FloatTensor<Self> {
130        module_op!(inp(x), opt(), E, |x| avg_pool2d::<E>(
131            x,
132            kernel_size,
133            stride,
134            padding,
135            count_include_pad
136        )
137        .into())
138    }
139
140    fn avg_pool2d_backward(
141        x: FloatTensor<Self>,
142        grad: FloatTensor<Self>,
143        kernel_size: [usize; 2],
144        stride: [usize; 2],
145        padding: [usize; 2],
146        count_include_pad: bool,
147    ) -> FloatTensor<Self> {
148        module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(
149            x,
150            grad,
151            kernel_size,
152            stride,
153            padding,
154            count_include_pad
155        )
156        .into())
157    }
158
159    fn max_pool2d(
160        x: FloatTensor<Self>,
161        kernel_size: [usize; 2],
162        stride: [usize; 2],
163        padding: [usize; 2],
164        dilation: [usize; 2],
165    ) -> FloatTensor<Self> {
166        module_op!(inp(x), opt(), E, |x| max_pool2d::<E, I, Q>(
167            x,
168            kernel_size,
169            stride,
170            padding,
171            dilation
172        )
173        .into())
174    }
175
176    fn max_pool2d_with_indices(
177        x: FloatTensor<Self>,
178        kernel_size: [usize; 2],
179        stride: [usize; 2],
180        padding: [usize; 2],
181        dilation: [usize; 2],
182    ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {
183        module_op!(inp(x), opt(), E, |x| {
184            let (output, indices) =
185                max_pool2d_with_indices::<E, I, Q>(x, kernel_size, stride, padding, dilation);
186            MaxPool2dWithIndices::new(output.into(), indices)
187        })
188    }
189
190    fn max_pool2d_with_indices_backward(
191        x: FloatTensor<Self>,
192        kernel_size: [usize; 2],
193        stride: [usize; 2],
194        padding: [usize; 2],
195        dilation: [usize; 2],
196        output_grad: FloatTensor<Self>,
197        indices: NdArrayTensor<I>,
198    ) -> MaxPool2dBackward<NdArray<E, I, Q>> {
199        module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {
200            let output = max_pool2d_backward::<E, I>(
201                x,
202                kernel_size,
203                stride,
204                padding,
205                dilation,
206                output_grad,
207                indices,
208            );
209            MaxPool2dBackward::new(output.into())
210        })
211    }
212
213    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
214        module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(
215            x,
216            output_size
217        )
218        .into())
219    }
220
221    fn adaptive_avg_pool2d_backward(
222        x: FloatTensor<Self>,
223        grad: FloatTensor<Self>,
224    ) -> FloatTensor<Self> {
225        module_op!(inp(x, grad), opt(), E, |x, grad| {
226            adaptive_avg_pool2d_backward::<E>(x, grad).into()
227        })
228    }
229
230    fn interpolate(
231        x: FloatTensor<Self>,
232        output_size: [usize; 2],
233        options: InterpolateOptions,
234    ) -> FloatTensor<Self> {
235        match options.mode {
236            InterpolateMode::Nearest => {
237                module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(
238                    x,
239                    output_size
240                )
241                .into())
242            }
243            InterpolateMode::Bilinear => {
244                module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
245                    x,
246                    output_size
247                )
248                .into())
249            }
250            InterpolateMode::Bicubic => {
251                module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(
252                    x,
253                    output_size
254                )
255                .into())
256            }
257        }
258    }
259
260    fn interpolate_backward(
261        x: FloatTensor<Self>,
262        grad: FloatTensor<Self>,
263        output_size: [usize; 2],
264        options: InterpolateOptions,
265    ) -> FloatTensor<Self> {
266        match options.mode {
267            InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
268                nearest_interpolate_backward::<E>(x, grad, output_size).into()
269            }),
270            InterpolateMode::Bilinear => {
271                panic!("bilinear interpolation backward is not supported for ndarray backend")
272            }
273            InterpolateMode::Bicubic => {
274                panic!("bicubic interpolation backward is not supported for ndarray backend")
275            }
276        }
277    }
278
279    fn conv3d(
280        x: FloatTensor<Self>,
281        weight: FloatTensor<Self>,
282        bias: Option<FloatTensor<Self>>,
283        options: ConvOptions<3>,
284    ) -> FloatTensor<Self> {
285        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<
286            E,
287            I,
288            Q,
289        >(
290            x, weight, bias, options
291        )
292        .into())
293    }
294
295    fn conv_transpose3d(
296        x: FloatTensor<Self>,
297        weight: FloatTensor<Self>,
298        bias: Option<FloatTensor<Self>>,
299        options: ConvTransposeOptions<3>,
300    ) -> FloatTensor<Self> {
301        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {
302            conv_transpose3d::<E>(x, weight, bias, options).into()
303        })
304    }
305}