Skip to main content

burn_tensor/tensor/api/
pad.rs

1use alloc::vec::Vec;
2use core::ops::Range;
3
4use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode};
5
6use super::Numeric;
7
8/// Trait for types that can be used as padding specifications.
9///
10/// Padding is specified as `(before, after)` pairs per dimension, returned as a
11/// fixed-size array `[(usize, usize); D]`. If fewer pairs than dimensions are provided,
12/// they apply to the **last** N dimensions (earlier dimensions are left unpadded).
13pub trait IntoPadding<const D: usize> {
14    /// Converts into a fixed-size array of `(before, after)` padding pairs.
15    fn into_padding(self) -> [(usize, usize); D];
16}
17
18impl<const D: usize, const N: usize> IntoPadding<D> for [(usize, usize); N] {
19    fn into_padding(self) -> [(usize, usize); D] {
20        assert!(
21            N <= D,
22            "Padding has {} pairs but tensor only has {} dimensions",
23            N,
24            D
25        );
26        let mut result = [(0usize, 0usize); D];
27        let offset = D - N;
28        for (i, pair) in self.into_iter().enumerate() {
29            result[offset + i] = pair;
30        }
31        result
32    }
33}
34
35/// Backward-compatible: `(left, right, top, bottom)` maps to last 2 dimensions.
36///
37/// Equivalent to `[(top, bottom), (left, right)]`.
38impl<const D: usize> IntoPadding<D> for (usize, usize, usize, usize) {
39    fn into_padding(self) -> [(usize, usize); D] {
40        let (left, right, top, bottom) = self;
41        let mut result = [(0usize, 0usize); D];
42        result[D - 2] = (top, bottom);
43        result[D - 1] = (left, right);
44        result
45    }
46}
47
48impl<const D: usize> IntoPadding<D> for &[(usize, usize)] {
49    fn into_padding(self) -> [(usize, usize); D] {
50        assert!(
51            self.len() <= D,
52            "Padding has {} pairs but tensor only has {} dimensions",
53            self.len(),
54            D
55        );
56        let mut result = [(0usize, 0usize); D];
57        let offset = D - self.len();
58        for (i, &pair) in self.iter().enumerate() {
59            result[offset + i] = pair;
60        }
61        result
62    }
63}
64
65impl<const D: usize> IntoPadding<D> for Vec<(usize, usize)> {
66    fn into_padding(self) -> [(usize, usize); D] {
67        assert!(
68            self.len() <= D,
69            "Padding has {} pairs but tensor only has {} dimensions",
70            self.len(),
71            D
72        );
73        let mut result = [(0usize, 0usize); D];
74        let offset = D - self.len();
75        for (i, pair) in self.into_iter().enumerate() {
76            result[offset + i] = pair;
77        }
78        result
79    }
80}
81
82/// Helper to build a range array for slice_assign, selecting a portion of one dimension.
83fn build_slice_ranges<const D: usize>(
84    dims: [usize; D],
85    target_dim: usize,
86    start: usize,
87    len: usize,
88) -> [Range<usize>; D] {
89    dims.iter()
90        .enumerate()
91        .map(|(i, &size)| {
92            if i == target_dim {
93                start..start + len
94            } else {
95                0..size
96            }
97        })
98        .collect::<Vec<Range<usize>>>()
99        .try_into()
100        .unwrap()
101}
102
103impl<B, const D: usize, K> Tensor<B, D, K>
104where
105    B: Backend,
106    K: Numeric<B>,
107    K::Elem: Element,
108{
109    /// Pads the tensor using the specified padding mode.
110    ///
111    /// Padding is specified as `(before, after)` pairs. If fewer pairs than tensor dimensions
112    /// are provided, they apply to the **last** N dimensions (unspecified leading dimensions
113    /// are left unpadded).
114    ///
115    /// For backward compatibility, a `(left, right, top, bottom)` tuple is also accepted,
116    /// which pads the last two dimensions.
117    ///
118    /// # Arguments
119    ///
120    /// * `padding` - Padding specification. Accepts:
121    ///   - `[(before, after); N]` fixed-size array of pairs (N <= D)
122    ///   - `&[(before, after)]` slice of pairs per dimension
123    ///   - `Vec<(before, after)>` vector of pairs
124    ///   - `(left, right, top, bottom)` tuple for last-2-dim backward compatibility
125    /// * `mode` - The padding mode: `Constant(value)`, `Reflect`, or `Edge`.
126    ///
127    /// # Returns
128    ///
129    /// A new tensor with the specified padding applied.
130    ///
131    /// # Panics
132    ///
133    /// - Panics if more padding pairs are provided than tensor dimensions.
134    /// - `Reflect` mode panics if padding exceeds `dimension_size - 1`.
135    /// - `Edge` mode panics if padding is applied to a zero-sized dimension.
136    ///
137    /// # Example
138    ///
139    /// ```rust
140    /// use burn_tensor::backend::Backend;
141    /// use burn_tensor::{Tensor, Shape};
142    /// use burn_tensor::ops::PadMode;
143    ///
144    /// fn example<B: Backend<FloatElem: From<f32>>>() {
145    ///    let device = B::Device::default();
146    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);
147    ///
148    ///    // Constant padding with value 0.0 (backward-compatible tuple)
149    ///    let padded = tensor.clone().pad((1, 1, 1, 1), PadMode::Constant(0.0));
150    ///
151    ///    // Pad arbitrary dimensions with slice of (before, after) pairs
152    ///    let padded = tensor.clone().pad([(1, 1), (2, 2)], PadMode::Constant(0.0));
153    ///
154    ///    // Pad only the last dimension
155    ///    let padded = tensor.pad([(1, 1)], PadMode::Reflect);
156    /// }
157    /// ```
158    pub fn pad(self, padding: impl IntoPadding<D>, mode: impl Into<PadMode>) -> Self {
159        let pairs = padding.into_padding();
160        match mode.into() {
161            PadMode::Constant(value) => pad_constant(self, &pairs, value),
162            PadMode::Reflect => pad_reflect(self, &pairs),
163            PadMode::Edge => pad_edge(self, &pairs),
164        }
165    }
166}
167
168/// Pad with a constant value.
169fn pad_constant<B, const D: usize, K, E>(
170    tensor: Tensor<B, D, K>,
171    padding: &[(usize, usize); D],
172    value: E,
173) -> Tensor<B, D, K>
174where
175    B: Backend,
176    K: Numeric<B>,
177    K::Elem: Element,
178    E: ElementConversion,
179{
180    let mut padded_dims: [usize; D] = tensor.dims();
181
182    for (i, &(before, after)) in padding.iter().enumerate() {
183        padded_dims[i] += before + after;
184    }
185
186    let ranges: [Range<usize>; D] = padded_dims
187        .iter()
188        .enumerate()
189        .map(|(i, &dim)| {
190            let (before, after) = padding[i];
191            before..dim - after
192        })
193        .collect::<Vec<Range<usize>>>()
194        .try_into()
195        .unwrap();
196
197    let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());
198
199    padded_tensor.slice_assign(ranges, tensor)
200}
201
202/// Pad using reflection at the boundaries (excluding edge values).
203///
204/// For ONNX "reflect" mode: mirrors from index 1, not index 0.
205/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[3, 2, 1, 2, 3, 4]`
206fn pad_reflect<B, const D: usize, K>(
207    tensor: Tensor<B, D, K>,
208    padding: &[(usize, usize); D],
209) -> Tensor<B, D, K>
210where
211    B: Backend,
212    K: Numeric<B>,
213    K::Elem: Element,
214{
215    let dims = tensor.dims();
216
217    for (i, &(before, after)) in padding.iter().enumerate() {
218        if before > 0 || after > 0 {
219            assert!(
220                before < dims[i] && after < dims[i],
221                "Reflect padding ({}, {}) must be less than dimension {} size ({})",
222                before,
223                after,
224                i,
225                dims[i]
226            );
227        }
228    }
229
230    let mut result = tensor;
231
232    for (i, &(before, after)) in padding.iter().enumerate() {
233        if before > 0 || after > 0 {
234            result = pad_reflect_dim(result, i, before, after);
235        }
236    }
237
238    result
239}
240
241/// Helper to pad a single dimension using reflection.
242fn pad_reflect_dim<B, const D: usize, K>(
243    tensor: Tensor<B, D, K>,
244    dim: usize,
245    pad_before: usize,
246    pad_after: usize,
247) -> Tensor<B, D, K>
248where
249    B: Backend,
250    K: Numeric<B>,
251    K::Elem: Element,
252{
253    let dims = tensor.dims();
254    let dim_size = dims[dim];
255
256    // Calculate output dimensions
257    let mut output_dims = dims;
258    output_dims[dim] += pad_before + pad_after;
259
260    // Create output tensor and place original in the center
261    let output = Tensor::zeros(output_dims, &tensor.device());
262    let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
263    let mut output = output.slice_assign(original_range, tensor.clone());
264
265    // Assign reflected "before" padding (e.g., top or left)
266    // Reflect excludes the edge, so we take indices [1..pad_before+1] and flip
267    if pad_before > 0 {
268        let before_slice = tensor.clone().narrow(dim, 1, pad_before);
269        let before_flipped = before_slice.flip([dim as isize]);
270        let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
271        output = output.slice_assign(before_range, before_flipped);
272    }
273
274    // Assign reflected "after" padding (e.g., bottom or right)
275    // Take indices [dim_size - pad_after - 1..dim_size - 1] and flip
276    if pad_after > 0 {
277        let start = dim_size - pad_after - 1;
278        let after_slice = tensor.narrow(dim, start, pad_after);
279        let after_flipped = after_slice.flip([dim as isize]);
280        let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
281        output = output.slice_assign(after_range, after_flipped);
282    }
283
284    output
285}
286
287/// Pad by replicating edge values.
288///
289/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[1, 1, 1, 2, 3, 4]`
290fn pad_edge<B, const D: usize, K>(
291    tensor: Tensor<B, D, K>,
292    padding: &[(usize, usize); D],
293) -> Tensor<B, D, K>
294where
295    B: Backend,
296    K: Numeric<B>,
297    K::Elem: Element,
298{
299    let dims = tensor.dims();
300
301    for (i, &(before, after)) in padding.iter().enumerate() {
302        if before > 0 || after > 0 {
303            assert!(
304                dims[i] > 0,
305                "Cannot apply edge padding to zero-sized dimension {}",
306                i
307            );
308        }
309    }
310
311    let mut result = tensor;
312
313    for (i, &(before, after)) in padding.iter().enumerate() {
314        if before > 0 || after > 0 {
315            result = pad_edge_dim(result, i, before, after);
316        }
317    }
318
319    result
320}
321
322/// Helper to pad a single dimension by replicating edge values.
323fn pad_edge_dim<B, const D: usize, K>(
324    tensor: Tensor<B, D, K>,
325    dim: usize,
326    pad_before: usize,
327    pad_after: usize,
328) -> Tensor<B, D, K>
329where
330    B: Backend,
331    K: Numeric<B>,
332    K::Elem: Element,
333{
334    let dims = tensor.dims();
335    let dim_size = dims[dim];
336
337    // Calculate output dimensions
338    let mut output_dims = dims;
339    output_dims[dim] += pad_before + pad_after;
340
341    // Create output tensor and place original in the center
342    let output = Tensor::zeros(output_dims, &tensor.device());
343    let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
344    let mut output = output.slice_assign(original_range, tensor.clone());
345
346    // Assign "before" padding by repeating the first element
347    if pad_before > 0 {
348        let first_slice = tensor.clone().narrow(dim, 0, 1);
349        let before_pad = first_slice.repeat_dim(dim, pad_before);
350        let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
351        output = output.slice_assign(before_range, before_pad);
352    }
353
354    // Assign "after" padding by repeating the last element
355    if pad_after > 0 {
356        let last_slice = tensor.narrow(dim, dim_size - 1, 1);
357        let after_pad = last_slice.repeat_dim(dim, pad_after);
358        let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
359        output = output.slice_assign(after_range, after_pad);
360    }
361
362    output
363}