1use alloc::vec::Vec;
2use core::ops::Range;
3
4use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode};
5
6use super::Numeric;
7
8pub trait IntoPadding<const D: usize> {
14 fn into_padding(self) -> [(usize, usize); D];
16}
17
18impl<const D: usize, const N: usize> IntoPadding<D> for [(usize, usize); N] {
19 fn into_padding(self) -> [(usize, usize); D] {
20 assert!(
21 N <= D,
22 "Padding has {} pairs but tensor only has {} dimensions",
23 N,
24 D
25 );
26 let mut result = [(0usize, 0usize); D];
27 let offset = D - N;
28 for (i, pair) in self.into_iter().enumerate() {
29 result[offset + i] = pair;
30 }
31 result
32 }
33}
34
35impl<const D: usize> IntoPadding<D> for (usize, usize, usize, usize) {
39 fn into_padding(self) -> [(usize, usize); D] {
40 let (left, right, top, bottom) = self;
41 let mut result = [(0usize, 0usize); D];
42 result[D - 2] = (top, bottom);
43 result[D - 1] = (left, right);
44 result
45 }
46}
47
48impl<const D: usize> IntoPadding<D> for &[(usize, usize)] {
49 fn into_padding(self) -> [(usize, usize); D] {
50 assert!(
51 self.len() <= D,
52 "Padding has {} pairs but tensor only has {} dimensions",
53 self.len(),
54 D
55 );
56 let mut result = [(0usize, 0usize); D];
57 let offset = D - self.len();
58 for (i, &pair) in self.iter().enumerate() {
59 result[offset + i] = pair;
60 }
61 result
62 }
63}
64
65impl<const D: usize> IntoPadding<D> for Vec<(usize, usize)> {
66 fn into_padding(self) -> [(usize, usize); D] {
67 assert!(
68 self.len() <= D,
69 "Padding has {} pairs but tensor only has {} dimensions",
70 self.len(),
71 D
72 );
73 let mut result = [(0usize, 0usize); D];
74 let offset = D - self.len();
75 for (i, pair) in self.into_iter().enumerate() {
76 result[offset + i] = pair;
77 }
78 result
79 }
80}
81
82fn build_slice_ranges<const D: usize>(
84 dims: [usize; D],
85 target_dim: usize,
86 start: usize,
87 len: usize,
88) -> [Range<usize>; D] {
89 dims.iter()
90 .enumerate()
91 .map(|(i, &size)| {
92 if i == target_dim {
93 start..start + len
94 } else {
95 0..size
96 }
97 })
98 .collect::<Vec<Range<usize>>>()
99 .try_into()
100 .unwrap()
101}
102
103impl<B, const D: usize, K> Tensor<B, D, K>
104where
105 B: Backend,
106 K: Numeric<B>,
107 K::Elem: Element,
108{
109 pub fn pad(self, padding: impl IntoPadding<D>, mode: impl Into<PadMode>) -> Self {
159 let pairs = padding.into_padding();
160 match mode.into() {
161 PadMode::Constant(value) => pad_constant(self, &pairs, value),
162 PadMode::Reflect => pad_reflect(self, &pairs),
163 PadMode::Edge => pad_edge(self, &pairs),
164 }
165 }
166}
167
168fn pad_constant<B, const D: usize, K, E>(
170 tensor: Tensor<B, D, K>,
171 padding: &[(usize, usize); D],
172 value: E,
173) -> Tensor<B, D, K>
174where
175 B: Backend,
176 K: Numeric<B>,
177 K::Elem: Element,
178 E: ElementConversion,
179{
180 let mut padded_dims: [usize; D] = tensor.dims();
181
182 for (i, &(before, after)) in padding.iter().enumerate() {
183 padded_dims[i] += before + after;
184 }
185
186 let ranges: [Range<usize>; D] = padded_dims
187 .iter()
188 .enumerate()
189 .map(|(i, &dim)| {
190 let (before, after) = padding[i];
191 before..dim - after
192 })
193 .collect::<Vec<Range<usize>>>()
194 .try_into()
195 .unwrap();
196
197 let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());
198
199 padded_tensor.slice_assign(ranges, tensor)
200}
201
202fn pad_reflect<B, const D: usize, K>(
207 tensor: Tensor<B, D, K>,
208 padding: &[(usize, usize); D],
209) -> Tensor<B, D, K>
210where
211 B: Backend,
212 K: Numeric<B>,
213 K::Elem: Element,
214{
215 let dims = tensor.dims();
216
217 for (i, &(before, after)) in padding.iter().enumerate() {
218 if before > 0 || after > 0 {
219 assert!(
220 before < dims[i] && after < dims[i],
221 "Reflect padding ({}, {}) must be less than dimension {} size ({})",
222 before,
223 after,
224 i,
225 dims[i]
226 );
227 }
228 }
229
230 let mut result = tensor;
231
232 for (i, &(before, after)) in padding.iter().enumerate() {
233 if before > 0 || after > 0 {
234 result = pad_reflect_dim(result, i, before, after);
235 }
236 }
237
238 result
239}
240
241fn pad_reflect_dim<B, const D: usize, K>(
243 tensor: Tensor<B, D, K>,
244 dim: usize,
245 pad_before: usize,
246 pad_after: usize,
247) -> Tensor<B, D, K>
248where
249 B: Backend,
250 K: Numeric<B>,
251 K::Elem: Element,
252{
253 let dims = tensor.dims();
254 let dim_size = dims[dim];
255
256 let mut output_dims = dims;
258 output_dims[dim] += pad_before + pad_after;
259
260 let output = Tensor::zeros(output_dims, &tensor.device());
262 let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
263 let mut output = output.slice_assign(original_range, tensor.clone());
264
265 if pad_before > 0 {
268 let before_slice = tensor.clone().narrow(dim, 1, pad_before);
269 let before_flipped = before_slice.flip([dim as isize]);
270 let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
271 output = output.slice_assign(before_range, before_flipped);
272 }
273
274 if pad_after > 0 {
277 let start = dim_size - pad_after - 1;
278 let after_slice = tensor.narrow(dim, start, pad_after);
279 let after_flipped = after_slice.flip([dim as isize]);
280 let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
281 output = output.slice_assign(after_range, after_flipped);
282 }
283
284 output
285}
286
287fn pad_edge<B, const D: usize, K>(
291 tensor: Tensor<B, D, K>,
292 padding: &[(usize, usize); D],
293) -> Tensor<B, D, K>
294where
295 B: Backend,
296 K: Numeric<B>,
297 K::Elem: Element,
298{
299 let dims = tensor.dims();
300
301 for (i, &(before, after)) in padding.iter().enumerate() {
302 if before > 0 || after > 0 {
303 assert!(
304 dims[i] > 0,
305 "Cannot apply edge padding to zero-sized dimension {}",
306 i
307 );
308 }
309 }
310
311 let mut result = tensor;
312
313 for (i, &(before, after)) in padding.iter().enumerate() {
314 if before > 0 || after > 0 {
315 result = pad_edge_dim(result, i, before, after);
316 }
317 }
318
319 result
320}
321
322fn pad_edge_dim<B, const D: usize, K>(
324 tensor: Tensor<B, D, K>,
325 dim: usize,
326 pad_before: usize,
327 pad_after: usize,
328) -> Tensor<B, D, K>
329where
330 B: Backend,
331 K: Numeric<B>,
332 K::Elem: Element,
333{
334 let dims = tensor.dims();
335 let dim_size = dims[dim];
336
337 let mut output_dims = dims;
339 output_dims[dim] += pad_before + pad_after;
340
341 let output = Tensor::zeros(output_dims, &tensor.device());
343 let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
344 let mut output = output.slice_assign(original_range, tensor.clone());
345
346 if pad_before > 0 {
348 let first_slice = tensor.clone().narrow(dim, 0, 1);
349 let before_pad = first_slice.repeat_dim(dim, pad_before);
350 let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
351 output = output.slice_assign(before_range, before_pad);
352 }
353
354 if pad_after > 0 {
356 let last_slice = tensor.narrow(dim, dim_size - 1, 1);
357 let after_pad = last_slice.repeat_dim(dim, pad_after);
358 let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
359 output = output.slice_assign(after_range, after_pad);
360 }
361
362 output
363}