1use alloc::vec::Vec;
2use core::ops::Range;
3
4use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode};
5
6use super::Numeric;
7
8fn build_slice_ranges<const D: usize>(
10 dims: [usize; D],
11 target_dim: usize,
12 start: usize,
13 len: usize,
14) -> [Range<usize>; D] {
15 dims.iter()
16 .enumerate()
17 .map(|(i, &size)| {
18 if i == target_dim {
19 start..start + len
20 } else {
21 0..size
22 }
23 })
24 .collect::<Vec<Range<usize>>>()
25 .try_into()
26 .unwrap()
27}
28
29impl<B, const D: usize, K> Tensor<B, D, K>
30where
31 B: Backend,
32 K: Numeric<B>,
33 K::Elem: Element,
34{
35 pub fn pad(self, padding: (usize, usize, usize, usize), mode: PadMode) -> Self {
84 match mode {
85 PadMode::Constant(value) => pad_constant(self, padding, value),
86 PadMode::Reflect => pad_reflect(self, padding),
87 PadMode::Edge => pad_edge(self, padding),
88 }
89 }
90}
91
92pub fn pad_constant<B, const D: usize, K, E>(
94 tensor: Tensor<B, D, K>,
95 padding: (usize, usize, usize, usize),
96 value: E,
97) -> Tensor<B, D, K>
98where
99 B: Backend,
100 K: Numeric<B>,
101 K::Elem: Element,
102 E: ElementConversion,
103{
104 let (left, right, top, bottom) = padding;
105
106 let mut padded_dims: [usize; D] = tensor.dims();
107
108 padded_dims[D - 2] += top + bottom;
110 padded_dims[D - 1] += left + right;
111
112 let ranges: [core::ops::Range<usize>; D] = padded_dims
114 .iter()
115 .enumerate()
116 .map(|(i, &dim)| {
117 if i == D - 2 {
118 top..dim - bottom
119 } else if i == D - 1 {
120 left..dim - right
121 } else {
122 0..dim
123 }
124 })
125 .collect::<Vec<core::ops::Range<usize>>>()
126 .try_into()
127 .unwrap();
128
129 let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());
131
132 padded_tensor.slice_assign(ranges, tensor)
134}
135
136pub fn pad_reflect<B, const D: usize, K>(
141 tensor: Tensor<B, D, K>,
142 padding: (usize, usize, usize, usize),
143) -> Tensor<B, D, K>
144where
145 B: Backend,
146 K: Numeric<B>,
147 K::Elem: Element,
148{
149 let (left, right, top, bottom) = padding;
150 let dims = tensor.dims();
151
152 assert!(
155 top < dims[D - 2] && bottom < dims[D - 2],
156 "Reflect padding on height ({}, {}) must be less than height dimension ({})",
157 top,
158 bottom,
159 dims[D - 2]
160 );
161 assert!(
162 left < dims[D - 1] && right < dims[D - 1],
163 "Reflect padding on width ({}, {}) must be less than width dimension ({})",
164 left,
165 right,
166 dims[D - 1]
167 );
168
169 let mut result = tensor;
170
171 if top > 0 || bottom > 0 {
173 result = pad_reflect_dim(result, D - 2, top, bottom);
174 }
175
176 if left > 0 || right > 0 {
178 result = pad_reflect_dim(result, D - 1, left, right);
179 }
180
181 result
182}
183
184fn pad_reflect_dim<B, const D: usize, K>(
186 tensor: Tensor<B, D, K>,
187 dim: usize,
188 pad_before: usize,
189 pad_after: usize,
190) -> Tensor<B, D, K>
191where
192 B: Backend,
193 K: Numeric<B>,
194 K::Elem: Element,
195{
196 let dims = tensor.dims();
197 let dim_size = dims[dim];
198
199 let mut output_dims = dims;
201 output_dims[dim] += pad_before + pad_after;
202
203 let output = Tensor::zeros(output_dims, &tensor.device());
205 let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
206 let mut output = output.slice_assign(original_range, tensor.clone());
207
208 if pad_before > 0 {
211 let before_slice = tensor.clone().narrow(dim, 1, pad_before);
212 let before_flipped = before_slice.flip([dim as isize]);
213 let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
214 output = output.slice_assign(before_range, before_flipped);
215 }
216
217 if pad_after > 0 {
220 let start = dim_size - pad_after - 1;
221 let after_slice = tensor.narrow(dim, start, pad_after);
222 let after_flipped = after_slice.flip([dim as isize]);
223 let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
224 output = output.slice_assign(after_range, after_flipped);
225 }
226
227 output
228}
229
230pub fn pad_edge<B, const D: usize, K>(
234 tensor: Tensor<B, D, K>,
235 padding: (usize, usize, usize, usize),
236) -> Tensor<B, D, K>
237where
238 B: Backend,
239 K: Numeric<B>,
240 K::Elem: Element,
241{
242 let (left, right, top, bottom) = padding;
243 let dims = tensor.dims();
244
245 if top > 0 || bottom > 0 {
247 assert!(
248 dims[D - 2] > 0,
249 "Cannot apply edge padding to zero-sized height dimension"
250 );
251 }
252 if left > 0 || right > 0 {
253 assert!(
254 dims[D - 1] > 0,
255 "Cannot apply edge padding to zero-sized width dimension"
256 );
257 }
258
259 let mut result = tensor;
260
261 if top > 0 || bottom > 0 {
263 result = pad_edge_dim(result, D - 2, top, bottom);
264 }
265
266 if left > 0 || right > 0 {
268 result = pad_edge_dim(result, D - 1, left, right);
269 }
270
271 result
272}
273
274fn pad_edge_dim<B, const D: usize, K>(
276 tensor: Tensor<B, D, K>,
277 dim: usize,
278 pad_before: usize,
279 pad_after: usize,
280) -> Tensor<B, D, K>
281where
282 B: Backend,
283 K: Numeric<B>,
284 K::Elem: Element,
285{
286 let dims = tensor.dims();
287 let dim_size = dims[dim];
288
289 let mut output_dims = dims;
291 output_dims[dim] += pad_before + pad_after;
292
293 let output = Tensor::zeros(output_dims, &tensor.device());
295 let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
296 let mut output = output.slice_assign(original_range, tensor.clone());
297
298 if pad_before > 0 {
300 let first_slice = tensor.clone().narrow(dim, 0, 1);
301 let before_pad = first_slice.repeat_dim(dim, pad_before);
302 let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
303 output = output.slice_assign(before_range, before_pad);
304 }
305
306 if pad_after > 0 {
308 let last_slice = tensor.narrow(dim, dim_size - 1, 1);
309 let after_pad = last_slice.repeat_dim(dim, pad_after);
310 let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
311 output = output.slice_assign(after_range, after_pad);
312 }
313
314 output
315}