1use super::{BorderType, PaddingMode};
9
10use ndarray::{
11 Array, ArrayBase, Data, DataMut, Dim, IntoDimension, Ix, RemoveAxis, SliceArg, SliceInfo,
12 SliceInfoElem,
13};
14use num::traits::NumAssign;
15
16pub(crate) mod dim;
17mod half_dim;
18
19pub type ExplicitPadding<const N: usize> = [[usize; 2]; N];
21
22pub trait PaddingExt<const N: usize, T: num::traits::NumAssign + Copy, Output> {
34 fn padding(&self, mode: PaddingMode<N, T>, padding_size: ExplicitPadding<N>) -> Output;
47
48 fn padding_in<SO: DataMut<Elem = T>, DO: RemoveAxis>(
63 &self,
64 buffer: &mut ArrayBase<SO, DO>,
65 mode: PaddingMode<N, T>,
66 padding_size: ExplicitPadding<N>,
67 ) where
68 T: NumAssign + Copy,
69 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
70 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
71 Dim<[Ix; N]>: RemoveAxis,
72 SliceInfo<[SliceInfoElem; N], DO, DO>: SliceArg<DO>;
73}
74
75impl<const N: usize, T, S, D> PaddingExt<N, T, Array<T, Dim<[Ix; N]>>> for ArrayBase<S, D>
76where
77 T: NumAssign + Copy,
78 S: Data<Elem = T>,
79 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
80 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
81 Dim<[Ix; N]>: RemoveAxis,
82 D: RemoveAxis + IntoDimension,
83{
84 fn padding(
85 &self,
86 mode: PaddingMode<N, T>,
87 explicit_padding: ExplicitPadding<N>,
88 ) -> Array<T, Dim<[Ix; N]>> {
89 let c = match mode {
90 PaddingMode::Const(c) => c,
91 _ => T::zero(),
92 };
93
94 let raw_dim = self.raw_dim();
95
96 let output_dim =
97 std::array::from_fn(|i| raw_dim[i] + explicit_padding[i][0] + explicit_padding[i][1]);
98
99 let mut output: Array<T, Dim<[Ix; N]>> = Array::from_elem(output_dim, c);
100
101 padding_const(self, &mut output, explicit_padding);
102
103 match mode {
104 PaddingMode::Replicate => padding_replicate(self, &mut output, explicit_padding),
105 PaddingMode::Reflect => padding_reflect(self, &mut output, explicit_padding),
106 PaddingMode::Circular => padding_circular(self, &mut output, explicit_padding),
107 PaddingMode::Custom(borders) => {
108 padding_custom(self, &mut output, explicit_padding, borders)
109 }
110 PaddingMode::Explicit(borders) => {
111 padding_explicit(self, &mut output, explicit_padding, borders)
112 }
113 _ => {}
114 };
115
116 output
117 }
118
119 fn padding_in<SO, DO>(
120 &self,
121 buffer: &mut ArrayBase<SO, DO>,
122 mode: PaddingMode<N, T>,
123 explicit_padding: ExplicitPadding<N>,
124 ) where
125 T: NumAssign + Copy,
126 S: Data<Elem = T>,
127 SO: DataMut<Elem = T>,
128 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
129 SliceInfo<[SliceInfoElem; N], DO, DO>: SliceArg<DO>,
130 Dim<[Ix; N]>: RemoveAxis,
131 DO: RemoveAxis,
132 {
133 padding_const(self, buffer, explicit_padding);
134
135 match mode {
136 PaddingMode::Const(c) => {
137 explicit_padding
138 .iter()
139 .enumerate()
140 .for_each(|(dim, &explicit_padding)| {
141 dim::constant(self.raw_dim(), buffer, dim, explicit_padding, c);
142 })
143 }
144 PaddingMode::Replicate => padding_replicate(self, buffer, explicit_padding),
145 PaddingMode::Reflect => padding_reflect(self, buffer, explicit_padding),
146 PaddingMode::Circular => padding_circular(self, buffer, explicit_padding),
147 PaddingMode::Custom(borders) => padding_custom(self, buffer, explicit_padding, borders),
148 PaddingMode::Explicit(borders) => {
149 padding_explicit(self, buffer, explicit_padding, borders)
150 }
151 _ => {}
152 };
153 }
154}
155
156pub(crate) fn padding_const<const N: usize, T, S, D, SO, DO>(
176 input: &ArrayBase<S, D>,
177 output: &mut ArrayBase<SO, DO>,
178 explicit_padding: ExplicitPadding<N>,
179) where
180 T: NumAssign + Copy,
181 S: Data<Elem = T>,
182 SO: DataMut<Elem = T>,
183 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
184 SliceInfo<[SliceInfoElem; N], DO, DO>: SliceArg<DO>,
186 Dim<[Ix; N]>: RemoveAxis,
187 D: RemoveAxis,
188 DO: RemoveAxis,
189{
190 let mut output_slice = output.slice_mut(unsafe {
191 SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
192 start: explicit_padding[i][0] as isize,
193 end: Some((explicit_padding[i][0] + input.raw_dim()[i]) as isize),
194 step: 1,
195 }))
196 .unwrap()
197 });
198
199 output_slice.assign(input);
200}
201
202fn padding_replicate<const N: usize, T, S, D, SO, DO>(
222 input: &ArrayBase<S, D>,
223 output: &mut ArrayBase<SO, DO>,
224 explicit_padding: ExplicitPadding<N>,
225) where
226 T: NumAssign + Copy,
227 S: Data<Elem = T>,
228 SO: DataMut<Elem = T>,
229 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
230 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
231 Dim<[Ix; N]>: RemoveAxis,
232 D: RemoveAxis + IntoDimension,
233 DO: RemoveAxis,
234{
235 explicit_padding
236 .iter()
237 .enumerate()
238 .for_each(|(dim, &explicit_padding)| {
239 dim::replicate(input.raw_dim(), output, dim, explicit_padding);
240 });
241}
242
243fn padding_reflect<const N: usize, T, S, D, SO, DO>(
263 input: &ArrayBase<S, D>,
264 output: &mut ArrayBase<SO, DO>,
265 explicit_padding: ExplicitPadding<N>,
266) where
267 T: NumAssign + Copy,
268 S: Data<Elem = T>,
269 SO: DataMut<Elem = T>,
270 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
271 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
272 Dim<[Ix; N]>: RemoveAxis,
273 D: RemoveAxis,
274 DO: RemoveAxis,
275{
276 explicit_padding
277 .iter()
278 .enumerate()
279 .for_each(|(dim, &explicit_padding)| {
280 dim::reflect(input.raw_dim(), output, dim, explicit_padding);
281 });
282}
283
284fn padding_circular<const N: usize, T, S, D, SO, DO>(
304 input: &ArrayBase<S, D>,
305 output: &mut ArrayBase<SO, DO>,
306 explicit_padding: ExplicitPadding<N>,
307) where
308 T: NumAssign + Copy,
309 S: Data<Elem = T>,
310 SO: DataMut<Elem = T>,
311 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
312 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
313 Dim<[Ix; N]>: RemoveAxis,
314 D: RemoveAxis,
315 DO: RemoveAxis,
316{
317 explicit_padding
318 .iter()
319 .enumerate()
320 .for_each(|(dim, &explicit_padding)| {
321 dim::circular(input.raw_dim(), output, dim, explicit_padding);
322 });
323}
324
325fn padding_custom<const N: usize, T, S, D, SO, DO>(
347 input: &ArrayBase<S, D>,
348 output: &mut ArrayBase<SO, DO>,
349 explicit_padding: ExplicitPadding<N>,
350 borders: [BorderType<T>; N],
351) where
352 T: NumAssign + Copy,
353 S: Data<Elem = T>,
354 SO: DataMut<Elem = T>,
355 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
356 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
357 Dim<[Ix; N]>: RemoveAxis,
358 D: RemoveAxis,
359 DO: RemoveAxis,
360{
361 explicit_padding
362 .iter()
363 .zip(borders.iter())
364 .enumerate()
365 .for_each(|(dim, (&explicit_padding, border))| match border {
366 BorderType::Zeros => {
367 dim::constant(input.raw_dim(), output, dim, explicit_padding, T::zero())
368 }
369 BorderType::Const(c) => {
370 dim::constant(input.raw_dim(), output, dim, explicit_padding, *c)
371 }
372 BorderType::Reflect => dim::reflect(input.raw_dim(), output, dim, explicit_padding),
373 BorderType::Replicate => dim::replicate(input.raw_dim(), output, dim, explicit_padding),
374 BorderType::Circular => dim::circular(input.raw_dim(), output, dim, explicit_padding),
375 });
376}
377
378fn padding_explicit<const N: usize, T, S, D, SO, DO>(
402 input: &ArrayBase<S, D>,
403 output: &mut ArrayBase<SO, DO>,
404 explicit_padding: ExplicitPadding<N>,
405 borders: [[BorderType<T>; 2]; N],
406) where
407 T: NumAssign + Copy,
408 S: Data<Elem = T>,
409 SO: DataMut<Elem = T>,
410 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
411 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
412 Dim<[Ix; N]>: RemoveAxis,
413 D: RemoveAxis,
414 DO: RemoveAxis,
415{
416 explicit_padding
417 .iter()
418 .zip(borders.iter())
419 .enumerate()
420 .for_each(|(dim, (&explicit_padding, border))| {
421 match border[0] {
422 BorderType::Zeros => {
423 half_dim::constant_front(output, dim, explicit_padding, T::zero())
424 }
425 BorderType::Const(c) => half_dim::constant_front(output, dim, explicit_padding, c),
426 BorderType::Reflect => half_dim::reflect_front(output, dim, explicit_padding),
427 BorderType::Replicate => half_dim::replicate_front(output, dim, explicit_padding),
428 BorderType::Circular => half_dim::circular_front(output, dim, explicit_padding),
429 }
430 match border[1] {
431 BorderType::Zeros => half_dim::constant_back(
432 input.raw_dim(),
433 output,
434 dim,
435 explicit_padding,
436 T::zero(),
437 ),
438 BorderType::Const(c) => {
439 half_dim::constant_back(input.raw_dim(), output, dim, explicit_padding, c)
440 }
441 BorderType::Reflect => {
442 half_dim::reflect_back(input.raw_dim(), output, dim, explicit_padding)
443 }
444 BorderType::Replicate => {
445 half_dim::replicate_back(input.raw_dim(), output, dim, explicit_padding)
446 }
447 BorderType::Circular => {
448 half_dim::circular_back(input.raw_dim(), output, dim, explicit_padding)
449 }
450 }
451 });
452}
453
454#[cfg(test)]
455mod tests {
456 use ndarray::prelude::*;
457
458 use super::*;
459 use crate::dilation::IntoKernelWithDilation;
460 use crate::ConvMode;
461
462 mod zeros_padding {
465 use super::*;
466
467 #[test]
468 fn test_1d() {
469 let arr = array![1, 2, 3];
470 let explicit_padding = [[1, 1]];
471 let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
472 assert_eq!(padded, array![0, 1, 2, 3, 0]);
473 }
474
475 #[test]
476 fn test_2d() {
477 let arr = array![[1, 2], [3, 4]];
478 let explicit_padding = [[1, 1], [1, 1]];
479 let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
480 assert_eq!(
481 padded,
482 array![[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]
483 );
484 }
485
486 #[test]
487 fn test_3d() {
488 let arr = array![[[1, 2]], [[3, 4]]];
489 let explicit_padding = [[1, 0], [0, 1], [1, 0]];
490 let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
491 assert_eq!(
496 padded,
497 array![
498 [[0, 0, 0], [0, 0, 0]], [[0, 1, 2], [0, 0, 0]], [[0, 3, 4], [0, 0, 0]] ]
502 );
503 }
504
505 #[test]
506 fn test_asymmetric_padding() {
507 let arr = array![1, 2, 3];
508 let explicit_padding = [[2, 1]];
509 let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
510 assert_eq!(padded, array![0, 0, 1, 2, 3, 0]);
511 }
512 }
513
514 mod const_padding {
515 use super::*;
516
517 #[test]
518 fn test_1d() {
519 let arr = array![1, 2, 3];
520 let explicit_padding = [[1, 1]];
521 let padded = arr.padding(PaddingMode::Const(7), explicit_padding);
522 assert_eq!(padded, array![7, 1, 2, 3, 7]);
523 }
524
525 #[test]
526 fn test_2d() {
527 let arr = array![[1, 2], [3, 4]];
528 let explicit_padding = [[1, 1], [1, 1]];
529 let padded = arr.padding(PaddingMode::Const(9), explicit_padding);
530 assert_eq!(
531 padded,
532 array![[9, 9, 9, 9], [9, 1, 2, 9], [9, 3, 4, 9], [9, 9, 9, 9]]
533 );
534 }
535 }
536
537 mod replicate_padding {
538 use super::*;
539
540 #[test]
541 fn test_1d() {
542 let arr = array![1, 2, 3];
543 let explicit_padding = [[1, 2]];
544 let padded = arr.padding(PaddingMode::Replicate, explicit_padding);
545 assert_eq!(padded, array![1, 1, 2, 3, 3, 3]);
546 }
547
548 #[test]
549 fn test_2d() {
550 let arr = array![[1, 2], [3, 4]];
551 let explicit_padding = [[1, 1], [1, 1]];
552 let padded = arr.padding(PaddingMode::Replicate, explicit_padding);
553 assert_eq!(
554 padded,
555 array![[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]
556 );
557 }
558
559 #[test]
560 fn test_large_padding() {
561 let arr = array![1, 2];
562 let explicit_padding = [[3, 3]];
563 let padded = arr.padding(PaddingMode::Replicate, explicit_padding);
564 assert_eq!(padded, array![1, 1, 1, 1, 2, 2, 2, 2]);
565 }
566 }
567
568 mod reflect_padding {
569 use super::*;
570
571 #[test]
572 fn test_1d() {
573 let arr = array![1, 2, 3, 4];
574 let explicit_padding = [[2, 2]];
575 let padded = arr.padding(PaddingMode::Reflect, explicit_padding);
576 assert_eq!(padded, array![3, 2, 1, 2, 3, 4, 3, 2]);
577 }
578
579 #[test]
580 fn test_2d() {
581 let arr = array![[1, 2, 3], [4, 5, 6]];
582 let explicit_padding = [[1, 1], [1, 1]];
583 let padded = arr.padding(PaddingMode::Reflect, explicit_padding);
584 assert_eq!(
585 padded,
586 array![
587 [5, 4, 5, 6, 5],
588 [2, 1, 2, 3, 2],
589 [5, 4, 5, 6, 5],
590 [2, 1, 2, 3, 2]
591 ]
592 );
593 }
594 }
595
596 mod circular_padding {
597 use super::*;
598
599 #[test]
600 fn test_1d() {
601 let arr = array![1, 2, 3, 4];
602 let explicit_padding = [[2, 2]];
603 let padded = arr.padding(PaddingMode::Circular, explicit_padding);
604 assert_eq!(padded, array![3, 4, 1, 2, 3, 4, 1, 2]);
605 }
606
607 #[test]
608 fn test_2d() {
609 let arr = array![[1, 2], [3, 4]];
610 let explicit_padding = [[1, 1], [1, 1]];
611 let padded = arr.padding(PaddingMode::Circular, explicit_padding);
612 assert_eq!(
613 padded,
614 array![[4, 3, 4, 3], [2, 1, 2, 1], [4, 3, 4, 3], [2, 1, 2, 1]]
615 );
616 }
617
618 #[test]
619 fn test_type_cast_safety() {
620 let arr = array![1u8, 2, 3];
622 let explicit_padding = [[1, 1]];
623 let padded = arr.padding(PaddingMode::Circular, explicit_padding);
624 assert_eq!(padded, array![3u8, 1, 2, 3, 1]);
625 }
626 }
627
628 mod custom_padding {
629 use super::*;
630
631 #[test]
632 fn test_per_dimension() {
633 let arr = array![[1, 2], [3, 4]];
634 let kernel = array![[1, 1, 1], [1, 1, 1], [1, 1, 1]];
635 let kernel = kernel.into_kernel_with_dilation();
636
637 let explicit_conv = ConvMode::Full.unfold(&kernel);
638 let explicit_padding = explicit_conv.padding;
639
640 let arr_padded = arr.padding(
641 PaddingMode::Custom([BorderType::Replicate, BorderType::Circular]),
642 explicit_padding,
643 );
644 assert_eq!(
645 arr_padded,
646 array![
647 [1, 2, 1, 2, 1, 2],
648 [1, 2, 1, 2, 1, 2],
649 [1, 2, 1, 2, 1, 2],
650 [3, 4, 3, 4, 3, 4],
651 [3, 4, 3, 4, 3, 4],
652 [3, 4, 3, 4, 3, 4]
653 ]
654 );
655 }
656
657 #[test]
658 fn test_mixed_types() {
659 let arr = array![[1, 2], [3, 4]];
660 let kernel = array![[1, 1, 1], [1, 1, 1], [1, 1, 1]];
661 let kernel = kernel.into_kernel_with_dilation();
662
663 let explicit_conv = ConvMode::Full.unfold(&kernel);
664 let explicit_padding = explicit_conv.padding;
665
666 let arr_padded = arr.padding(
667 PaddingMode::Custom([BorderType::Reflect, BorderType::Const(7)]),
668 explicit_padding,
669 );
670 assert_eq!(
671 arr_padded,
672 array![
673 [7, 7, 0, 0, 7, 7],
674 [7, 7, 3, 4, 7, 7],
675 [7, 7, 1, 2, 7, 7],
676 [7, 7, 3, 4, 7, 7],
677 [7, 7, 1, 2, 7, 7],
678 [7, 7, 3, 4, 7, 7]
679 ]
680 );
681 }
682 }
683
684 mod explicit_padding {
685 use super::*;
686
687 #[test]
688 fn test_per_side() {
689 let arr = array![1, 2, 3];
690 let explicit_padding = [[1, 2]];
691
692 let padded = arr.padding(
694 PaddingMode::Explicit([[BorderType::Const(7), BorderType::Const(9)]]),
695 explicit_padding,
696 );
697 assert_eq!(padded, array![7, 1, 2, 3, 9, 9]);
698 }
699 }
700
701 mod edge_cases {
704 use super::*;
705
706 #[test]
707 fn test_zero_padding() {
708 let arr = array![1, 2, 3];
709 let explicit_padding = [[0, 0]];
710 let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
711 assert_eq!(padded, arr);
712 }
713
714 #[test]
715 fn test_single_element() {
716 let arr = array![42];
717 let explicit_padding = [[2, 2]];
718 let padded = arr.padding(PaddingMode::Replicate, explicit_padding);
719 assert_eq!(padded, array![42, 42, 42, 42, 42]);
720 }
721
722 #[test]
723 fn test_large_array() {
724 let arr = Array::from_shape_fn((100, 100), |(i, j)| (i + j) as i32);
725 let explicit_padding = [[5, 5], [5, 5]];
726 let padded = arr.padding(PaddingMode::Zeros, explicit_padding);
727
728 assert_eq!(padded.shape(), &[110, 110]);
730
731 for i in 0..5 {
734 for j in 0..110 {
735 assert_eq!(padded[[i, j]], 0);
736 }
737 }
738 for i in 105..110 {
740 for j in 0..110 {
741 assert_eq!(padded[[i, j]], 0);
742 }
743 }
744 for i in 5..105 {
746 for j in 0..5 {
747 assert_eq!(padded[[i, j]], 0);
748 }
749 for j in 105..110 {
750 assert_eq!(padded[[i, j]], 0);
751 }
752 }
753
754 assert_eq!(padded[[5, 5]], arr[[0, 0]]); assert_eq!(padded[[54, 54]], arr[[49, 49]]); assert_eq!(padded[[104, 104]], arr[[99, 99]]); }
759 }
760
761 #[test]
764 fn aligned_with_libtorch() {
765 let arr = array![[[1, 2, 3], [3, 4, 5]], [[5, 6, 7], [7, 8, 9]]];
767 let kernel = array![
768 [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
769 [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
770 [[1, 1, 1], [1, 1, 1], [1, 1, 1]]
771 ];
772 let explicit_conv = ConvMode::Same.unfold(&kernel.into_kernel_with_dilation());
773 let explicit_padding = explicit_conv.padding;
774 check(&arr, PaddingMode::Zeros, explicit_padding);
775 check(&arr, PaddingMode::Const(7), explicit_padding);
776 check(&arr, PaddingMode::Replicate, explicit_padding);
777 check(&arr, PaddingMode::Reflect, explicit_padding);
778 check(&arr, PaddingMode::Circular, explicit_padding);
779
780 let arr = array![[1, 2], [3, 4]];
782 let kernel = array![[1, 1], [1, 1]];
783 let explicit_conv = ConvMode::Full.unfold(&kernel.into_kernel_with_dilation());
784 let explicit_padding = explicit_conv.padding;
785 check(&arr, PaddingMode::Zeros, explicit_padding);
786 check(&arr, PaddingMode::Const(7), explicit_padding);
787 check(&arr, PaddingMode::Replicate, explicit_padding);
788 check(&arr, PaddingMode::Reflect, explicit_padding);
789 check(&arr, PaddingMode::Circular, explicit_padding);
790
791 let arr = array![1, 2, 3];
793 let kernel = array![1, 1, 1, 1];
794 let explicit_conv = ConvMode::Same.unfold(&kernel.into_kernel_with_dilation());
795 let explicit_padding = explicit_conv.padding;
796 check(&arr, PaddingMode::Zeros, explicit_padding);
797 check(&arr, PaddingMode::Const(7), explicit_padding);
798 check(&arr, PaddingMode::Replicate, explicit_padding);
799 check(&arr, PaddingMode::Reflect, explicit_padding);
800 check(&arr, PaddingMode::Circular, explicit_padding);
801 }
802
803 fn check<T, const N: usize>(
804 arr: &Array<T, Dim<[Ix; N]>>,
805 padding_mode: PaddingMode<N, T>,
806 explicit_padding: ExplicitPadding<N>,
807 ) where
808 T: num::traits::NumAssign + Copy + tch::kind::Element + std::fmt::Debug,
809 Dim<[Ix; N]>: Dimension,
810 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
811 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: SliceArg<Dim<[Ix; N]>>,
812 Dim<[Ix; N]>: RemoveAxis,
813 f64: std::convert::From<T>,
814 T: num::traits::FromPrimitive,
815 {
816 let ndarray_result = arr.padding(padding_mode, explicit_padding);
817 dbg!(&ndarray_result);
818
819 let shape = [1, 1]
820 .iter()
821 .chain(arr.shape())
822 .map(|s| *s as i64)
823 .collect::<Vec<_>>();
824 let tensor = tch::Tensor::from_slice(arr.as_slice().unwrap())
825 .reshape(shape)
826 .totype(tch::Kind::Float);
827
828 let (mode, value) = match padding_mode {
829 PaddingMode::Zeros => ("constant", Some(0.0)),
830 PaddingMode::Const(c) => ("constant", Some(f64::from(c))),
831 PaddingMode::Replicate => ("replicate", None),
832 PaddingMode::Reflect => ("reflect", None),
833 PaddingMode::Circular => ("circular", None),
834 _ => unreachable!(),
835 };
836
837 let tensor_result = tensor
838 .f_pad(
839 explicit_padding
840 .into_iter()
841 .flatten()
842 .map(|p| p as i64)
843 .collect::<Vec<_>>(),
844 mode,
845 value,
846 )
847 .unwrap();
848
849 dbg!(&tensor_result);
850 tensor_result.print();
851
852 assert_eq!(
853 ndarray_result.into_raw_vec_and_offset().0,
854 tensor_result
855 .reshape(tensor_result.size().iter().product::<i64>())
856 .iter::<f64>()
857 .unwrap()
858 .map(|v| T::from_f64(v).unwrap())
859 .collect::<Vec<T>>()
860 );
861 }
862}