burn_ndarray/ops/
padding.rs1use crate::{NdArrayElement, SharedArray};
2use ndarray::{Array4, Array5};
3
4use super::NdArrayOps;
5
6pub(crate) fn apply_padding_4d<E: NdArrayElement>(
7 x: SharedArray<E>,
8 padding: [usize; 2],
9 elem: E,
10) -> SharedArray<E> {
11 let [batch_size, input_channels, height, width] = x.shape().try_into().unwrap();
12 let [padding_height, padding_width] = padding;
13 let padded_height = height + 2 * padding_height;
14 let padded_width = width + 2 * padding_width;
15
16 let x_new = Array4::from_elem(
17 (batch_size, input_channels, padded_height, padded_width),
18 elem,
19 );
20 let mut x_new = x_new.into_shared().into_dyn();
21
22 x_new = NdArrayOps::slice_assign(
23 x_new,
24 &[
25 burn_backend::Slice::from(0..batch_size),
26 burn_backend::Slice::from(0..input_channels),
27 burn_backend::Slice::from(padding_height..height + padding_height),
28 burn_backend::Slice::from(padding_width..width + padding_width),
29 ],
30 x,
31 );
32
33 x_new
34}
35
36pub(crate) fn apply_padding_5d<E: NdArrayElement>(
37 x: SharedArray<E>,
38 padding: [usize; 3],
39 elem: E,
40) -> SharedArray<E> {
41 let [batch_size, input_channels, depth, height, width] = x.shape().try_into().unwrap();
42 let [padding_depth, padding_height, padding_width] = padding;
43 let padded_depth = depth + 2 * padding_depth;
44 let padded_height = height + 2 * padding_height;
45 let padded_width = width + 2 * padding_width;
46
47 let x_new = Array5::from_elem(
48 (
49 batch_size,
50 input_channels,
51 padded_depth,
52 padded_height,
53 padded_width,
54 ),
55 elem,
56 );
57 let mut x_new = x_new.into_shared().into_dyn();
58
59 x_new = NdArrayOps::slice_assign(
60 x_new,
61 &[
62 burn_backend::Slice::from(0..batch_size),
63 burn_backend::Slice::from(0..input_channels),
64 burn_backend::Slice::from(padding_depth..depth + padding_depth),
65 burn_backend::Slice::from(padding_height..height + padding_height),
66 burn_backend::Slice::from(padding_width..width + padding_width),
67 ],
68 x,
69 );
70
71 x_new
72}