burn_tensor/tensor/api/
pad.rs

1use alloc::vec::Vec;
2use core::ops::Range;
3
4use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode};
5
6use super::Numeric;
7
8/// Helper to build a range array for slice_assign, selecting a portion of one dimension.
9fn build_slice_ranges<const D: usize>(
10    dims: [usize; D],
11    target_dim: usize,
12    start: usize,
13    len: usize,
14) -> [Range<usize>; D] {
15    dims.iter()
16        .enumerate()
17        .map(|(i, &size)| {
18            if i == target_dim {
19                start..start + len
20            } else {
21                0..size
22            }
23        })
24        .collect::<Vec<Range<usize>>>()
25        .try_into()
26        .unwrap()
27}
28
29impl<B, const D: usize, K> Tensor<B, D, K>
30where
31    B: Backend,
32    K: Numeric<B>,
33    K::Elem: Element,
34{
35    /// Pads the tensor on the last two dimensions using the specified padding mode.
36    ///
37    /// **Note**: Currently, padding is only supported on the last two dimensions of a tensor
38    /// (typically height and width for image data in NCHW format).
39    ///
40    /// # Arguments
41    ///
42    /// * `padding` - A tuple `(left, right, top, bottom)` specifying padding for the last two dimensions.
43    /// * `mode` - The padding mode: `Constant(value)`, `Reflect`, or `Edge`.
44    ///
45    /// # Returns
46    ///
47    /// A new tensor with the specified padding applied.
48    ///
49    /// # Panics
50    ///
51    /// - `Reflect` mode panics if padding exceeds `dimension_size - 1`.
52    /// - `Edge` mode panics if padding is applied to a zero-sized dimension.
53    ///
54    /// # Example
55    ///
56    /// ```rust
57    /// use burn_tensor::backend::Backend;
58    /// use burn_tensor::{Tensor, Shape};
59    /// use burn_tensor::ops::PadMode;
60    ///
61    /// fn example<B: Backend<FloatElem: From<f32>>>() {
62    ///    let device = B::Device::default();
63    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
64    ///
65    ///    // Constant padding with value 0.0
66    ///    let padded = tensor.clone().pad((1, 1, 1, 1), PadMode::Constant(0.0));
67    ///    // [
68    ///    //   [0.0, 0.0, 0.0, 0.0, 0.0],
69    ///    //   [0.0, 12.0, -2.0, 3.0, 0.0],
70    ///    //   [0.0, 5.0, 3.0, 6.0, 0.0],
71    ///    //   [0.0, 0.0, 0.0, 0.0, 0.0]
72    ///    // ]
73    ///
74    ///    // Reflect padding
75    ///    let padded = tensor.clone().pad((1, 1, 0, 0), PadMode::Reflect);
76    ///    // [[−2.0, 12.0, −2.0, 3.0, −2.0], [3.0, 5.0, 3.0, 6.0, 3.0]]
77    ///
78    ///    // Edge padding
79    ///    let padded = tensor.pad((1, 1, 0, 0), PadMode::Edge);
80    ///    // [[12.0, 12.0, −2.0, 3.0, 3.0], [5.0, 5.0, 3.0, 6.0, 6.0]]
81    /// }
82    /// ```
83    pub fn pad(self, padding: (usize, usize, usize, usize), mode: PadMode) -> Self {
84        match mode {
85            PadMode::Constant(value) => pad_constant(self, padding, value),
86            PadMode::Reflect => pad_reflect(self, padding),
87            PadMode::Edge => pad_edge(self, padding),
88        }
89    }
90}
91
92/// Pad with a constant value.
93pub fn pad_constant<B, const D: usize, K, E>(
94    tensor: Tensor<B, D, K>,
95    padding: (usize, usize, usize, usize),
96    value: E,
97) -> Tensor<B, D, K>
98where
99    B: Backend,
100    K: Numeric<B>,
101    K::Elem: Element,
102    E: ElementConversion,
103{
104    let (left, right, top, bottom) = padding;
105
106    let mut padded_dims: [usize; D] = tensor.dims();
107
108    // Update the last two dimensions with padding
109    padded_dims[D - 2] += top + bottom;
110    padded_dims[D - 1] += left + right;
111
112    // Create the ranges for the padded tensor
113    let ranges: [core::ops::Range<usize>; D] = padded_dims
114        .iter()
115        .enumerate()
116        .map(|(i, &dim)| {
117            if i == D - 2 {
118                top..dim - bottom
119            } else if i == D - 1 {
120                left..dim - right
121            } else {
122                0..dim
123            }
124        })
125        .collect::<Vec<core::ops::Range<usize>>>()
126        .try_into()
127        .unwrap();
128
129    // Create the padded tensor
130    let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());
131
132    // Assign the original tensor data to the appropriate slice of the padded tensor
133    padded_tensor.slice_assign(ranges, tensor)
134}
135
136/// Pad using reflection at the boundaries (excluding edge values).
137///
138/// For ONNX "reflect" mode: mirrors from index 1, not index 0.
139/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[3, 2, 1, 2, 3, 4]`
140pub fn pad_reflect<B, const D: usize, K>(
141    tensor: Tensor<B, D, K>,
142    padding: (usize, usize, usize, usize),
143) -> Tensor<B, D, K>
144where
145    B: Backend,
146    K: Numeric<B>,
147    K::Elem: Element,
148{
149    let (left, right, top, bottom) = padding;
150    let dims = tensor.dims();
151
152    // Validate padding doesn't exceed tensor dimensions
153    // For reflect mode, padding must be less than the corresponding dimension
154    assert!(
155        top < dims[D - 2] && bottom < dims[D - 2],
156        "Reflect padding on height ({}, {}) must be less than height dimension ({})",
157        top,
158        bottom,
159        dims[D - 2]
160    );
161    assert!(
162        left < dims[D - 1] && right < dims[D - 1],
163        "Reflect padding on width ({}, {}) must be less than width dimension ({})",
164        left,
165        right,
166        dims[D - 1]
167    );
168
169    let mut result = tensor;
170
171    // Pad height dimension (D - 2): top and bottom
172    if top > 0 || bottom > 0 {
173        result = pad_reflect_dim(result, D - 2, top, bottom);
174    }
175
176    // Pad width dimension (D - 1): left and right
177    if left > 0 || right > 0 {
178        result = pad_reflect_dim(result, D - 1, left, right);
179    }
180
181    result
182}
183
184/// Helper to pad a single dimension using reflection.
185fn pad_reflect_dim<B, const D: usize, K>(
186    tensor: Tensor<B, D, K>,
187    dim: usize,
188    pad_before: usize,
189    pad_after: usize,
190) -> Tensor<B, D, K>
191where
192    B: Backend,
193    K: Numeric<B>,
194    K::Elem: Element,
195{
196    let dims = tensor.dims();
197    let dim_size = dims[dim];
198
199    // Calculate output dimensions
200    let mut output_dims = dims;
201    output_dims[dim] += pad_before + pad_after;
202
203    // Create output tensor and place original in the center
204    let output = Tensor::zeros(output_dims, &tensor.device());
205    let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
206    let mut output = output.slice_assign(original_range, tensor.clone());
207
208    // Assign reflected "before" padding (e.g., top or left)
209    // Reflect excludes the edge, so we take indices [1..pad_before+1] and flip
210    if pad_before > 0 {
211        let before_slice = tensor.clone().narrow(dim, 1, pad_before);
212        let before_flipped = before_slice.flip([dim as isize]);
213        let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
214        output = output.slice_assign(before_range, before_flipped);
215    }
216
217    // Assign reflected "after" padding (e.g., bottom or right)
218    // Take indices [dim_size - pad_after - 1..dim_size - 1] and flip
219    if pad_after > 0 {
220        let start = dim_size - pad_after - 1;
221        let after_slice = tensor.narrow(dim, start, pad_after);
222        let after_flipped = after_slice.flip([dim as isize]);
223        let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
224        output = output.slice_assign(after_range, after_flipped);
225    }
226
227    output
228}
229
230/// Pad by replicating edge values.
231///
232/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[1, 1, 1, 2, 3, 4]`
233pub fn pad_edge<B, const D: usize, K>(
234    tensor: Tensor<B, D, K>,
235    padding: (usize, usize, usize, usize),
236) -> Tensor<B, D, K>
237where
238    B: Backend,
239    K: Numeric<B>,
240    K::Elem: Element,
241{
242    let (left, right, top, bottom) = padding;
243    let dims = tensor.dims();
244
245    // Validate dimensions are non-zero when padding is requested
246    if top > 0 || bottom > 0 {
247        assert!(
248            dims[D - 2] > 0,
249            "Cannot apply edge padding to zero-sized height dimension"
250        );
251    }
252    if left > 0 || right > 0 {
253        assert!(
254            dims[D - 1] > 0,
255            "Cannot apply edge padding to zero-sized width dimension"
256        );
257    }
258
259    let mut result = tensor;
260
261    // Pad height dimension (D - 2): top and bottom
262    if top > 0 || bottom > 0 {
263        result = pad_edge_dim(result, D - 2, top, bottom);
264    }
265
266    // Pad width dimension (D - 1): left and right
267    if left > 0 || right > 0 {
268        result = pad_edge_dim(result, D - 1, left, right);
269    }
270
271    result
272}
273
274/// Helper to pad a single dimension by replicating edge values.
275fn pad_edge_dim<B, const D: usize, K>(
276    tensor: Tensor<B, D, K>,
277    dim: usize,
278    pad_before: usize,
279    pad_after: usize,
280) -> Tensor<B, D, K>
281where
282    B: Backend,
283    K: Numeric<B>,
284    K::Elem: Element,
285{
286    let dims = tensor.dims();
287    let dim_size = dims[dim];
288
289    // Calculate output dimensions
290    let mut output_dims = dims;
291    output_dims[dim] += pad_before + pad_after;
292
293    // Create output tensor and place original in the center
294    let output = Tensor::zeros(output_dims, &tensor.device());
295    let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
296    let mut output = output.slice_assign(original_range, tensor.clone());
297
298    // Assign "before" padding by repeating the first element
299    if pad_before > 0 {
300        let first_slice = tensor.clone().narrow(dim, 0, 1);
301        let before_pad = first_slice.repeat_dim(dim, pad_before);
302        let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
303        output = output.slice_assign(before_range, before_pad);
304    }
305
306    // Assign "after" padding by repeating the last element
307    if pad_after > 0 {
308        let last_slice = tensor.narrow(dim, dim_size - 1, 1);
309        let after_pad = last_slice.repeat_dim(dim, pad_after);
310        let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
311        output = output.slice_assign(after_range, after_pad);
312    }
313
314    output
315}