Skip to main content

ndarray_conv/padding/
mod.rs

1//! Provides padding functionality for ndarray arrays.
2//!
3//! This module defines the `PaddingExt` trait, which extends the `ArrayBase`
4//! struct from the `ndarray` crate with methods for padding arrays using
5//! different padding modes. It also provides helper functions for
6//! applying specific types of padding.
7
8use super::{BorderType, PaddingMode};
9
10use ndarray::{
11    Array, ArrayBase, Data, DataMut, Dim, IntoDimension, Ix, RemoveAxis, SliceArg, SliceInfo,
12    SliceInfoElem,
13};
14use num::traits::NumAssign;
15
16pub(crate) mod dim;
17mod half_dim;
18
19/// Represents explicit padding sizes for each dimension.
20pub type ExplicitPadding<const N: usize> = [[usize; 2]; N];
21
22/// Extends `ndarray`'s `ArrayBase` with padding operations.
23///
24/// This trait provides the `padding` and `padding_in` methods for adding
25/// padding to an array using various modes, like zero padding, constant
26/// padding, replication, reflection, and circular padding.
27///
28/// # Type Parameters
29///
30/// *   `N`: The number of dimensions of the array.
31/// *   `T`: The numeric type of the array elements.
32/// *   `Output`: The type of the padded array returned by `padding`, typically an `Array<T, Dim<[Ix; N]>>`.
33pub trait PaddingExt<const N: usize, T: num::traits::NumAssign + Copy, Output> {
34    /// Returns a new array with the specified padding applied.
35    ///
36    /// This method creates a new array with the dimensions and padding specified by
37    /// `mode` and `padding_size`. It calls the `padding_in` method internally to handle the padding itself.
38    ///
39    /// # Arguments
40    ///
41    /// * `mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`).
42    /// * `padding_size`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`.
43    ///
44    /// # Returns
45    /// A new `Array` with the padded data.
46    fn padding(&self, mode: PaddingMode<N, T>, padding_size: ExplicitPadding<N>) -> Output;
47
48    /// Modifies the buffer in-place by applying padding using the specified mode.
49    ///
50    /// This method directly modifies the provided buffer by adding padding to its content.
51    ///
52    /// # Type Parameters
53    ///
54    /// *   `SO`: The data storage type of the output buffer.
55    /// *   `DO`: The dimension type of the output buffer.
56    ///
57    /// # Arguments
58    ///
59    /// * `buffer`: A mutable reference to the array to be padded.
60    /// * `mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`).
61    /// * `padding_size`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`.
62    fn padding_in<SO: DataMut<Elem = T>, DO: RemoveAxis>(
63        &self,
64        buffer: &mut ArrayBase<SO, DO>,
65        mode: PaddingMode<N, T>,
66        padding_size: ExplicitPadding<N>,
67    ) where
68        T: NumAssign + Copy,
69        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
70        SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
71        Dim<[Ix; N]>: RemoveAxis,
72        SliceInfo<[SliceInfoElem; N], DO, DO>: SliceArg<DO>;
73}
74
75impl<const N: usize, T, S, D> PaddingExt<N, T, Array<T, Dim<[Ix; N]>>> for ArrayBase<S, D>
76where
77    T: NumAssign + Copy,
78    S: Data<Elem = T>,
79    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
80    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
81    Dim<[Ix; N]>: RemoveAxis,
82    D: RemoveAxis + IntoDimension,
83{
84    fn padding(
85        &self,
86        mode: PaddingMode<N, T>,
87        explicit_padding: ExplicitPadding<N>,
88    ) -> Array<T, Dim<[Ix; N]>> {
89        let c = match mode {
90            PaddingMode::Const(c) => c,
91            _ => T::zero(),
92        };
93
94        let raw_dim = self.raw_dim();
95
96        let output_dim =
97            std::array::from_fn(|i| raw_dim[i] + explicit_padding[i][0] + explicit_padding[i][1]);
98
99        let mut output: Array<T, Dim<[Ix; N]>> = Array::from_elem(output_dim, c);
100
101        padding_const(self, &mut output, explicit_padding);
102
103        match mode {
104            PaddingMode::Replicate => padding_replicate(self, &mut output, explicit_padding),
105            PaddingMode::Reflect => padding_reflect(self, &mut output, explicit_padding),
106            PaddingMode::Circular => padding_circular(self, &mut output, explicit_padding),
107            PaddingMode::Custom(borders) => {
108                padding_custom(self, &mut output, explicit_padding, borders)
109            }
110            PaddingMode::Explicit(borders) => {
111                padding_explicit(self, &mut output, explicit_padding, borders)
112            }
113            _ => {}
114        };
115
116        output
117    }
118
119    fn padding_in<SO, DO>(
120        &self,
121        buffer: &mut ArrayBase<SO, DO>,
122        mode: PaddingMode<N, T>,
123        explicit_padding: ExplicitPadding<N>,
124    ) where
125        T: NumAssign + Copy,
126        S: Data<Elem = T>,
127        SO: DataMut<Elem = T>,
128        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
129        SliceInfo<[SliceInfoElem; N], DO, DO>: SliceArg<DO>,
130        Dim<[Ix; N]>: RemoveAxis,
131        DO: RemoveAxis,
132    {
133        padding_const(self, buffer, explicit_padding);
134
135        match mode {
136            PaddingMode::Const(c) => {
137                explicit_padding
138                    .iter()
139                    .enumerate()
140                    .for_each(|(dim, &explicit_padding)| {
141                        dim::constant(self.raw_dim(), buffer, dim, explicit_padding, c);
142                    })
143            }
144            PaddingMode::Replicate => padding_replicate(self, buffer, explicit_padding),
145            PaddingMode::Reflect => padding_reflect(self, buffer, explicit_padding),
146            PaddingMode::Circular => padding_circular(self, buffer, explicit_padding),
147            PaddingMode::Custom(borders) => padding_custom(self, buffer, explicit_padding, borders),
148            PaddingMode::Explicit(borders) => {
149                padding_explicit(self, buffer, explicit_padding, borders)
150            }
151            _ => {}
152        };
153    }
154}
155
156/// Applies padding using a constant value to the specified slice of the output array.
157///
158/// This function copies the input array to a specific slice of the output array, leaving the rest of the
159/// output array with the default padding value, which is typically a zero or a constant, depending on the padding mode.
160///
161/// # Type Parameters
162///
163/// *   `N`: The number of dimensions of the array.
164/// *   `T`: The numeric type of the array elements.
165/// *   `S`: The data storage type of the input array.
166/// *   `D`: The dimension type of the input array.
167/// *   `SO`: The data storage type of the output array.
168/// *   `DO`: The dimension type of the output array.
169///
170/// # Arguments
171///
172/// * `input`: The input array to pad.
173/// * `output`: A mutable reference to the array where the padded result will be stored.
174/// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`.
175pub(crate) fn padding_const<const N: usize, T, S, D, SO, DO>(
176    input: &ArrayBase<S, D>,
177    output: &mut ArrayBase<SO, DO>,
178    explicit_padding: ExplicitPadding<N>,
179) where
180    T: NumAssign + Copy,
181    S: Data<Elem = T>,
182    SO: DataMut<Elem = T>,
183    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
184    // SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
185    SliceInfo<[SliceInfoElem; N], DO, DO>: SliceArg<DO>,
186    Dim<[Ix; N]>: RemoveAxis,
187    D: RemoveAxis,
188    DO: RemoveAxis,
189{
190    let mut output_slice = output.slice_mut(unsafe {
191        SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
192            start: explicit_padding[i][0] as isize,
193            end: Some((explicit_padding[i][0] + input.raw_dim()[i]) as isize),
194            step: 1,
195        }))
196        .unwrap()
197    });
198
199    output_slice.assign(input);
200}
201
202/// Applies replicate padding to the specified slice of the output array.
203///
204/// This function uses the `dim::replicate` function to add replicate padding
205/// to each dimension of the output array.
206///
207/// # Type Parameters
208///
209/// *   `N`: The number of dimensions of the array.
210/// *   `T`: The numeric type of the array elements.
211/// *   `S`: The data storage type of the input array.
212/// *   `D`: The dimension type of the input array.
213/// *   `SO`: The data storage type of the output array.
214/// *   `DO`: The dimension type of the output array.
215///
216/// # Arguments
217///
218/// * `input`: The input array to pad.
219/// * `output`: A mutable reference to the array where the padded result will be stored.
220/// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`.
221fn padding_replicate<const N: usize, T, S, D, SO, DO>(
222    input: &ArrayBase<S, D>,
223    output: &mut ArrayBase<SO, DO>,
224    explicit_padding: ExplicitPadding<N>,
225) where
226    T: NumAssign + Copy,
227    S: Data<Elem = T>,
228    SO: DataMut<Elem = T>,
229    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
230    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
231    Dim<[Ix; N]>: RemoveAxis,
232    D: RemoveAxis + IntoDimension,
233    DO: RemoveAxis,
234{
235    explicit_padding
236        .iter()
237        .enumerate()
238        .for_each(|(dim, &explicit_padding)| {
239            dim::replicate(input.raw_dim(), output, dim, explicit_padding);
240        });
241}
242
243/// Applies reflect padding to the specified slice of the output array.
244///
245/// This function uses the `dim::reflect` function to add reflect padding
246/// to each dimension of the output array.
247///
248/// # Type Parameters
249///
250/// *   `N`: The number of dimensions of the array.
251/// *   `T`: The numeric type of the array elements.
252/// *   `S`: The data storage type of the input array.
253/// *   `D`: The dimension type of the input array.
254/// *   `SO`: The data storage type of the output array.
255/// *   `DO`: The dimension type of the output array.
256///
257/// # Arguments
258///
259/// * `input`: The input array to pad.
260/// * `output`: A mutable reference to the array where the padded result will be stored.
261/// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`.
262fn padding_reflect<const N: usize, T, S, D, SO, DO>(
263    input: &ArrayBase<S, D>,
264    output: &mut ArrayBase<SO, DO>,
265    explicit_padding: ExplicitPadding<N>,
266) where
267    T: NumAssign + Copy,
268    S: Data<Elem = T>,
269    SO: DataMut<Elem = T>,
270    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
271    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
272    Dim<[Ix; N]>: RemoveAxis,
273    D: RemoveAxis,
274    DO: RemoveAxis,
275{
276    explicit_padding
277        .iter()
278        .enumerate()
279        .for_each(|(dim, &explicit_padding)| {
280            dim::reflect(input.raw_dim(), output, dim, explicit_padding);
281        });
282}
283
284/// Applies circular padding to the specified slice of the output array.
285///
286/// This function uses the `dim::circular` function to add circular padding
287/// to each dimension of the output array.
288///
289/// # Type Parameters
290///
291/// *   `N`: The number of dimensions of the array.
292/// *   `T`: The numeric type of the array elements.
293/// *   `S`: The data storage type of the input array.
294/// *   `D`: The dimension type of the input array.
295/// *   `SO`: The data storage type of the output array.
296/// *   `DO`: The dimension type of the output array.
297///
298/// # Arguments
299///
300/// * `input`: The input array to pad.
301/// * `output`: A mutable reference to the array where the padded result will be stored.
302/// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`.
303fn padding_circular<const N: usize, T, S, D, SO, DO>(
304    input: &ArrayBase<S, D>,
305    output: &mut ArrayBase<SO, DO>,
306    explicit_padding: ExplicitPadding<N>,
307) where
308    T: NumAssign + Copy,
309    S: Data<Elem = T>,
310    SO: DataMut<Elem = T>,
311    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
312    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
313    Dim<[Ix; N]>: RemoveAxis,
314    D: RemoveAxis,
315    DO: RemoveAxis,
316{
317    explicit_padding
318        .iter()
319        .enumerate()
320        .for_each(|(dim, &explicit_padding)| {
321            dim::circular(input.raw_dim(), output, dim, explicit_padding);
322        });
323}
324
325/// Applies custom padding to the specified slice of the output array using `BorderType` for each dimension.
326///
327/// This function uses the `dim::constant`, `dim::reflect`, `dim::replicate`,
328/// or `dim::circular` function based on the corresponding `BorderType` specified in the `borders` argument,
329/// to add padding to each dimension of the output array.
330///
331/// # Type Parameters
332///
333/// *   `N`: The number of dimensions of the array.
334/// *   `T`: The numeric type of the array elements.
335/// *   `S`: The data storage type of the input array.
336/// *   `D`: The dimension type of the input array.
337/// *   `SO`: The data storage type of the output array.
338/// *   `DO`: The dimension type of the output array.
339///
340/// # Arguments
341///
342/// * `input`: The input array to pad.
343/// * `output`: A mutable reference to the array where the padded result will be stored.
344/// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`.
345/// * `borders`: An array containing a `BorderType` enum for each dimension.
346fn padding_custom<const N: usize, T, S, D, SO, DO>(
347    input: &ArrayBase<S, D>,
348    output: &mut ArrayBase<SO, DO>,
349    explicit_padding: ExplicitPadding<N>,
350    borders: [BorderType<T>; N],
351) where
352    T: NumAssign + Copy,
353    S: Data<Elem = T>,
354    SO: DataMut<Elem = T>,
355    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
356    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
357    Dim<[Ix; N]>: RemoveAxis,
358    D: RemoveAxis,
359    DO: RemoveAxis,
360{
361    explicit_padding
362        .iter()
363        .zip(borders.iter())
364        .enumerate()
365        .for_each(|(dim, (&explicit_padding, border))| match border {
366            BorderType::Zeros => {
367                dim::constant(input.raw_dim(), output, dim, explicit_padding, T::zero())
368            }
369            BorderType::Const(c) => {
370                dim::constant(input.raw_dim(), output, dim, explicit_padding, *c)
371            }
372            BorderType::Reflect => dim::reflect(input.raw_dim(), output, dim, explicit_padding),
373            BorderType::Replicate => dim::replicate(input.raw_dim(), output, dim, explicit_padding),
374            BorderType::Circular => dim::circular(input.raw_dim(), output, dim, explicit_padding),
375        });
376}
377
378/// Applies explicit padding to the specified slice of the output array using `BorderType` for each side of each dimension.
379///
380/// This function uses the `half_dim::constant_front`, `half_dim::constant_back`,
381/// `half_dim::reflect_front`, `half_dim::reflect_back`, `half_dim::replicate_front`,
382/// `half_dim::replicate_back`, `half_dim::circular_front`, and `half_dim::circular_back`
383/// functions based on the corresponding `BorderType` specified in the `borders` argument,
384/// to add padding to each dimension of the output array.
385///
386/// # Type Parameters
387///
388/// *   `N`: The number of dimensions of the array.
389/// *   `T`: The numeric type of the array elements.
390/// *   `S`: The data storage type of the input array.
391/// *   `D`: The dimension type of the input array.
392/// *   `SO`: The data storage type of the output array.
393/// *   `DO`: The dimension type of the output array.
394///
395/// # Arguments
396///
397/// * `input`: The input array to pad.
398/// * `output`: A mutable reference to the array where the padded result will be stored.
399/// * `explicit_padding`: An array representing the padding sizes for each dimension in the form `[[front, back]; N]`.
400/// * `borders`: An array containing an array of two `BorderType` enums for each dimension.
401fn padding_explicit<const N: usize, T, S, D, SO, DO>(
402    input: &ArrayBase<S, D>,
403    output: &mut ArrayBase<SO, DO>,
404    explicit_padding: ExplicitPadding<N>,
405    borders: [[BorderType<T>; 2]; N],
406) where
407    T: NumAssign + Copy,
408    S: Data<Elem = T>,
409    SO: DataMut<Elem = T>,
410    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
411    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
412    Dim<[Ix; N]>: RemoveAxis,
413    D: RemoveAxis,
414    DO: RemoveAxis,
415{
416    explicit_padding
417        .iter()
418        .zip(borders.iter())
419        .enumerate()
420        .for_each(|(dim, (&explicit_padding, border))| {
421            match border[0] {
422                BorderType::Zeros => {
423                    half_dim::constant_front(output, dim, explicit_padding, T::zero())
424                }
425                BorderType::Const(c) => half_dim::constant_front(output, dim, explicit_padding, c),
426                BorderType::Reflect => half_dim::reflect_front(output, dim, explicit_padding),
427                BorderType::Replicate => half_dim::replicate_front(output, dim, explicit_padding),
428                BorderType::Circular => half_dim::circular_front(output, dim, explicit_padding),
429            }
430            match border[1] {
431                BorderType::Zeros => half_dim::constant_back(
432                    input.raw_dim(),
433                    output,
434                    dim,
435                    explicit_padding,
436                    T::zero(),
437                ),
438                BorderType::Const(c) => {
439                    half_dim::constant_back(input.raw_dim(), output, dim, explicit_padding, c)
440                }
441                BorderType::Reflect => {
442                    half_dim::reflect_back(input.raw_dim(), output, dim, explicit_padding)
443                }
444                BorderType::Replicate => {
445                    half_dim::replicate_back(input.raw_dim(), output, dim, explicit_padding)
446                }
447                BorderType::Circular => {
448                    half_dim::circular_back(input.raw_dim(), output, dim, explicit_padding)
449                }
450            }
451        });
452}
453
454#[cfg(test)]
455mod tests {
456    use ndarray::prelude::*;
457
458    use super::*;
459    use crate::dilation::IntoKernelWithDilation;
460    use crate::ConvMode;
461
462    // ===== Basic Padding Tests =====
463
464    mod zeros_padding {
465        use super::*;
466
467        #[test]
468        fn test_1d() {
469            let arr = array![1, 2, 3];
470            let explicit_padding = [[1, 1]];
471            let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
472            assert_eq!(padded, array![0, 1, 2, 3, 0]);
473        }
474
475        #[test]
476        fn test_2d() {
477            let arr = array![[1, 2], [3, 4]];
478            let explicit_padding = [[1, 1], [1, 1]];
479            let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
480            assert_eq!(
481                padded,
482                array![[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]
483            );
484        }
485
486        #[test]
487        fn test_3d() {
488            let arr = array![[[1, 2]], [[3, 4]]];
489            let explicit_padding = [[1, 0], [0, 1], [1, 0]];
490            let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
491            // Shape: [2, 1, 2] -> [3, 2, 3]
492            // dim 0: padding [1, 0] => add 1 layer before
493            // dim 1: padding [0, 1] => add 1 layer after
494            // dim 2: padding [1, 0] => add 1 column before each row
495            assert_eq!(
496                padded,
497                array![
498                    [[0, 0, 0], [0, 0, 0]], // padded layer at front (dim 0)
499                    [[0, 1, 2], [0, 0, 0]], // original [[[1, 2]]] with padding
500                    [[0, 3, 4], [0, 0, 0]]  // original [[[3, 4]]] with padding
501                ]
502            );
503        }
504
505        #[test]
506        fn test_asymmetric_padding() {
507            let arr = array![1, 2, 3];
508            let explicit_padding = [[2, 1]];
509            let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
510            assert_eq!(padded, array![0, 0, 1, 2, 3, 0]);
511        }
512    }
513
514    mod const_padding {
515        use super::*;
516
517        #[test]
518        fn test_1d() {
519            let arr = array![1, 2, 3];
520            let explicit_padding = [[1, 1]];
521            let padded = arr.padding(PaddingMode::Const(7), explicit_padding);
522            assert_eq!(padded, array![7, 1, 2, 3, 7]);
523        }
524
525        #[test]
526        fn test_2d() {
527            let arr = array![[1, 2], [3, 4]];
528            let explicit_padding = [[1, 1], [1, 1]];
529            let padded = arr.padding(PaddingMode::Const(9), explicit_padding);
530            assert_eq!(
531                padded,
532                array![[9, 9, 9, 9], [9, 1, 2, 9], [9, 3, 4, 9], [9, 9, 9, 9]]
533            );
534        }
535    }
536
537    mod replicate_padding {
538        use super::*;
539
540        #[test]
541        fn test_1d() {
542            let arr = array![1, 2, 3];
543            let explicit_padding = [[1, 2]];
544            let padded = arr.padding(PaddingMode::Replicate, explicit_padding);
545            assert_eq!(padded, array![1, 1, 2, 3, 3, 3]);
546        }
547
548        #[test]
549        fn test_2d() {
550            let arr = array![[1, 2], [3, 4]];
551            let explicit_padding = [[1, 1], [1, 1]];
552            let padded = arr.padding(PaddingMode::Replicate, explicit_padding);
553            assert_eq!(
554                padded,
555                array![[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]
556            );
557        }
558
559        #[test]
560        fn test_large_padding() {
561            let arr = array![1, 2];
562            let explicit_padding = [[3, 3]];
563            let padded = arr.padding(PaddingMode::Replicate, explicit_padding);
564            assert_eq!(padded, array![1, 1, 1, 1, 2, 2, 2, 2]);
565        }
566    }
567
568    mod reflect_padding {
569        use super::*;
570
571        #[test]
572        fn test_1d() {
573            let arr = array![1, 2, 3, 4];
574            let explicit_padding = [[2, 2]];
575            let padded = arr.padding(PaddingMode::Reflect, explicit_padding);
576            assert_eq!(padded, array![3, 2, 1, 2, 3, 4, 3, 2]);
577        }
578
579        #[test]
580        fn test_2d() {
581            let arr = array![[1, 2, 3], [4, 5, 6]];
582            let explicit_padding = [[1, 1], [1, 1]];
583            let padded = arr.padding(PaddingMode::Reflect, explicit_padding);
584            assert_eq!(
585                padded,
586                array![
587                    [5, 4, 5, 6, 5],
588                    [2, 1, 2, 3, 2],
589                    [5, 4, 5, 6, 5],
590                    [2, 1, 2, 3, 2]
591                ]
592            );
593        }
594    }
595
596    mod circular_padding {
597        use super::*;
598
599        #[test]
600        fn test_1d() {
601            let arr = array![1, 2, 3, 4];
602            let explicit_padding = [[2, 2]];
603            let padded = arr.padding(PaddingMode::Circular, explicit_padding);
604            assert_eq!(padded, array![3, 4, 1, 2, 3, 4, 1, 2]);
605        }
606
607        #[test]
608        fn test_2d() {
609            let arr = array![[1, 2], [3, 4]];
610            let explicit_padding = [[1, 1], [1, 1]];
611            let padded = arr.padding(PaddingMode::Circular, explicit_padding);
612            assert_eq!(
613                padded,
614                array![[4, 3, 4, 3], [2, 1, 2, 1], [4, 3, 4, 3], [2, 1, 2, 1]]
615            );
616        }
617
618        #[test]
619        fn test_type_cast_safety() {
620            // Regression test for issue with type casting in circular padding
621            let arr = array![1u8, 2, 3];
622            let explicit_padding = [[1, 1]];
623            let padded = arr.padding(PaddingMode::Circular, explicit_padding);
624            assert_eq!(padded, array![3u8, 1, 2, 3, 1]);
625        }
626    }
627
628    mod custom_padding {
629        use super::*;
630
631        #[test]
632        fn test_per_dimension() {
633            let arr = array![[1, 2], [3, 4]];
634            let kernel = array![[1, 1, 1], [1, 1, 1], [1, 1, 1]];
635            let kernel = kernel.into_kernel_with_dilation();
636
637            let explicit_conv = ConvMode::Full.unfold(&kernel);
638            let explicit_padding = explicit_conv.padding;
639
640            let arr_padded = arr.padding(
641                PaddingMode::Custom([BorderType::Replicate, BorderType::Circular]),
642                explicit_padding,
643            );
644            assert_eq!(
645                arr_padded,
646                array![
647                    [1, 2, 1, 2, 1, 2],
648                    [1, 2, 1, 2, 1, 2],
649                    [1, 2, 1, 2, 1, 2],
650                    [3, 4, 3, 4, 3, 4],
651                    [3, 4, 3, 4, 3, 4],
652                    [3, 4, 3, 4, 3, 4]
653                ]
654            );
655        }
656
657        #[test]
658        fn test_mixed_types() {
659            let arr = array![[1, 2], [3, 4]];
660            let kernel = array![[1, 1, 1], [1, 1, 1], [1, 1, 1]];
661            let kernel = kernel.into_kernel_with_dilation();
662
663            let explicit_conv = ConvMode::Full.unfold(&kernel);
664            let explicit_padding = explicit_conv.padding;
665
666            let arr_padded = arr.padding(
667                PaddingMode::Custom([BorderType::Reflect, BorderType::Const(7)]),
668                explicit_padding,
669            );
670            assert_eq!(
671                arr_padded,
672                array![
673                    [7, 7, 0, 0, 7, 7],
674                    [7, 7, 3, 4, 7, 7],
675                    [7, 7, 1, 2, 7, 7],
676                    [7, 7, 3, 4, 7, 7],
677                    [7, 7, 1, 2, 7, 7],
678                    [7, 7, 3, 4, 7, 7]
679                ]
680            );
681        }
682    }
683
684    mod explicit_padding {
685        use super::*;
686
687        #[test]
688        fn test_per_side() {
689            let arr = array![1, 2, 3];
690            let explicit_padding = [[1, 2]];
691
692            // Use different BorderType for each side
693            let padded = arr.padding(
694                PaddingMode::Explicit([[BorderType::Const(7), BorderType::Const(9)]]),
695                explicit_padding,
696            );
697            assert_eq!(padded, array![7, 1, 2, 3, 9, 9]);
698        }
699    }
700
701    // ===== Edge Cases =====
702
703    mod edge_cases {
704        use super::*;
705
706        #[test]
707        fn test_zero_padding() {
708            let arr = array![1, 2, 3];
709            let explicit_padding = [[0, 0]];
710            let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
711            assert_eq!(padded, arr);
712        }
713
714        #[test]
715        fn test_single_element() {
716            let arr = array![42];
717            let explicit_padding = [[2, 2]];
718            let padded = arr.padding(PaddingMode::Replicate, explicit_padding);
719            assert_eq!(padded, array![42, 42, 42, 42, 42]);
720        }
721
722        #[test]
723        fn test_large_array() {
724            let arr = Array::from_shape_fn((100, 100), |(i, j)| (i + j) as i32);
725            let explicit_padding = [[5, 5], [5, 5]];
726            let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
727
728            // Verify shape
729            assert_eq!(padded.shape(), &[110, 110]);
730
731            // Verify padding is zeros
732            // Top padding
733            for i in 0..5 {
734                for j in 0..110 {
735                    assert_eq!(padded[[i, j]], 0);
736                }
737            }
738            // Bottom padding
739            for i in 105..110 {
740                for j in 0..110 {
741                    assert_eq!(padded[[i, j]], 0);
742                }
743            }
744            // Left and right padding (middle rows)
745            for i in 5..105 {
746                for j in 0..5 {
747                    assert_eq!(padded[[i, j]], 0);
748                }
749                for j in 105..110 {
750                    assert_eq!(padded[[i, j]], 0);
751                }
752            }
753
754            // Verify original data is preserved
755            assert_eq!(padded[[5, 5]], arr[[0, 0]]); // top-left
756            assert_eq!(padded[[54, 54]], arr[[49, 49]]); // middle
757            assert_eq!(padded[[104, 104]], arr[[99, 99]]); // bottom-right
758        }
759    }
760
761    // ===== Torch Verification Tests =====
762
763    #[test]
764    fn aligned_with_libtorch() {
765        // Test all padding modes against torch for 3D
766        let arr = array![[[1, 2, 3], [3, 4, 5]], [[5, 6, 7], [7, 8, 9]]];
767        let kernel = array![
768            [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
769            [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
770            [[1, 1, 1], [1, 1, 1], [1, 1, 1]]
771        ];
772        let explicit_conv = ConvMode::Same.unfold(&kernel.into_kernel_with_dilation());
773        let explicit_padding = explicit_conv.padding;
774        check(&arr, PaddingMode::Zeros, explicit_padding);
775        check(&arr, PaddingMode::Const(7), explicit_padding);
776        check(&arr, PaddingMode::Replicate, explicit_padding);
777        check(&arr, PaddingMode::Reflect, explicit_padding);
778        check(&arr, PaddingMode::Circular, explicit_padding);
779
780        // Test all padding modes against torch for 2D
781        let arr = array![[1, 2], [3, 4]];
782        let kernel = array![[1, 1], [1, 1]];
783        let explicit_conv = ConvMode::Full.unfold(&kernel.into_kernel_with_dilation());
784        let explicit_padding = explicit_conv.padding;
785        check(&arr, PaddingMode::Zeros, explicit_padding);
786        check(&arr, PaddingMode::Const(7), explicit_padding);
787        check(&arr, PaddingMode::Replicate, explicit_padding);
788        check(&arr, PaddingMode::Reflect, explicit_padding);
789        check(&arr, PaddingMode::Circular, explicit_padding);
790
791        // Test all padding modes against torch for 1D
792        let arr = array![1, 2, 3];
793        let kernel = array![1, 1, 1, 1];
794        let explicit_conv = ConvMode::Same.unfold(&kernel.into_kernel_with_dilation());
795        let explicit_padding = explicit_conv.padding;
796        check(&arr, PaddingMode::Zeros, explicit_padding);
797        check(&arr, PaddingMode::Const(7), explicit_padding);
798        check(&arr, PaddingMode::Replicate, explicit_padding);
799        check(&arr, PaddingMode::Reflect, explicit_padding);
800        check(&arr, PaddingMode::Circular, explicit_padding);
801    }
802
803    fn check<T, const N: usize>(
804        arr: &Array<T, Dim<[Ix; N]>>,
805        padding_mode: PaddingMode<N, T>,
806        explicit_padding: ExplicitPadding<N>,
807    ) where
808        T: num::traits::NumAssign + Copy + tch::kind::Element + std::fmt::Debug,
809        Dim<[Ix; N]>: Dimension,
810        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
811        SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
812        Dim<[Ix; N]>: RemoveAxis,
813        f64: std::convert::From<T>,
814        T: num::traits::FromPrimitive,
815    {
816        let ndarray_result = arr.padding(padding_mode, explicit_padding);
817        dbg!(&ndarray_result);
818
819        let shape = [1, 1]
820            .iter()
821            .chain(arr.shape())
822            .map(|s| *s as i64)
823            .collect::<Vec<_>>();
824        let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap())
825            .reshape(shape)
826            .totype(tch::Kind::Float);
827
828        let (mode, value) = match padding_mode {
829            PaddingMode::Zeros => ("constant", Some(0.0)),
830            PaddingMode::Const(c) => ("constant", Some(f64::from(c))),
831            PaddingMode::Replicate => ("replicate", None),
832            PaddingMode::Reflect => ("reflect", None),
833            PaddingMode::Circular => ("circular", None),
834            _ => unreachable!(),
835        };
836
837        let tensor_result = tensor
838            .f_pad(
839                explicit_padding
840                    .into_iter()
841                    .flatten()
842                    .map(|p| p as i64)
843                    .collect::<Vec<_>>(),
844                mode,
845                value,
846            )
847            .unwrap();
848
849        dbg!(&tensor_result);
850        tensor_result.print();
851
852        assert_eq!(
853            ndarray_result.into_raw_vec_and_offset().0,
854            tensor_result
855                .reshape(tensor_result.size().iter().product::<i64>())
856                .iter::<f64>()
857                .unwrap()
858                .map(|v| T::from_f64(v).unwrap())
859                .collect::<Vec<T>>()
860        );
861    }
862}