Skip to main content

burn_ndarray/ops/
padding.rs

1use crate::{NdArrayElement, SharedArray};
2use ndarray::{Array4, Array5};
3
4use super::NdArrayOps;
5
6pub(crate) fn apply_padding_4d<E: NdArrayElement>(
7    x: SharedArray<E>,
8    padding: [usize; 2],
9    elem: E,
10) -> SharedArray<E> {
11    let [batch_size, input_channels, height, width] = x.shape().try_into().unwrap();
12    let [padding_height, padding_width] = padding;
13    let padded_height = height + 2 * padding_height;
14    let padded_width = width + 2 * padding_width;
15
16    let x_new = Array4::from_elem(
17        (batch_size, input_channels, padded_height, padded_width),
18        elem,
19    );
20    let mut x_new = x_new.into_shared().into_dyn();
21
22    x_new = NdArrayOps::slice_assign(
23        x_new,
24        &[
25            burn_backend::Slice::from(0..batch_size),
26            burn_backend::Slice::from(0..input_channels),
27            burn_backend::Slice::from(padding_height..height + padding_height),
28            burn_backend::Slice::from(padding_width..width + padding_width),
29        ],
30        x,
31    );
32
33    x_new
34}
35
36pub(crate) fn apply_padding_5d<E: NdArrayElement>(
37    x: SharedArray<E>,
38    padding: [usize; 3],
39    elem: E,
40) -> SharedArray<E> {
41    let [batch_size, input_channels, depth, height, width] = x.shape().try_into().unwrap();
42    let [padding_depth, padding_height, padding_width] = padding;
43    let padded_depth = depth + 2 * padding_depth;
44    let padded_height = height + 2 * padding_height;
45    let padded_width = width + 2 * padding_width;
46
47    let x_new = Array5::from_elem(
48        (
49            batch_size,
50            input_channels,
51            padded_depth,
52            padded_height,
53            padded_width,
54        ),
55        elem,
56    );
57    let mut x_new = x_new.into_shared().into_dyn();
58
59    x_new = NdArrayOps::slice_assign(
60        x_new,
61        &[
62            burn_backend::Slice::from(0..batch_size),
63            burn_backend::Slice::from(0..input_channels),
64            burn_backend::Slice::from(padding_depth..depth + padding_depth),
65            burn_backend::Slice::from(padding_height..height + padding_height),
66            burn_backend::Slice::from(padding_width..width + padding_width),
67        ],
68        x,
69    );
70
71    x_new
72}