1use std::sync::Arc;
25
26use ferrotorch_core::autograd::no_grad::is_grad_enabled;
27use ferrotorch_core::storage::TensorStorage;
28use ferrotorch_core::tensor::{GradFn, Tensor};
29use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
30
31use crate::module::Module;
32use crate::parameter::Parameter;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum PaddingMode {
41 Zeros,
43 Reflect,
45 Replicate,
47 Circular,
49}
50
51fn pad_1d_constant<T: Float>(
60 data: &[T],
61 shape: &[usize],
62 pad_left: usize,
63 pad_right: usize,
64 value: T,
65) -> (Vec<T>, Vec<usize>) {
66 let ndim = shape.len();
67 let inner = shape[ndim - 1];
68 let new_inner = inner + pad_left + pad_right;
69
70 let rows: usize = shape[..ndim - 1].iter().product();
72 let rows = if rows == 0 { 1 } else { rows };
73
74 let mut out = vec![value; rows * new_inner];
75 if !data.is_empty() {
82 for r in 0..rows {
83 let src_start = r * inner;
84 let dst_start = r * new_inner + pad_left;
85 out[dst_start..dst_start + inner].copy_from_slice(&data[src_start..src_start + inner]);
86 }
87 }
88
89 let mut new_shape = shape.to_vec();
90 new_shape[ndim - 1] = new_inner;
91 (out, new_shape)
92}
93
94fn pad_2d_constant<T: Float>(
96 data: &[T],
97 shape: &[usize],
98 pad_left: usize,
99 pad_right: usize,
100 pad_top: usize,
101 pad_bottom: usize,
102 value: T,
103) -> (Vec<T>, Vec<usize>) {
104 let ndim = shape.len();
105 let h = shape[ndim - 2];
106 let w = shape[ndim - 1];
107 let new_h = h + pad_top + pad_bottom;
108 let new_w = w + pad_left + pad_right;
109
110 let outer: usize = shape[..ndim - 2].iter().product();
111 let outer = if outer == 0 { 1 } else { outer };
112
113 let mut out = vec![value; outer * new_h * new_w];
114 if !data.is_empty() {
117 for o in 0..outer {
118 for row in 0..h {
119 let src_off = o * h * w + row * w;
120 let dst_off = o * new_h * new_w + (row + pad_top) * new_w + pad_left;
121 out[dst_off..dst_off + w].copy_from_slice(&data[src_off..src_off + w]);
122 }
123 }
124 }
125
126 let mut new_shape = shape.to_vec();
127 new_shape[ndim - 2] = new_h;
128 new_shape[ndim - 1] = new_w;
129 (out, new_shape)
130}
131
132#[allow(clippy::too_many_arguments)]
136fn pad_3d_constant<T: Float>(
137 data: &[T],
138 shape: &[usize],
139 pad_left: usize,
140 pad_right: usize,
141 pad_top: usize,
142 pad_bottom: usize,
143 pad_front: usize,
144 pad_back: usize,
145 value: T,
146) -> (Vec<T>, Vec<usize>) {
147 let ndim = shape.len();
148 let d = shape[ndim - 3];
149 let h = shape[ndim - 2];
150 let w = shape[ndim - 1];
151 let new_d = d + pad_front + pad_back;
152 let new_h = h + pad_top + pad_bottom;
153 let new_w = w + pad_left + pad_right;
154
155 let outer: usize = shape[..ndim - 3].iter().product();
156 let outer = if outer == 0 { 1 } else { outer };
157
158 let mut out = vec![value; outer * new_d * new_h * new_w];
159 if !data.is_empty() {
162 for o in 0..outer {
163 for dep in 0..d {
164 for row in 0..h {
165 let src_off = o * d * h * w + dep * h * w + row * w;
166 let dst_off = o * new_d * new_h * new_w
167 + (dep + pad_front) * new_h * new_w
168 + (row + pad_top) * new_w
169 + pad_left;
170 out[dst_off..dst_off + w].copy_from_slice(&data[src_off..src_off + w]);
171 }
172 }
173 }
174 }
175
176 let mut new_shape = shape.to_vec();
177 new_shape[ndim - 3] = new_d;
178 new_shape[ndim - 2] = new_h;
179 new_shape[ndim - 1] = new_w;
180 (out, new_shape)
181}
182
183fn pad_1d_reflect<T: Float>(
189 data: &[T],
190 shape: &[usize],
191 pad_left: usize,
192 pad_right: usize,
193) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
194 let ndim = shape.len();
195 let inner = shape[ndim - 1];
196 if pad_left >= inner || pad_right >= inner {
197 return Err(FerrotorchError::InvalidArgument {
198 message: format!(
199 "Reflection padding ({pad_left}, {pad_right}) must be less than input size ({inner})"
200 ),
201 });
202 }
203 let new_inner = inner + pad_left + pad_right;
204 let rows: usize = shape[..ndim - 1].iter().copied().product::<usize>().max(1);
205
206 let zero = <T as num_traits::Zero>::zero();
207 let mut out = vec![zero; rows * new_inner];
208 for r in 0..rows {
209 let src = &data[r * inner..(r + 1) * inner];
210 let dst = &mut out[r * new_inner..(r + 1) * new_inner];
211 for i in 0..pad_left {
213 dst[pad_left - 1 - i] = src[i + 1];
214 }
215 dst[pad_left..pad_left + inner].copy_from_slice(src);
217 for i in 0..pad_right {
219 dst[pad_left + inner + i] = src[inner - 2 - i];
220 }
221 }
222
223 let mut new_shape = shape.to_vec();
224 new_shape[ndim - 1] = new_inner;
225 Ok((out, new_shape))
226}
227
228fn pad_2d_reflect<T: Float>(
230 data: &[T],
231 shape: &[usize],
232 pad_left: usize,
233 pad_right: usize,
234 pad_top: usize,
235 pad_bottom: usize,
236) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
237 let ndim = shape.len();
238 let h = shape[ndim - 2];
239 let w = shape[ndim - 1];
240 if pad_left >= w || pad_right >= w || pad_top >= h || pad_bottom >= h {
241 return Err(FerrotorchError::InvalidArgument {
242 message: format!(
243 "Reflection padding ({pad_left}, {pad_right}, {pad_top}, {pad_bottom}) must be less than input size ({h}, {w})"
244 ),
245 });
246 }
247 let new_h = h + pad_top + pad_bottom;
248 let new_w = w + pad_left + pad_right;
249 let outer: usize = shape[..ndim - 2].iter().copied().product::<usize>().max(1);
250
251 let zero = <T as num_traits::Zero>::zero();
252 let mut out = vec![zero; outer * new_h * new_w];
253
254 for o in 0..outer {
255 let src_base = o * h * w;
256 let dst_base = o * new_h * new_w;
257
258 for new_row in 0..new_h {
259 let src_row = if new_row < pad_top {
261 pad_top - new_row
262 } else if new_row >= pad_top + h {
263 h - 2 - (new_row - pad_top - h)
264 } else {
265 new_row - pad_top
266 };
267
268 for new_col in 0..new_w {
269 let src_col = if new_col < pad_left {
270 pad_left - new_col
271 } else if new_col >= pad_left + w {
272 w - 2 - (new_col - pad_left - w)
273 } else {
274 new_col - pad_left
275 };
276
277 out[dst_base + new_row * new_w + new_col] = data[src_base + src_row * w + src_col];
278 }
279 }
280 }
281
282 let mut new_shape = shape.to_vec();
283 new_shape[ndim - 2] = new_h;
284 new_shape[ndim - 1] = new_w;
285 Ok((out, new_shape))
286}
287
288#[allow(clippy::too_many_arguments)]
291fn pad_3d_reflect<T: Float>(
292 data: &[T],
293 shape: &[usize],
294 pad_left: usize,
295 pad_right: usize,
296 pad_top: usize,
297 pad_bottom: usize,
298 pad_front: usize,
299 pad_back: usize,
300) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
301 let ndim = shape.len();
302 let d = shape[ndim - 3];
303 let h = shape[ndim - 2];
304 let w = shape[ndim - 1];
305 if pad_left >= w
306 || pad_right >= w
307 || pad_top >= h
308 || pad_bottom >= h
309 || pad_front >= d
310 || pad_back >= d
311 {
312 return Err(FerrotorchError::InvalidArgument {
313 message: "Reflection padding must be less than corresponding input dimension".into(),
314 });
315 }
316 let new_d = d + pad_front + pad_back;
317 let new_h = h + pad_top + pad_bottom;
318 let new_w = w + pad_left + pad_right;
319 let outer: usize = shape[..ndim - 3].iter().copied().product::<usize>().max(1);
320
321 let zero = <T as num_traits::Zero>::zero();
322 let mut out = vec![zero; outer * new_d * new_h * new_w];
323
324 for o in 0..outer {
325 let src_base = o * d * h * w;
326 let dst_base = o * new_d * new_h * new_w;
327
328 for nd in 0..new_d {
329 let sd = if nd < pad_front {
330 pad_front - nd
331 } else if nd >= pad_front + d {
332 d - 2 - (nd - pad_front - d)
333 } else {
334 nd - pad_front
335 };
336 for nh in 0..new_h {
337 let sh = if nh < pad_top {
338 pad_top - nh
339 } else if nh >= pad_top + h {
340 h - 2 - (nh - pad_top - h)
341 } else {
342 nh - pad_top
343 };
344 for nw in 0..new_w {
345 let sw = if nw < pad_left {
346 pad_left - nw
347 } else if nw >= pad_left + w {
348 w - 2 - (nw - pad_left - w)
349 } else {
350 nw - pad_left
351 };
352 out[dst_base + nd * new_h * new_w + nh * new_w + nw] =
353 data[src_base + sd * h * w + sh * w + sw];
354 }
355 }
356 }
357 }
358
359 let mut new_shape = shape.to_vec();
360 new_shape[ndim - 3] = new_d;
361 new_shape[ndim - 2] = new_h;
362 new_shape[ndim - 1] = new_w;
363 Ok((out, new_shape))
364}
365
366fn pad_1d_replicate<T: Float>(
372 data: &[T],
373 shape: &[usize],
374 pad_left: usize,
375 pad_right: usize,
376) -> (Vec<T>, Vec<usize>) {
377 let ndim = shape.len();
378 let inner = shape[ndim - 1];
379 let new_inner = inner + pad_left + pad_right;
380 let rows: usize = shape[..ndim - 1].iter().copied().product::<usize>().max(1);
381
382 let zero = <T as num_traits::Zero>::zero();
383 let mut out = vec![zero; rows * new_inner];
384 for r in 0..rows {
385 let src = &data[r * inner..(r + 1) * inner];
386 let dst = &mut out[r * new_inner..(r + 1) * new_inner];
387 for (i, d) in dst.iter_mut().enumerate() {
388 let src_idx = if i < pad_left {
389 0
390 } else if i >= pad_left + inner {
391 inner - 1
392 } else {
393 i - pad_left
394 };
395 *d = src[src_idx];
396 }
397 }
398
399 let mut new_shape = shape.to_vec();
400 new_shape[ndim - 1] = new_inner;
401 (out, new_shape)
402}
403
404fn pad_2d_replicate<T: Float>(
406 data: &[T],
407 shape: &[usize],
408 pad_left: usize,
409 pad_right: usize,
410 pad_top: usize,
411 pad_bottom: usize,
412) -> (Vec<T>, Vec<usize>) {
413 let ndim = shape.len();
414 let h = shape[ndim - 2];
415 let w = shape[ndim - 1];
416 let new_h = h + pad_top + pad_bottom;
417 let new_w = w + pad_left + pad_right;
418 let outer: usize = shape[..ndim - 2].iter().copied().product::<usize>().max(1);
419
420 let zero = <T as num_traits::Zero>::zero();
421 let mut out = vec![zero; outer * new_h * new_w];
422
423 for o in 0..outer {
424 let src_base = o * h * w;
425 let dst_base = o * new_h * new_w;
426 for nr in 0..new_h {
427 let sr = nr.saturating_sub(pad_top).min(h - 1);
428 for nc in 0..new_w {
429 let sc = nc.saturating_sub(pad_left).min(w - 1);
430 out[dst_base + nr * new_w + nc] = data[src_base + sr * w + sc];
431 }
432 }
433 }
434
435 let mut new_shape = shape.to_vec();
436 new_shape[ndim - 2] = new_h;
437 new_shape[ndim - 1] = new_w;
438 (out, new_shape)
439}
440
441#[allow(clippy::too_many_arguments)]
444fn pad_3d_replicate<T: Float>(
445 data: &[T],
446 shape: &[usize],
447 pad_left: usize,
448 pad_right: usize,
449 pad_top: usize,
450 pad_bottom: usize,
451 pad_front: usize,
452 pad_back: usize,
453) -> (Vec<T>, Vec<usize>) {
454 let ndim = shape.len();
455 let d = shape[ndim - 3];
456 let h = shape[ndim - 2];
457 let w = shape[ndim - 1];
458 let new_d = d + pad_front + pad_back;
459 let new_h = h + pad_top + pad_bottom;
460 let new_w = w + pad_left + pad_right;
461 let outer: usize = shape[..ndim - 3].iter().copied().product::<usize>().max(1);
462
463 let zero = <T as num_traits::Zero>::zero();
464 let mut out = vec![zero; outer * new_d * new_h * new_w];
465
466 for o in 0..outer {
467 let src_base = o * d * h * w;
468 let dst_base = o * new_d * new_h * new_w;
469 for nd in 0..new_d {
470 let sd = nd.saturating_sub(pad_front).min(d - 1);
471 for nh in 0..new_h {
472 let sh = nh.saturating_sub(pad_top).min(h - 1);
473 for nw in 0..new_w {
474 let sw = nw.saturating_sub(pad_left).min(w - 1);
475 out[dst_base + nd * new_h * new_w + nh * new_w + nw] =
476 data[src_base + sd * h * w + sh * w + sw];
477 }
478 }
479 }
480 }
481
482 let mut new_shape = shape.to_vec();
483 new_shape[ndim - 3] = new_d;
484 new_shape[ndim - 2] = new_h;
485 new_shape[ndim - 1] = new_w;
486 (out, new_shape)
487}
488
489fn check_circular_positive(axes: &[(usize, usize)]) -> FerrotorchResult<()> {
504 for (idx, &(size, pad)) in axes.iter().enumerate() {
505 if pad > size {
506 return Err(FerrotorchError::InvalidArgument {
507 message: format!(
508 "Circular padding {pad} on axis (size {size}, position {idx}) causes wrapping around more than once (pad must be <= size)"
509 ),
510 });
511 }
512 }
513 Ok(())
514}
515
516fn pad_1d_circular<T: Float>(
518 data: &[T],
519 shape: &[usize],
520 pad_left: usize,
521 pad_right: usize,
522) -> (Vec<T>, Vec<usize>) {
523 let ndim = shape.len();
524 let inner = shape[ndim - 1];
525 let new_inner = inner + pad_left + pad_right;
526 let rows: usize = shape[..ndim - 1].iter().copied().product::<usize>().max(1);
527
528 let zero = <T as num_traits::Zero>::zero();
529 let mut out = vec![zero; rows * new_inner];
530 for r in 0..rows {
531 let src = &data[r * inner..(r + 1) * inner];
532 let dst = &mut out[r * new_inner..(r + 1) * new_inner];
533 for (i, d) in dst.iter_mut().enumerate() {
534 let src_idx = ((i as isize - pad_left as isize).rem_euclid(inner as isize)) as usize;
536 *d = src[src_idx];
537 }
538 }
539
540 let mut new_shape = shape.to_vec();
541 new_shape[ndim - 1] = new_inner;
542 (out, new_shape)
543}
544
545fn pad_2d_circular<T: Float>(
547 data: &[T],
548 shape: &[usize],
549 pad_left: usize,
550 pad_right: usize,
551 pad_top: usize,
552 pad_bottom: usize,
553) -> (Vec<T>, Vec<usize>) {
554 let ndim = shape.len();
555 let h = shape[ndim - 2];
556 let w = shape[ndim - 1];
557 let new_h = h + pad_top + pad_bottom;
558 let new_w = w + pad_left + pad_right;
559 let outer: usize = shape[..ndim - 2].iter().copied().product::<usize>().max(1);
560
561 let zero = <T as num_traits::Zero>::zero();
562 let mut out = vec![zero; outer * new_h * new_w];
563
564 for o in 0..outer {
565 let src_base = o * h * w;
566 let dst_base = o * new_h * new_w;
567 for nr in 0..new_h {
568 let sr = ((nr as isize - pad_top as isize).rem_euclid(h as isize)) as usize;
569 for nc in 0..new_w {
570 let sc = ((nc as isize - pad_left as isize).rem_euclid(w as isize)) as usize;
571 out[dst_base + nr * new_w + nc] = data[src_base + sr * w + sc];
572 }
573 }
574 }
575
576 let mut new_shape = shape.to_vec();
577 new_shape[ndim - 2] = new_h;
578 new_shape[ndim - 1] = new_w;
579 (out, new_shape)
580}
581
582#[allow(clippy::too_many_arguments)]
585fn pad_3d_circular<T: Float>(
586 data: &[T],
587 shape: &[usize],
588 pad_left: usize,
589 pad_right: usize,
590 pad_top: usize,
591 pad_bottom: usize,
592 pad_front: usize,
593 pad_back: usize,
594) -> (Vec<T>, Vec<usize>) {
595 let ndim = shape.len();
596 let d = shape[ndim - 3];
597 let h = shape[ndim - 2];
598 let w = shape[ndim - 1];
599 let new_d = d + pad_front + pad_back;
600 let new_h = h + pad_top + pad_bottom;
601 let new_w = w + pad_left + pad_right;
602 let outer: usize = shape[..ndim - 3].iter().copied().product::<usize>().max(1);
603
604 let zero = <T as num_traits::Zero>::zero();
605 let mut out = vec![zero; outer * new_d * new_h * new_w];
606
607 for o in 0..outer {
608 let src_base = o * d * h * w;
609 let dst_base = o * new_d * new_h * new_w;
610 for nd in 0..new_d {
611 let sd = ((nd as isize - pad_front as isize).rem_euclid(d as isize)) as usize;
612 for nh in 0..new_h {
613 let sh = ((nh as isize - pad_top as isize).rem_euclid(h as isize)) as usize;
614 for nw in 0..new_w {
615 let sw = ((nw as isize - pad_left as isize).rem_euclid(w as isize)) as usize;
616 out[dst_base + nd * new_h * new_w + nh * new_w + nw] =
617 data[src_base + sd * h * w + sh * w + sw];
618 }
619 }
620 }
621 }
622
623 let mut new_shape = shape.to_vec();
624 new_shape[ndim - 3] = new_d;
625 new_shape[ndim - 2] = new_h;
626 new_shape[ndim - 1] = new_w;
627 (out, new_shape)
628}
629
630fn src_index_1d(mode: PaddingMode, new_idx: usize, inner: usize, pad_left: usize) -> Option<usize> {
649 let s: usize = match mode {
650 PaddingMode::Zeros => {
651 if new_idx < pad_left || new_idx >= pad_left + inner {
652 return None;
653 }
654 new_idx - pad_left
655 }
656 PaddingMode::Reflect => {
657 if new_idx < pad_left {
658 pad_left - new_idx
659 } else if new_idx >= pad_left + inner {
660 inner - 2 - (new_idx - pad_left - inner)
661 } else {
662 new_idx - pad_left
663 }
664 }
665 PaddingMode::Replicate => new_idx.saturating_sub(pad_left).min(inner - 1),
666 PaddingMode::Circular => {
667 ((new_idx as isize - pad_left as isize).rem_euclid(inner as isize)) as usize
668 }
669 };
670 Some(s)
671}
672
673#[derive(Debug)]
676struct Pad1dBackward<T: Float> {
677 input: Tensor<T>,
678 input_shape: Vec<usize>,
679 mode: PaddingMode,
680 pad_left: usize,
681}
682
683impl<T: Float> GradFn<T> for Pad1dBackward<T> {
684 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
685 if !self.input.requires_grad() {
686 return Ok(vec![None]);
687 }
688 let ndim = self.input_shape.len();
689 let inner = self.input_shape[ndim - 1];
690 let rows: usize = self.input_shape[..ndim - 1]
691 .iter()
692 .copied()
693 .product::<usize>()
694 .max(1);
695
696 let go_shape = grad_output.shape();
697 let new_inner = go_shape[ndim - 1];
698
699 let go = grad_output.data_vec()?;
702 let zero = <T as num_traits::Zero>::zero();
703 let mut grad_in = vec![zero; rows * inner];
704
705 for r in 0..rows {
706 let go_base = r * new_inner;
707 let gi_base = r * inner;
708 for ni in 0..new_inner {
709 if let Some(src) = src_index_1d(self.mode, ni, inner, self.pad_left) {
710 grad_in[gi_base + src] += go[go_base + ni];
711 }
712 }
713 }
714
715 let grad_input =
716 Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
717 Ok(vec![Some(grad_input)])
718 }
719
720 fn inputs(&self) -> Vec<&Tensor<T>> {
721 vec![&self.input]
722 }
723
724 fn name(&self) -> &'static str {
725 "Pad1dBackward"
726 }
727}
728
729pub fn functional_pad_1d<T: Float>(
740 input: &Tensor<T>,
741 pad_left: usize,
742 pad_right: usize,
743 mode: PaddingMode,
744 value: T,
745) -> FerrotorchResult<Tensor<T>> {
746 if mode == PaddingMode::Zeros {
754 return functional_pad_1d_signed(input, pad_left as isize, pad_right as isize, mode, value);
755 }
756
757 let data = input.data_vec()?;
758 let shape = input.shape();
759 let input_shape = shape.to_vec();
760 let (out_data, new_shape) = match mode {
764 PaddingMode::Reflect => pad_1d_reflect(&data, shape, pad_left, pad_right)?,
765 PaddingMode::Replicate => pad_1d_replicate(&data, shape, pad_left, pad_right),
766 PaddingMode::Circular => {
767 let inner = shape[shape.len() - 1];
768 check_circular_positive(&[(inner, pad_left), (inner, pad_right)])?;
769 pad_1d_circular(&data, shape, pad_left, pad_right)
770 }
771 PaddingMode::Zeros => {
772 return functional_pad_1d_signed(
773 input,
774 pad_left as isize,
775 pad_right as isize,
776 mode,
777 value,
778 );
779 }
780 };
781
782 if is_grad_enabled() && input.requires_grad() {
787 let grad_fn = Arc::new(Pad1dBackward {
788 input: input.clone(),
789 input_shape,
790 mode,
791 pad_left,
792 });
793 return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
794 }
795
796 Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false)
797}
798
799fn src_index_2d(
818 mode: PaddingMode,
819 new_row: usize,
820 new_col: usize,
821 h: usize,
822 w: usize,
823 pad_left: usize,
824 pad_top: usize,
825) -> Option<usize> {
826 let sr: usize = match mode {
827 PaddingMode::Zeros => {
828 if new_row < pad_top || new_row >= pad_top + h {
829 return None;
830 }
831 new_row - pad_top
832 }
833 PaddingMode::Reflect => {
834 if new_row < pad_top {
835 pad_top - new_row
836 } else if new_row >= pad_top + h {
837 h - 2 - (new_row - pad_top - h)
838 } else {
839 new_row - pad_top
840 }
841 }
842 PaddingMode::Replicate => new_row.saturating_sub(pad_top).min(h - 1),
843 PaddingMode::Circular => {
844 ((new_row as isize - pad_top as isize).rem_euclid(h as isize)) as usize
845 }
846 };
847 let sc: usize = match mode {
848 PaddingMode::Zeros => {
849 if new_col < pad_left || new_col >= pad_left + w {
850 return None;
851 }
852 new_col - pad_left
853 }
854 PaddingMode::Reflect => {
855 if new_col < pad_left {
856 pad_left - new_col
857 } else if new_col >= pad_left + w {
858 w - 2 - (new_col - pad_left - w)
859 } else {
860 new_col - pad_left
861 }
862 }
863 PaddingMode::Replicate => new_col.saturating_sub(pad_left).min(w - 1),
864 PaddingMode::Circular => {
865 ((new_col as isize - pad_left as isize).rem_euclid(w as isize)) as usize
866 }
867 };
868 Some(sr * w + sc)
869}
870
871#[derive(Debug)]
874struct Pad2dBackward<T: Float> {
875 input: Tensor<T>,
876 input_shape: Vec<usize>,
877 mode: PaddingMode,
878 pad_left: usize,
879 pad_top: usize,
880}
881
882impl<T: Float> GradFn<T> for Pad2dBackward<T> {
883 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
884 if !self.input.requires_grad() {
885 return Ok(vec![None]);
886 }
887 let ndim = self.input_shape.len();
888 let h = self.input_shape[ndim - 2];
889 let w = self.input_shape[ndim - 1];
890 let outer: usize = self.input_shape[..ndim - 2]
891 .iter()
892 .copied()
893 .product::<usize>()
894 .max(1);
895
896 let go_shape = grad_output.shape();
897 let new_h = go_shape[ndim - 2];
898 let new_w = go_shape[ndim - 1];
899
900 let go = grad_output.data_vec()?;
903 let zero = <T as num_traits::Zero>::zero();
904 let mut grad_in = vec![zero; outer * h * w];
905
906 for o in 0..outer {
907 let go_base = o * new_h * new_w;
908 let gi_base = o * h * w;
909 for nr in 0..new_h {
910 for nc in 0..new_w {
911 if let Some(src) =
912 src_index_2d(self.mode, nr, nc, h, w, self.pad_left, self.pad_top)
913 {
914 grad_in[gi_base + src] += go[go_base + nr * new_w + nc];
915 }
916 }
917 }
918 }
919
920 let grad_input =
921 Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
922 Ok(vec![Some(grad_input)])
923 }
924
925 fn inputs(&self) -> Vec<&Tensor<T>> {
926 vec![&self.input]
927 }
928
929 fn name(&self) -> &'static str {
930 "Pad2dBackward"
931 }
932}
933
934pub fn functional_pad_2d<T: Float>(
941 input: &Tensor<T>,
942 pad_left: usize,
943 pad_right: usize,
944 pad_top: usize,
945 pad_bottom: usize,
946 mode: PaddingMode,
947 value: T,
948) -> FerrotorchResult<Tensor<T>> {
949 if mode == PaddingMode::Zeros {
953 return functional_pad_2d_signed(
954 input,
955 pad_left as isize,
956 pad_right as isize,
957 pad_top as isize,
958 pad_bottom as isize,
959 mode,
960 value,
961 );
962 }
963
964 let data = input.data_vec()?;
965 let shape = input.shape();
966 let input_shape = shape.to_vec();
967 let (out_data, new_shape) = match mode {
968 PaddingMode::Reflect => {
969 pad_2d_reflect(&data, shape, pad_left, pad_right, pad_top, pad_bottom)?
970 }
971 PaddingMode::Replicate => {
972 pad_2d_replicate(&data, shape, pad_left, pad_right, pad_top, pad_bottom)
973 }
974 PaddingMode::Circular => {
975 let nd = shape.len();
976 let (h, w) = (shape[nd - 2], shape[nd - 1]);
977 check_circular_positive(&[
978 (w, pad_left),
979 (w, pad_right),
980 (h, pad_top),
981 (h, pad_bottom),
982 ])?;
983 pad_2d_circular(&data, shape, pad_left, pad_right, pad_top, pad_bottom)
984 }
985 PaddingMode::Zeros => {
986 return functional_pad_2d_signed(
987 input,
988 pad_left as isize,
989 pad_right as isize,
990 pad_top as isize,
991 pad_bottom as isize,
992 mode,
993 value,
994 );
995 }
996 };
997
998 if is_grad_enabled() && input.requires_grad() {
1001 let grad_fn = Arc::new(Pad2dBackward {
1002 input: input.clone(),
1003 input_shape,
1004 mode,
1005 pad_left,
1006 pad_top,
1007 });
1008 return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
1009 }
1010
1011 Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false)
1012}
1013
1014#[allow(clippy::too_many_arguments)]
1030fn src_index_3d(
1031 mode: PaddingMode,
1032 nd: usize,
1033 nh: usize,
1034 nw: usize,
1035 d: usize,
1036 h: usize,
1037 w: usize,
1038 pad_left: usize,
1039 pad_top: usize,
1040 pad_front: usize,
1041) -> Option<usize> {
1042 fn axis(mode: PaddingMode, new_idx: usize, size: usize, pad_lo: usize) -> Option<usize> {
1045 let s = match mode {
1046 PaddingMode::Zeros => {
1047 if new_idx < pad_lo || new_idx >= pad_lo + size {
1048 return None;
1049 }
1050 new_idx - pad_lo
1051 }
1052 PaddingMode::Reflect => {
1053 if new_idx < pad_lo {
1054 pad_lo - new_idx
1055 } else if new_idx >= pad_lo + size {
1056 size - 2 - (new_idx - pad_lo - size)
1057 } else {
1058 new_idx - pad_lo
1059 }
1060 }
1061 PaddingMode::Replicate => new_idx.saturating_sub(pad_lo).min(size - 1),
1062 PaddingMode::Circular => {
1063 ((new_idx as isize - pad_lo as isize).rem_euclid(size as isize)) as usize
1064 }
1065 };
1066 Some(s)
1067 }
1068 let sd = axis(mode, nd, d, pad_front)?;
1069 let sh = axis(mode, nh, h, pad_top)?;
1070 let sw = axis(mode, nw, w, pad_left)?;
1071 Some(sd * h * w + sh * w + sw)
1072}
1073
1074#[derive(Debug)]
1077struct Pad3dBackward<T: Float> {
1078 input: Tensor<T>,
1079 input_shape: Vec<usize>,
1080 mode: PaddingMode,
1081 pad_left: usize,
1082 pad_top: usize,
1083 pad_front: usize,
1084}
1085
1086impl<T: Float> GradFn<T> for Pad3dBackward<T> {
1087 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1088 if !self.input.requires_grad() {
1089 return Ok(vec![None]);
1090 }
1091 let ndim = self.input_shape.len();
1092 let d = self.input_shape[ndim - 3];
1093 let h = self.input_shape[ndim - 2];
1094 let w = self.input_shape[ndim - 1];
1095 let outer: usize = self.input_shape[..ndim - 3]
1096 .iter()
1097 .copied()
1098 .product::<usize>()
1099 .max(1);
1100
1101 let go_shape = grad_output.shape();
1102 let new_d = go_shape[ndim - 3];
1103 let new_h = go_shape[ndim - 2];
1104 let new_w = go_shape[ndim - 1];
1105
1106 let go = grad_output.data_vec()?;
1109 let zero = <T as num_traits::Zero>::zero();
1110 let mut grad_in = vec![zero; outer * d * h * w];
1111
1112 for o in 0..outer {
1113 let go_base = o * new_d * new_h * new_w;
1114 let gi_base = o * d * h * w;
1115 for ndp in 0..new_d {
1116 for nhp in 0..new_h {
1117 for nwp in 0..new_w {
1118 if let Some(src) = src_index_3d(
1119 self.mode,
1120 ndp,
1121 nhp,
1122 nwp,
1123 d,
1124 h,
1125 w,
1126 self.pad_left,
1127 self.pad_top,
1128 self.pad_front,
1129 ) {
1130 grad_in[gi_base + src] +=
1131 go[go_base + ndp * new_h * new_w + nhp * new_w + nwp];
1132 }
1133 }
1134 }
1135 }
1136 }
1137
1138 let grad_input =
1139 Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
1140 Ok(vec![Some(grad_input)])
1141 }
1142
1143 fn inputs(&self) -> Vec<&Tensor<T>> {
1144 vec![&self.input]
1145 }
1146
1147 fn name(&self) -> &'static str {
1148 "Pad3dBackward"
1149 }
1150}
1151
1152#[allow(clippy::too_many_arguments)]
1163pub fn functional_pad_3d<T: Float>(
1164 input: &Tensor<T>,
1165 pad_left: usize,
1166 pad_right: usize,
1167 pad_top: usize,
1168 pad_bottom: usize,
1169 pad_front: usize,
1170 pad_back: usize,
1171 mode: PaddingMode,
1172 value: T,
1173) -> FerrotorchResult<Tensor<T>> {
1174 if mode == PaddingMode::Zeros {
1178 return functional_pad_3d_signed(
1179 input,
1180 pad_left as isize,
1181 pad_right as isize,
1182 pad_top as isize,
1183 pad_bottom as isize,
1184 pad_front as isize,
1185 pad_back as isize,
1186 mode,
1187 value,
1188 );
1189 }
1190
1191 let data = input.data_vec()?;
1192 let shape = input.shape();
1193 let input_shape = shape.to_vec();
1194 let (out_data, new_shape) = match mode {
1195 PaddingMode::Reflect => pad_3d_reflect(
1196 &data, shape, pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back,
1197 )?,
1198 PaddingMode::Replicate => pad_3d_replicate(
1199 &data, shape, pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back,
1200 ),
1201 PaddingMode::Circular => {
1202 let nd = shape.len();
1203 let (d, h, w) = (shape[nd - 3], shape[nd - 2], shape[nd - 1]);
1204 check_circular_positive(&[
1205 (w, pad_left),
1206 (w, pad_right),
1207 (h, pad_top),
1208 (h, pad_bottom),
1209 (d, pad_front),
1210 (d, pad_back),
1211 ])?;
1212 pad_3d_circular(
1213 &data, shape, pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back,
1214 )
1215 }
1216 PaddingMode::Zeros => {
1217 return functional_pad_3d_signed(
1218 input,
1219 pad_left as isize,
1220 pad_right as isize,
1221 pad_top as isize,
1222 pad_bottom as isize,
1223 pad_front as isize,
1224 pad_back as isize,
1225 mode,
1226 value,
1227 );
1228 }
1229 };
1230
1231 if is_grad_enabled() && input.requires_grad() {
1236 let grad_fn = Arc::new(Pad3dBackward {
1237 input: input.clone(),
1238 input_shape,
1239 mode,
1240 pad_left,
1241 pad_top,
1242 pad_front,
1243 });
1244 return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
1245 }
1246
1247 Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false)
1248}
1249
1250#[inline]
1294fn signed_axis_src(new_idx: usize, size: usize, lo: isize) -> Option<usize> {
1295 let s = new_idx as isize - lo;
1296 if s >= 0 && (s as usize) < size {
1297 Some(s as usize)
1298 } else {
1299 None
1300 }
1301}
1302
1303fn signed_axis_new_size(
1308 size: usize,
1309 lo: isize,
1310 hi: isize,
1311 axis_label: &str,
1312) -> FerrotorchResult<usize> {
1313 let after_left: isize = if lo < 0 {
1316 size as isize + lo
1317 } else {
1318 size as isize
1319 };
1320 if after_left < 0 {
1321 return Err(FerrotorchError::InvalidArgument {
1322 message: format!(
1323 "constant pad: negative padding {lo} on {axis_label} crops more than the dimension size {size} (narrow length would be negative)"
1324 ),
1325 });
1326 }
1327 let after_right: isize = if hi < 0 { after_left + hi } else { after_left };
1329 if after_right < 0 {
1330 return Err(FerrotorchError::InvalidArgument {
1331 message: format!(
1332 "constant pad: negative padding ({lo}, {hi}) on {axis_label} crops more than the dimension size {size}, resulting in a negative output size"
1333 ),
1334 });
1335 }
1336 Ok((after_right + lo.max(0) + hi.max(0)) as usize)
1338}
1339
1340fn pad_nd_signed_constant<T: Float>(
1346 data: &[T],
1347 shape: &[usize],
1348 pads: &[(isize, isize)],
1349 value: T,
1350) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
1351 let ndim = shape.len();
1352 let npad = pads.len();
1353 let mut new_shape = shape.to_vec();
1356 let mut new_sizes = vec![0usize; npad]; for (k, &(lo, hi)) in pads.iter().enumerate() {
1358 let dim = ndim - 1 - k;
1359 let new_size = signed_axis_new_size(shape[dim], lo, hi, &format!("dimension {dim}"))?;
1360 new_sizes[k] = new_size;
1361 new_shape[dim] = new_size;
1362 }
1363
1364 let first_padded = ndim - npad;
1366 let outer: usize = shape[..first_padded]
1367 .iter()
1368 .copied()
1369 .product::<usize>()
1370 .max(1);
1371
1372 let new_total: usize = new_shape.iter().copied().product();
1373 let mut out = vec![value; new_total];
1374
1375 if data.is_empty() {
1382 return Ok((out, new_shape));
1383 }
1384
1385 let in_inner: usize = shape[first_padded..].iter().product();
1389 let out_inner: usize = new_shape[first_padded..].iter().product();
1390
1391 for o in 0..outer {
1393 let in_base = o * in_inner;
1394 let out_base = o * out_inner;
1395 for flat in 0..out_inner {
1396 let mut rem = flat;
1398 let mut src_lin = 0usize;
1399 let mut src_stride = 1usize;
1400 let mut missing = false;
1401 for k in 0..npad {
1403 let dim = ndim - 1 - k;
1404 let axis_new = new_shape[dim];
1405 let coord = rem % axis_new;
1406 rem /= axis_new;
1407 let lo = pads[k].0;
1408 match signed_axis_src(coord, shape[dim], lo) {
1409 Some(s) => {
1410 src_lin += s * src_stride;
1411 src_stride *= shape[dim];
1412 }
1413 None => {
1414 missing = true;
1415 break;
1416 }
1417 }
1418 }
1419 if !missing {
1420 out[out_base + flat] = data[in_base + src_lin];
1421 }
1422 }
1424 }
1425
1426 Ok((out, new_shape))
1427}
1428
1429#[derive(Debug)]
1435struct PadNdSignedBackward<T: Float> {
1436 input: Tensor<T>,
1437 input_shape: Vec<usize>,
1438 pads: Vec<(isize, isize)>,
1440}
1441
1442impl<T: Float> GradFn<T> for PadNdSignedBackward<T> {
1443 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1444 if !self.input.requires_grad() {
1445 return Ok(vec![None]);
1446 }
1447 let ndim = self.input_shape.len();
1448 let npad = self.pads.len();
1449 let first_padded = ndim - npad;
1450 let outer: usize = self.input_shape[..first_padded]
1451 .iter()
1452 .copied()
1453 .product::<usize>()
1454 .max(1);
1455 let in_inner: usize = self.input_shape[first_padded..].iter().product();
1456
1457 let go_shape = grad_output.shape();
1458 let out_inner: usize = go_shape[first_padded..].iter().product();
1459
1460 let go = grad_output.data_vec()?;
1463 let zero = <T as num_traits::Zero>::zero();
1464 let mut grad_in = vec![zero; outer * in_inner];
1465
1466 for o in 0..outer {
1467 let in_base = o * in_inner;
1468 let out_base = o * out_inner;
1469 for flat in 0..out_inner {
1470 let mut rem = flat;
1471 let mut src_lin = 0usize;
1472 let mut src_stride = 1usize;
1473 let mut missing = false;
1474 for k in 0..npad {
1475 let dim = ndim - 1 - k;
1476 let axis_new = go_shape[dim];
1477 let coord = rem % axis_new;
1478 rem /= axis_new;
1479 let lo = self.pads[k].0;
1480 match signed_axis_src(coord, self.input_shape[dim], lo) {
1481 Some(s) => {
1482 src_lin += s * src_stride;
1483 src_stride *= self.input_shape[dim];
1484 }
1485 None => {
1486 missing = true;
1487 break;
1488 }
1489 }
1490 }
1491 if !missing {
1492 grad_in[in_base + src_lin] += go[out_base + flat];
1493 }
1494 }
1495 }
1496
1497 let grad_input =
1498 Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
1499 Ok(vec![Some(grad_input)])
1500 }
1501
1502 fn inputs(&self) -> Vec<&Tensor<T>> {
1503 vec![&self.input]
1504 }
1505
1506 fn name(&self) -> &'static str {
1507 "PadNdSignedBackward"
1508 }
1509}
1510
1511fn functional_pad_nd_positive<T: Float>(
1515 input: &Tensor<T>,
1516 pads: &[(isize, isize)],
1517 mode: PaddingMode,
1518 value: T,
1519) -> FerrotorchResult<Tensor<T>> {
1520 match pads.len() {
1521 1 => functional_pad_1d(input, pads[0].0 as usize, pads[0].1 as usize, mode, value),
1522 2 => functional_pad_2d(
1523 input,
1524 pads[0].0 as usize,
1525 pads[0].1 as usize,
1526 pads[1].0 as usize,
1527 pads[1].1 as usize,
1528 mode,
1529 value,
1530 ),
1531 3 => functional_pad_3d(
1532 input,
1533 pads[0].0 as usize,
1534 pads[0].1 as usize,
1535 pads[1].0 as usize,
1536 pads[1].1 as usize,
1537 pads[2].0 as usize,
1538 pads[2].1 as usize,
1539 mode,
1540 value,
1541 ),
1542 other => Err(FerrotorchError::InvalidArgument {
1543 message: format!("functional_pad_nd_signed supports 1-3 padded dims, got {other}"),
1544 }),
1545 }
1546}
1547
1548#[inline]
1559fn reflect_axis_src(j: usize, size: usize, pad: isize) -> usize {
1560 let j = j as isize;
1561 let size_i = size as isize;
1562 let offset = 0i64.max(-(pad as i64)) - 0i64.max(pad as i64);
1563 let offset = offset as isize;
1564 let i = if j < pad {
1565 pad * 2 - j
1566 } else if j >= pad && j < size_i + pad {
1567 j
1568 } else {
1569 (size_i + pad - 1) * 2 - j
1570 };
1571 (i + offset) as usize
1572}
1573
1574#[inline]
1589fn replicate_axis_src(j: usize, size: usize, pad: isize) -> usize {
1590 let j = j as isize;
1591 let size_i = size as isize;
1592 let offset = 0i64.max(-(pad as i64)) - 0i64.max(pad as i64);
1593 let offset = offset as isize;
1594 let i = if j < pad {
1595 pad
1596 } else if j >= pad && j < size_i + pad {
1597 j
1598 } else {
1599 size_i + pad - 1
1600 };
1601 (i + offset) as usize
1602}
1603
1604#[inline]
1617fn circular_axis_src(j: usize, size: usize, lo: isize, hi: isize) -> isize {
1618 let j = j as isize;
1619 let size_i = size as isize;
1620 let out_w = size_i + lo + hi;
1621 let lo_pos = lo.max(0);
1622 let hi_pos = hi.max(0);
1623 let center = if j < lo_pos {
1626 out_w - lo - hi_pos + j
1628 } else if j >= out_w - hi_pos {
1629 lo_pos + (j - (out_w - hi))
1631 } else {
1632 j
1633 };
1634 lo.min(0).abs() + (center - lo_pos)
1636}
1637
1638fn circular_axis_legality(
1659 size: usize,
1660 lo: isize,
1661 hi: isize,
1662 dim: usize,
1663) -> FerrotorchResult<usize> {
1664 let size_i = size as isize;
1665 if lo > size_i || hi > size_i {
1667 return Err(FerrotorchError::InvalidArgument {
1668 message: format!(
1669 "Circular padding ({lo}, {hi}) causes wrapping around more than once on dimension {dim} (size {size})"
1670 ),
1671 });
1672 }
1673 let out_w = size_i + lo + hi;
1675 if out_w < 0 {
1676 return Err(FerrotorchError::InvalidArgument {
1677 message: format!(
1678 "Circular padding ({lo}, {hi}) on dimension {dim} of size {size} results in a negative output size {out_w} (empty dimension)"
1679 ),
1680 });
1681 }
1682 Ok(out_w as usize)
1683}
1684
1685#[inline]
1691fn circular_slice_range(length: isize, mut start: isize, mut end: isize) -> (usize, usize) {
1692 if start < 0 {
1693 start += length;
1694 }
1695 if end < 0 {
1696 end += length;
1697 }
1698 start = start.clamp(0, length);
1699 end = end.clamp(0, length);
1700 if end < start {
1701 end = start;
1702 }
1703 (start as usize, end as usize)
1704}
1705
1706fn circular_slicecopy_block<T: Float>(
1736 in_block: &[T],
1737 in_inner_shape: &[usize],
1738 out_inner_shape: &[usize],
1739 pads: &[(isize, isize)],
1740) -> FerrotorchResult<Vec<T>> {
1741 let npad = pads.len();
1742 let ninner = in_inner_shape.len();
1743 let out_total: usize = out_inner_shape.iter().product();
1744 let zero = <T as num_traits::Zero>::zero();
1745 let mut out = vec![zero; out_total];
1746 let mut init = vec![false; out_total];
1747
1748 let mut in_strides = vec![1usize; ninner];
1750 let mut out_strides = vec![1usize; ninner];
1751 for d in (0..ninner.saturating_sub(1)).rev() {
1752 in_strides[d] = in_strides[d + 1] * in_inner_shape[d + 1];
1753 out_strides[d] = out_strides[d + 1] * out_inner_shape[d + 1];
1754 }
1755
1756 let pad_for_inner_dim = |d: usize| -> (isize, isize) {
1760 pads[ninner - 1 - d]
1762 };
1763
1764 #[allow(clippy::too_many_arguments)]
1778 fn copy_block<T: Float>(
1779 out: &mut [T],
1780 init: &mut [bool],
1781 read_data: Option<&[T]>,
1782 read_init: Option<&[bool]>,
1783 ninner: usize,
1784 out_strides: &[usize],
1785 read_strides: &[usize],
1786 dst_win: &[(usize, usize)],
1787 src_win: &[(usize, usize)],
1788 ) -> FerrotorchResult<()> {
1789 let mut dst_ext = vec![0usize; ninner];
1792 let mut src_ext = vec![0usize; ninner];
1793 for d in 0..ninner {
1794 dst_ext[d] = dst_win[d].1 - dst_win[d].0;
1795 src_ext[d] = src_win[d].1 - src_win[d].0;
1796 if dst_ext[d] != src_ext[d] && src_ext[d] != 1 {
1797 return Err(FerrotorchError::InvalidArgument {
1798 message: format!(
1799 "Circular padding: a slice copy of source extent {} into destination extent {} is not broadcastable on inner dim {d} (torch raises a size-mismatch here)",
1800 src_ext[d], dst_ext[d]
1801 ),
1802 });
1803 }
1804 }
1805 let total: usize = dst_ext.iter().product();
1806 if total == 0 {
1807 return Ok(()); }
1809 if read_data.is_none() {
1822 let mut wrap_dim: Option<usize> = None;
1823 for d in 0..ninner {
1824 if dst_win[d] != src_win[d] {
1825 wrap_dim = Some(d);
1827 break;
1828 }
1829 }
1830 if let Some(wd) = wrap_dim {
1831 let runs_contiguous = (0..wd).all(|d| dst_ext[d] == 1);
1833 let ds = dst_win[wd];
1834 let ss = src_win[wd];
1835 let overlap = ds.0 < ss.1 && ss.0 < ds.1; let identical = ds == ss;
1837 if runs_contiguous && overlap && !identical {
1838 return Err(FerrotorchError::InvalidArgument {
1839 message:
1840 "Circular padding: torch's wrap copy_ would read and write a single memory location over a contiguous slice (RuntimeError: some elements of the input and written-to tensor refer to a single memory location); ferrotorch rejects rather than fabricate (R-DEV-6)"
1841 .to_string(),
1842 });
1843 }
1844 }
1845 }
1846 let mut coord = vec![0usize; ninner];
1848 for _ in 0..total {
1849 let mut dst_off = 0usize;
1850 let mut src_off = 0usize;
1851 for d in 0..ninner {
1852 let dc = dst_win[d].0 + coord[d];
1853 dst_off += dc * out_strides[d];
1854 let sc = if src_ext[d] == 1 {
1855 src_win[d].0
1856 } else {
1857 src_win[d].0 + coord[d]
1858 };
1859 src_off += sc * read_strides[d];
1860 }
1861 let (v, src_inited) = match (read_data, read_init) {
1866 (Some(rd), ri) => (rd[src_off], ri.map(|m| m[src_off]).unwrap_or(true)),
1867 (None, _) => (out[src_off], init[src_off]),
1868 };
1869 out[dst_off] = v;
1870 init[dst_off] = src_inited;
1871 let mut d = ninner;
1873 while d > 0 {
1874 d -= 1;
1875 coord[d] += 1;
1876 if coord[d] < dst_ext[d] {
1877 break;
1878 }
1879 coord[d] = 0;
1880 }
1881 }
1882 Ok(())
1883 }
1884
1885 let mut dst_win = vec![(0usize, 0usize); ninner];
1887 let mut src_win = vec![(0usize, 0usize); ninner];
1888 for d in 0..ninner {
1889 let out_len = out_inner_shape[d] as isize;
1890 let in_len = in_inner_shape[d] as isize;
1891 if d < ninner - npad {
1892 dst_win[d] = (0, out_inner_shape[d]);
1894 src_win[d] = (0, in_inner_shape[d]);
1895 } else {
1896 let (pl, pr) = pad_for_inner_dim(d);
1897 dst_win[d] = circular_slice_range(out_len, pl.max(0), out_len - pr.max(0));
1898 src_win[d] = circular_slice_range(in_len, (-pl).max(0), in_len - (-pr).max(0));
1899 }
1900 }
1901 copy_block(
1902 &mut out,
1903 &mut init,
1904 Some(in_block),
1905 None,
1906 ninner,
1907 &out_strides,
1908 &in_strides,
1909 &dst_win,
1910 &src_win,
1911 )?;
1912
1913 for (k, &(pl, pr)) in pads.iter().enumerate() {
1923 let dim = ninner - 1 - k;
1931 let out_len = out_inner_shape[dim] as isize;
1932 if pl > 0 {
1933 let mut dwin = vec![(0usize, 0usize); ninner];
1934 let mut swin = vec![(0usize, 0usize); ninner];
1935 for d in 0..ninner {
1936 dwin[d] = (0, out_inner_shape[d]);
1937 swin[d] = (0, out_inner_shape[d]);
1938 }
1939 dwin[dim] = circular_slice_range(out_len, 0, pl);
1940 swin[dim] =
1941 circular_slice_range(out_len, out_len - pl - pr.max(0), out_len - pr.max(0));
1942 copy_block(
1943 &mut out,
1944 &mut init,
1945 None,
1946 None,
1947 ninner,
1948 &out_strides,
1949 &out_strides,
1950 &dwin,
1951 &swin,
1952 )?;
1953 }
1954 if pr > 0 {
1955 let mut dwin = vec![(0usize, 0usize); ninner];
1956 let mut swin = vec![(0usize, 0usize); ninner];
1957 for d in 0..ninner {
1958 dwin[d] = (0, out_inner_shape[d]);
1959 swin[d] = (0, out_inner_shape[d]);
1960 }
1961 dwin[dim] = circular_slice_range(out_len, out_len - pr, out_len);
1962 swin[dim] = circular_slice_range(out_len, pl.max(0), pl.max(0) + pr);
1963 copy_block(
1964 &mut out,
1965 &mut init,
1966 None,
1967 None,
1968 ninner,
1969 &out_strides,
1970 &out_strides,
1971 &dwin,
1972 &swin,
1973 )?;
1974 }
1975 }
1976
1977 if init.iter().any(|&b| !b) {
1983 return Err(FerrotorchError::InvalidArgument {
1984 message:
1985 "Circular padding crops the center below the wrap width, so torch reads uninitialized memory (no byte-for-byte contract; R-DEV-6)"
1986 .to_string(),
1987 });
1988 }
1989 Ok(out)
1990}
1991
1992struct CircularCopyOp {
1999 from_input: bool,
2000 pairs: Vec<(usize, usize)>,
2001}
2002
2003fn circular_slicecopy_backward_block<T: Float>(
2021 go_block: &[T],
2022 in_inner_shape: &[usize],
2023 out_inner_shape: &[usize],
2024 pads: &[(isize, isize)],
2025) -> Vec<T> {
2026 let npad = pads.len();
2027 let ninner = in_inner_shape.len();
2028 let in_total: usize = in_inner_shape.iter().product();
2029
2030 let mut in_strides = vec![1usize; ninner];
2031 let mut out_strides = vec![1usize; ninner];
2032 for d in (0..ninner.saturating_sub(1)).rev() {
2033 in_strides[d] = in_strides[d + 1] * in_inner_shape[d + 1];
2034 out_strides[d] = out_strides[d + 1] * out_inner_shape[d + 1];
2035 }
2036
2037 let pad_for_inner_dim = |d: usize| -> (isize, isize) { pads[ninner - 1 - d] };
2038
2039 let enum_pairs = |dst_win: &[(usize, usize)],
2044 src_win: &[(usize, usize)],
2045 src_strides: &[usize]|
2046 -> Vec<(usize, usize)> {
2047 let mut dst_ext = vec![0usize; ninner];
2048 let mut src_ext = vec![0usize; ninner];
2049 for d in 0..ninner {
2050 dst_ext[d] = dst_win[d].1 - dst_win[d].0;
2051 src_ext[d] = src_win[d].1 - src_win[d].0;
2052 }
2053 let total: usize = dst_ext.iter().product();
2054 let mut pairs = Vec::with_capacity(total);
2055 if total == 0 {
2056 return pairs;
2057 }
2058 let mut coord = vec![0usize; ninner];
2059 for _ in 0..total {
2060 let mut dst_off = 0usize;
2061 let mut src_off = 0usize;
2062 for d in 0..ninner {
2063 dst_off += (dst_win[d].0 + coord[d]) * out_strides[d];
2064 let sc = if src_ext[d] == 1 {
2065 src_win[d].0
2066 } else {
2067 src_win[d].0 + coord[d]
2068 };
2069 src_off += sc * src_strides[d];
2070 }
2071 pairs.push((dst_off, src_off));
2072 let mut d = ninner;
2073 while d > 0 {
2074 d -= 1;
2075 coord[d] += 1;
2076 if coord[d] < dst_ext[d] {
2077 break;
2078 }
2079 coord[d] = 0;
2080 }
2081 }
2082 pairs
2083 };
2084
2085 let mut ops: Vec<CircularCopyOp> = Vec::new();
2086
2087 let mut dst_win = vec![(0usize, 0usize); ninner];
2089 let mut src_win = vec![(0usize, 0usize); ninner];
2090 for d in 0..ninner {
2091 let out_len = out_inner_shape[d] as isize;
2092 let in_len = in_inner_shape[d] as isize;
2093 if d < ninner - npad {
2094 dst_win[d] = (0, out_inner_shape[d]);
2095 src_win[d] = (0, in_inner_shape[d]);
2096 } else {
2097 let (pl, pr) = pad_for_inner_dim(d);
2098 dst_win[d] = circular_slice_range(out_len, pl.max(0), out_len - pr.max(0));
2099 src_win[d] = circular_slice_range(in_len, (-pl).max(0), in_len - (-pr).max(0));
2100 }
2101 }
2102 ops.push(CircularCopyOp {
2103 from_input: true,
2104 pairs: enum_pairs(&dst_win, &src_win, &in_strides),
2105 });
2106
2107 for (k, &(pl, pr)) in pads.iter().enumerate() {
2110 let dim = ninner - 1 - k;
2111 let out_len = out_inner_shape[dim] as isize;
2112 if pl > 0 {
2113 let mut dwin = vec![(0usize, 0usize); ninner];
2114 let mut swin = vec![(0usize, 0usize); ninner];
2115 for d in 0..ninner {
2116 dwin[d] = (0, out_inner_shape[d]);
2117 swin[d] = (0, out_inner_shape[d]);
2118 }
2119 dwin[dim] = circular_slice_range(out_len, 0, pl);
2120 swin[dim] =
2121 circular_slice_range(out_len, out_len - pl - pr.max(0), out_len - pr.max(0));
2122 ops.push(CircularCopyOp {
2123 from_input: false,
2124 pairs: enum_pairs(&dwin, &swin, &out_strides),
2125 });
2126 }
2127 if pr > 0 {
2128 let mut dwin = vec![(0usize, 0usize); ninner];
2129 let mut swin = vec![(0usize, 0usize); ninner];
2130 for d in 0..ninner {
2131 dwin[d] = (0, out_inner_shape[d]);
2132 swin[d] = (0, out_inner_shape[d]);
2133 }
2134 dwin[dim] = circular_slice_range(out_len, out_len - pr, out_len);
2135 swin[dim] = circular_slice_range(out_len, pl.max(0), pl.max(0) + pr);
2136 ops.push(CircularCopyOp {
2137 from_input: false,
2138 pairs: enum_pairs(&dwin, &swin, &out_strides),
2139 });
2140 }
2141 }
2142
2143 let zero = <T as num_traits::Zero>::zero();
2152 let mut grad_out = go_block.to_vec();
2153 let mut grad_in = vec![zero; in_total];
2154 for op in ops.iter().rev() {
2155 if op.from_input {
2156 for &(d, s) in &op.pairs {
2157 grad_in[s] += grad_out[d];
2158 grad_out[d] = zero;
2159 }
2160 } else {
2161 let mut contrib: Vec<(usize, T)> = Vec::with_capacity(op.pairs.len());
2168 for &(d, s) in &op.pairs {
2169 contrib.push((s, grad_out[d]));
2170 }
2171 for &(d, _) in &op.pairs {
2172 grad_out[d] = zero;
2173 }
2174 for (s, v) in contrib {
2175 grad_out[s] += v;
2176 }
2177 }
2178 }
2179 grad_in
2180}
2181
2182#[inline]
2189fn signed_mode_axis_src(mode: PaddingMode, j: usize, size: usize, lo: isize, hi: isize) -> usize {
2190 match mode {
2191 PaddingMode::Reflect => reflect_axis_src(j, size, lo),
2192 PaddingMode::Replicate => replicate_axis_src(j, size, lo),
2193 PaddingMode::Circular => circular_axis_src(j, size, lo, hi) as usize,
2194 PaddingMode::Zeros => (j as isize - lo).clamp(0, size as isize - 1) as usize,
2199 }
2200}
2201
2202fn pad_nd_signed_reflect_circular<T: Float>(
2215 data: &[T],
2216 shape: &[usize],
2217 pads: &[(isize, isize)],
2218 mode: PaddingMode,
2219) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
2220 let ndim = shape.len();
2221 let npad = pads.len();
2222 let mut new_shape = shape.to_vec();
2223 let per_axis_min: isize = isize::from(npad == 1);
2236 for (k, &(lo, hi)) in pads.iter().enumerate() {
2237 let dim = ndim - 1 - k;
2238 let size = shape[dim] as isize;
2239 if mode == PaddingMode::Reflect && (lo >= size || hi >= size) {
2247 return Err(FerrotorchError::InvalidArgument {
2248 message: format!(
2249 "Reflection padding ({lo}, {hi}) must be less than input size ({size}) on dimension {dim}"
2250 ),
2251 });
2252 }
2253 if mode == PaddingMode::Replicate && size == 0 {
2258 return Err(FerrotorchError::InvalidArgument {
2259 message: format!(
2260 "Replication padding cannot replicate an empty input dimension {dim} (size 0)"
2261 ),
2262 });
2263 }
2264 let new_size: usize = if mode == PaddingMode::Circular {
2275 circular_axis_legality(shape[dim], lo, hi, dim)?
2276 } else {
2277 let n = size + lo + hi;
2278 if n < per_axis_min {
2279 return Err(FerrotorchError::InvalidArgument {
2280 message: format!(
2281 "padding ({lo}, {hi}) on dimension {dim} of size {size} yields output size {n} below the minimum {per_axis_min} for this rank"
2282 ),
2283 });
2284 }
2285 n as usize
2286 };
2287 new_shape[dim] = new_size;
2288 }
2289
2290 if npad >= 2
2295 && matches!(mode, PaddingMode::Reflect | PaddingMode::Replicate)
2296 && pads
2297 .iter()
2298 .enumerate()
2299 .all(|(k, _)| new_shape[ndim - 1 - k] == 0)
2300 {
2301 return Err(FerrotorchError::InvalidArgument {
2302 message: format!(
2303 "{mode:?} padding collapses every padded spatial axis to size 0 (torch requires at least one >= 1)"
2304 ),
2305 });
2306 }
2307
2308 let first_padded = ndim - npad;
2309 let outer: usize = shape[..first_padded]
2310 .iter()
2311 .copied()
2312 .product::<usize>()
2313 .max(1);
2314 let in_inner: usize = shape[first_padded..].iter().product();
2315 let out_inner: usize = new_shape[first_padded..].iter().product();
2316 let zero = <T as num_traits::Zero>::zero();
2317 let new_total: usize = new_shape.iter().copied().product();
2318 let mut out = vec![zero; new_total];
2319
2320 if mode == PaddingMode::Circular {
2330 let in_inner_shape = &shape[first_padded..];
2331 let out_inner_shape = &new_shape[first_padded..];
2332 for o in 0..outer {
2333 let in_block = &data[o * in_inner..(o + 1) * in_inner];
2334 let out_block =
2335 circular_slicecopy_block(in_block, in_inner_shape, out_inner_shape, pads)?;
2336 out[o * out_inner..(o + 1) * out_inner].copy_from_slice(&out_block);
2337 }
2338 return Ok((out, new_shape));
2339 }
2340
2341 for o in 0..outer {
2345 let in_base = o * in_inner;
2346 let out_base = o * out_inner;
2347 for flat in 0..out_inner {
2348 let mut rem = flat;
2349 let mut src_lin = 0usize;
2350 let mut src_stride = 1usize;
2351 for k in 0..npad {
2352 let dim = ndim - 1 - k;
2353 let axis_new = new_shape[dim];
2354 let coord = rem % axis_new;
2355 rem /= axis_new;
2356 let (lo, hi) = pads[k];
2357 let s = signed_mode_axis_src(mode, coord, shape[dim], lo, hi);
2358 src_lin += s * src_stride;
2359 src_stride *= shape[dim];
2360 }
2361 out[out_base + flat] = data[in_base + src_lin];
2362 }
2363 }
2364
2365 Ok((out, new_shape))
2366}
2367
2368#[derive(Debug)]
2373struct PadNdSignedModeBackward<T: Float> {
2374 input: Tensor<T>,
2375 input_shape: Vec<usize>,
2376 mode: PaddingMode,
2377 pads: Vec<(isize, isize)>,
2379}
2380
2381impl<T: Float> GradFn<T> for PadNdSignedModeBackward<T> {
2382 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2383 if !self.input.requires_grad() {
2384 return Ok(vec![None]);
2385 }
2386 let ndim = self.input_shape.len();
2387 let npad = self.pads.len();
2388 let first_padded = ndim - npad;
2389 let outer: usize = self.input_shape[..first_padded]
2390 .iter()
2391 .copied()
2392 .product::<usize>()
2393 .max(1);
2394 let in_inner: usize = self.input_shape[first_padded..].iter().product();
2395
2396 let go_shape = grad_output.shape();
2397 let out_inner: usize = go_shape[first_padded..].iter().product();
2398
2399 let go = grad_output.data_vec()?;
2400 let zero = <T as num_traits::Zero>::zero();
2401 let mut grad_in = vec![zero; outer * in_inner];
2402
2403 if self.mode == PaddingMode::Circular {
2404 let in_inner_shape = &self.input_shape[first_padded..];
2418 let out_inner_shape = &go_shape[first_padded..];
2419 for o in 0..outer {
2420 let in_base = o * in_inner;
2421 let out_base = o * out_inner;
2422 let go_block = &go[out_base..out_base + out_inner];
2423 let gi_block = circular_slicecopy_backward_block(
2424 go_block,
2425 in_inner_shape,
2426 out_inner_shape,
2427 &self.pads,
2428 );
2429 for (i, &v) in gi_block.iter().enumerate() {
2430 grad_in[in_base + i] += v;
2431 }
2432 }
2433 let grad_input =
2434 Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
2435 return Ok(vec![Some(grad_input)]);
2436 }
2437
2438 for o in 0..outer {
2439 let in_base = o * in_inner;
2440 let out_base = o * out_inner;
2441 for flat in 0..out_inner {
2442 let mut rem = flat;
2443 let mut src_lin = 0usize;
2444 let mut src_stride = 1usize;
2445 for k in 0..npad {
2446 let dim = ndim - 1 - k;
2447 let axis_new = go_shape[dim];
2448 let coord = rem % axis_new;
2449 rem /= axis_new;
2450 let (lo, hi) = self.pads[k];
2451 let s = signed_mode_axis_src(self.mode, coord, self.input_shape[dim], lo, hi);
2452 src_lin += s * src_stride;
2453 src_stride *= self.input_shape[dim];
2454 }
2455 grad_in[in_base + src_lin] += go[out_base + flat];
2456 }
2457 }
2458
2459 let grad_input =
2460 Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
2461 Ok(vec![Some(grad_input)])
2462 }
2463
2464 fn inputs(&self) -> Vec<&Tensor<T>> {
2465 vec![&self.input]
2466 }
2467
2468 fn name(&self) -> &'static str {
2469 "PadNdSignedModeBackward"
2470 }
2471}
2472
2473fn functional_pad_nd_signed<T: Float>(
2496 input: &Tensor<T>,
2497 pads: &[(isize, isize)],
2498 mode: PaddingMode,
2499 value: T,
2500) -> FerrotorchResult<Tensor<T>> {
2501 let has_negative = pads.iter().any(|&(lo, hi)| lo < 0 || hi < 0);
2502
2503 if mode != PaddingMode::Zeros {
2504 if !has_negative {
2505 return functional_pad_nd_positive(input, pads, mode, value);
2507 }
2508 let data = input.data_vec()?;
2527 let shape = input.shape();
2528 if pads.len() > shape.len() {
2529 return Err(FerrotorchError::InvalidArgument {
2530 message: format!(
2531 "pad targets {} dims but input has only {} dims",
2532 pads.len(),
2533 shape.len()
2534 ),
2535 });
2536 }
2537 let input_shape = shape.to_vec();
2538 let (out_data, new_shape) = pad_nd_signed_reflect_circular(&data, shape, pads, mode)?;
2539 if is_grad_enabled() && input.requires_grad() {
2540 let grad_fn = Arc::new(PadNdSignedModeBackward {
2541 input: input.clone(),
2542 input_shape,
2543 mode,
2544 pads: pads.to_vec(),
2545 });
2546 return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
2547 }
2548 return Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false);
2549 }
2550
2551 let data = input.data_vec()?;
2552 let shape = input.shape();
2553 if pads.len() > shape.len() {
2554 return Err(FerrotorchError::InvalidArgument {
2555 message: format!(
2556 "pad targets {} dims but input has only {} dims",
2557 pads.len(),
2558 shape.len()
2559 ),
2560 });
2561 }
2562 let input_shape = shape.to_vec();
2563 let (out_data, new_shape) = pad_nd_signed_constant(&data, shape, pads, value)?;
2564
2565 if is_grad_enabled() && input.requires_grad() {
2568 let grad_fn = Arc::new(PadNdSignedBackward {
2569 input: input.clone(),
2570 input_shape,
2571 pads: pads.to_vec(),
2572 });
2573 return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
2574 }
2575
2576 Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false)
2577}
2578
2579pub fn functional_pad_1d_signed<T: Float>(
2593 input: &Tensor<T>,
2594 pad_left: isize,
2595 pad_right: isize,
2596 mode: PaddingMode,
2597 value: T,
2598) -> FerrotorchResult<Tensor<T>> {
2599 functional_pad_nd_signed(input, &[(pad_left, pad_right)], mode, value)
2600}
2601
2602pub fn functional_pad_2d_signed<T: Float>(
2606 input: &Tensor<T>,
2607 pad_left: isize,
2608 pad_right: isize,
2609 pad_top: isize,
2610 pad_bottom: isize,
2611 mode: PaddingMode,
2612 value: T,
2613) -> FerrotorchResult<Tensor<T>> {
2614 functional_pad_nd_signed(
2616 input,
2617 &[(pad_left, pad_right), (pad_top, pad_bottom)],
2618 mode,
2619 value,
2620 )
2621}
2622
2623#[allow(clippy::too_many_arguments)]
2629pub fn functional_pad_3d_signed<T: Float>(
2630 input: &Tensor<T>,
2631 pad_left: isize,
2632 pad_right: isize,
2633 pad_top: isize,
2634 pad_bottom: isize,
2635 pad_front: isize,
2636 pad_back: isize,
2637 mode: PaddingMode,
2638 value: T,
2639) -> FerrotorchResult<Tensor<T>> {
2640 functional_pad_nd_signed(
2642 input,
2643 &[
2644 (pad_left, pad_right),
2645 (pad_top, pad_bottom),
2646 (pad_front, pad_back),
2647 ],
2648 mode,
2649 value,
2650 )
2651}
2652
2653macro_rules! impl_padding_module {
2658 ($name:ident) => {
2659 impl<T: Float> Module<T> for $name<T> {
2660 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2661 self.pad(input)
2662 }
2663
2664 fn parameters(&self) -> Vec<&Parameter<T>> {
2665 vec![]
2666 }
2667
2668 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2669 vec![]
2670 }
2671
2672 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2673 vec![]
2674 }
2675
2676 fn train(&mut self) {
2677 self.training = true;
2678 }
2679
2680 fn eval(&mut self) {
2681 self.training = false;
2682 }
2683
2684 fn is_training(&self) -> bool {
2685 self.training
2686 }
2687 }
2688 };
2689}
2690
2691#[derive(Debug)]
2701pub struct ConstantPad1d<T: Float> {
2702 pub padding: (usize, usize),
2704 pub value: T,
2706 training: bool,
2707}
2708
2709impl<T: Float> ConstantPad1d<T> {
2710 pub fn new(padding: (usize, usize), value: T) -> Self {
2711 Self {
2712 padding,
2713 value,
2714 training: true,
2715 }
2716 }
2717
2718 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2719 let data = input.data_vec()?;
2720 let (out, new_shape) = pad_1d_constant(
2721 &data,
2722 input.shape(),
2723 self.padding.0,
2724 self.padding.1,
2725 self.value,
2726 );
2727 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2728 }
2729}
2730
2731impl_padding_module!(ConstantPad1d);
2732
2733#[derive(Debug)]
2739pub struct ConstantPad2d<T: Float> {
2740 pub padding: (usize, usize, usize, usize),
2742 pub value: T,
2744 training: bool,
2745}
2746
2747impl<T: Float> ConstantPad2d<T> {
2748 pub fn new(padding: (usize, usize, usize, usize), value: T) -> Self {
2749 Self {
2750 padding,
2751 value,
2752 training: true,
2753 }
2754 }
2755
2756 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2757 if input.ndim() < 2 {
2758 return Err(FerrotorchError::InvalidArgument {
2759 message: format!(
2760 "ConstantPad2d expects at least 2-D input, got {:?}",
2761 input.shape()
2762 ),
2763 });
2764 }
2765 let data = input.data_vec()?;
2766 let (out, new_shape) = pad_2d_constant(
2767 &data,
2768 input.shape(),
2769 self.padding.0,
2770 self.padding.1,
2771 self.padding.2,
2772 self.padding.3,
2773 self.value,
2774 );
2775 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2776 }
2777}
2778
2779impl_padding_module!(ConstantPad2d);
2780
2781#[derive(Debug)]
2787pub struct ConstantPad3d<T: Float> {
2788 pub padding: (usize, usize, usize, usize, usize, usize),
2790 pub value: T,
2792 training: bool,
2793}
2794
2795impl<T: Float> ConstantPad3d<T> {
2796 pub fn new(padding: (usize, usize, usize, usize, usize, usize), value: T) -> Self {
2797 Self {
2798 padding,
2799 value,
2800 training: true,
2801 }
2802 }
2803
2804 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2805 if input.ndim() < 3 {
2806 return Err(FerrotorchError::InvalidArgument {
2807 message: format!(
2808 "ConstantPad3d expects at least 3-D input, got {:?}",
2809 input.shape()
2810 ),
2811 });
2812 }
2813 let data = input.data_vec()?;
2814 let (out, new_shape) = pad_3d_constant(
2815 &data,
2816 input.shape(),
2817 self.padding.0,
2818 self.padding.1,
2819 self.padding.2,
2820 self.padding.3,
2821 self.padding.4,
2822 self.padding.5,
2823 self.value,
2824 );
2825 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2826 }
2827}
2828
2829impl_padding_module!(ConstantPad3d);
2830
2831#[derive(Debug)]
2837pub struct ZeroPad1d<T: Float> {
2838 pub padding: (usize, usize),
2839 training: bool,
2840 _phantom: std::marker::PhantomData<T>,
2841}
2842
2843impl<T: Float> ZeroPad1d<T> {
2844 pub fn new(padding: (usize, usize)) -> Self {
2845 Self {
2846 padding,
2847 training: true,
2848 _phantom: std::marker::PhantomData,
2849 }
2850 }
2851
2852 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2853 let data = input.data_vec()?;
2854 let zero = <T as num_traits::Zero>::zero();
2855 let (out, new_shape) =
2856 pad_1d_constant(&data, input.shape(), self.padding.0, self.padding.1, zero);
2857 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2858 }
2859}
2860
2861impl_padding_module!(ZeroPad1d);
2862
2863#[derive(Debug)]
2865pub struct ZeroPad2d<T: Float> {
2866 pub padding: (usize, usize, usize, usize),
2867 training: bool,
2868 _phantom: std::marker::PhantomData<T>,
2869}
2870
2871impl<T: Float> ZeroPad2d<T> {
2872 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2873 Self {
2874 padding,
2875 training: true,
2876 _phantom: std::marker::PhantomData,
2877 }
2878 }
2879
2880 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2881 if input.ndim() < 2 {
2882 return Err(FerrotorchError::InvalidArgument {
2883 message: format!(
2884 "ZeroPad2d expects at least 2-D input, got {:?}",
2885 input.shape()
2886 ),
2887 });
2888 }
2889 let data = input.data_vec()?;
2890 let zero = <T as num_traits::Zero>::zero();
2891 let (out, new_shape) = pad_2d_constant(
2892 &data,
2893 input.shape(),
2894 self.padding.0,
2895 self.padding.1,
2896 self.padding.2,
2897 self.padding.3,
2898 zero,
2899 );
2900 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2901 }
2902}
2903
2904impl_padding_module!(ZeroPad2d);
2905
2906#[derive(Debug)]
2908pub struct ZeroPad3d<T: Float> {
2909 pub padding: (usize, usize, usize, usize, usize, usize),
2910 training: bool,
2911 _phantom: std::marker::PhantomData<T>,
2912}
2913
2914impl<T: Float> ZeroPad3d<T> {
2915 pub fn new(padding: (usize, usize, usize, usize, usize, usize)) -> Self {
2916 Self {
2917 padding,
2918 training: true,
2919 _phantom: std::marker::PhantomData,
2920 }
2921 }
2922
2923 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2924 if input.ndim() < 3 {
2925 return Err(FerrotorchError::InvalidArgument {
2926 message: format!(
2927 "ZeroPad3d expects at least 3-D input, got {:?}",
2928 input.shape()
2929 ),
2930 });
2931 }
2932 let data = input.data_vec()?;
2933 let zero = <T as num_traits::Zero>::zero();
2934 let (out, new_shape) = pad_3d_constant(
2935 &data,
2936 input.shape(),
2937 self.padding.0,
2938 self.padding.1,
2939 self.padding.2,
2940 self.padding.3,
2941 self.padding.4,
2942 self.padding.5,
2943 zero,
2944 );
2945 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2946 }
2947}
2948
2949impl_padding_module!(ZeroPad3d);
2950
2951#[derive(Debug)]
2957pub struct ReflectionPad1d<T: Float> {
2958 pub padding: (usize, usize),
2959 training: bool,
2960 _phantom: std::marker::PhantomData<T>,
2961}
2962
2963impl<T: Float> ReflectionPad1d<T> {
2964 pub fn new(padding: (usize, usize)) -> Self {
2965 Self {
2966 padding,
2967 training: true,
2968 _phantom: std::marker::PhantomData,
2969 }
2970 }
2971
2972 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2973 let data = input.data_vec()?;
2974 let (out, new_shape) =
2975 pad_1d_reflect(&data, input.shape(), self.padding.0, self.padding.1)?;
2976 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2977 }
2978}
2979
2980impl_padding_module!(ReflectionPad1d);
2981
2982#[derive(Debug)]
2984pub struct ReflectionPad2d<T: Float> {
2985 pub padding: (usize, usize, usize, usize),
2986 training: bool,
2987 _phantom: std::marker::PhantomData<T>,
2988}
2989
2990impl<T: Float> ReflectionPad2d<T> {
2991 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2992 Self {
2993 padding,
2994 training: true,
2995 _phantom: std::marker::PhantomData,
2996 }
2997 }
2998
2999 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3000 if input.ndim() < 2 {
3001 return Err(FerrotorchError::InvalidArgument {
3002 message: format!(
3003 "ReflectionPad2d expects at least 2-D input, got {:?}",
3004 input.shape()
3005 ),
3006 });
3007 }
3008 let data = input.data_vec()?;
3009 let (out, new_shape) = pad_2d_reflect(
3010 &data,
3011 input.shape(),
3012 self.padding.0,
3013 self.padding.1,
3014 self.padding.2,
3015 self.padding.3,
3016 )?;
3017 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3018 }
3019}
3020
3021impl_padding_module!(ReflectionPad2d);
3022
3023#[derive(Debug)]
3025pub struct ReflectionPad3d<T: Float> {
3026 pub padding: (usize, usize, usize, usize, usize, usize),
3027 training: bool,
3028 _phantom: std::marker::PhantomData<T>,
3029}
3030
3031impl<T: Float> ReflectionPad3d<T> {
3032 pub fn new(padding: (usize, usize, usize, usize, usize, usize)) -> Self {
3033 Self {
3034 padding,
3035 training: true,
3036 _phantom: std::marker::PhantomData,
3037 }
3038 }
3039
3040 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3041 if input.ndim() < 3 {
3042 return Err(FerrotorchError::InvalidArgument {
3043 message: format!(
3044 "ReflectionPad3d expects at least 3-D input, got {:?}",
3045 input.shape()
3046 ),
3047 });
3048 }
3049 let data = input.data_vec()?;
3050 let (out, new_shape) = pad_3d_reflect(
3051 &data,
3052 input.shape(),
3053 self.padding.0,
3054 self.padding.1,
3055 self.padding.2,
3056 self.padding.3,
3057 self.padding.4,
3058 self.padding.5,
3059 )?;
3060 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3061 }
3062}
3063
3064impl_padding_module!(ReflectionPad3d);
3065
3066#[derive(Debug)]
3072pub struct ReplicationPad1d<T: Float> {
3073 pub padding: (usize, usize),
3074 training: bool,
3075 _phantom: std::marker::PhantomData<T>,
3076}
3077
3078impl<T: Float> ReplicationPad1d<T> {
3079 pub fn new(padding: (usize, usize)) -> Self {
3080 Self {
3081 padding,
3082 training: true,
3083 _phantom: std::marker::PhantomData,
3084 }
3085 }
3086
3087 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3088 let data = input.data_vec()?;
3089 let (out, new_shape) =
3090 pad_1d_replicate(&data, input.shape(), self.padding.0, self.padding.1);
3091 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3092 }
3093}
3094
3095impl_padding_module!(ReplicationPad1d);
3096
3097#[derive(Debug)]
3099pub struct ReplicationPad2d<T: Float> {
3100 pub padding: (usize, usize, usize, usize),
3101 training: bool,
3102 _phantom: std::marker::PhantomData<T>,
3103}
3104
3105impl<T: Float> ReplicationPad2d<T> {
3106 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
3107 Self {
3108 padding,
3109 training: true,
3110 _phantom: std::marker::PhantomData,
3111 }
3112 }
3113
3114 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3115 if input.ndim() < 2 {
3116 return Err(FerrotorchError::InvalidArgument {
3117 message: format!(
3118 "ReplicationPad2d expects at least 2-D input, got {:?}",
3119 input.shape()
3120 ),
3121 });
3122 }
3123 let data = input.data_vec()?;
3124 let (out, new_shape) = pad_2d_replicate(
3125 &data,
3126 input.shape(),
3127 self.padding.0,
3128 self.padding.1,
3129 self.padding.2,
3130 self.padding.3,
3131 );
3132 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3133 }
3134}
3135
3136impl_padding_module!(ReplicationPad2d);
3137
3138#[derive(Debug)]
3140pub struct ReplicationPad3d<T: Float> {
3141 pub padding: (usize, usize, usize, usize, usize, usize),
3142 training: bool,
3143 _phantom: std::marker::PhantomData<T>,
3144}
3145
3146impl<T: Float> ReplicationPad3d<T> {
3147 pub fn new(padding: (usize, usize, usize, usize, usize, usize)) -> Self {
3148 Self {
3149 padding,
3150 training: true,
3151 _phantom: std::marker::PhantomData,
3152 }
3153 }
3154
3155 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3156 if input.ndim() < 3 {
3157 return Err(FerrotorchError::InvalidArgument {
3158 message: format!(
3159 "ReplicationPad3d expects at least 3-D input, got {:?}",
3160 input.shape()
3161 ),
3162 });
3163 }
3164 let data = input.data_vec()?;
3165 let (out, new_shape) = pad_3d_replicate(
3166 &data,
3167 input.shape(),
3168 self.padding.0,
3169 self.padding.1,
3170 self.padding.2,
3171 self.padding.3,
3172 self.padding.4,
3173 self.padding.5,
3174 );
3175 Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3176 }
3177}
3178
3179impl_padding_module!(ReplicationPad3d);
3180
3181#[derive(Debug, Clone)]
3190pub struct CircularPad1d<T: Float> {
3191 pub padding: (usize, usize),
3192 training: bool,
3193 _phantom: std::marker::PhantomData<T>,
3194}
3195
3196impl<T: Float> CircularPad1d<T> {
3197 pub fn new(padding: (usize, usize)) -> Self {
3198 Self {
3199 padding,
3200 training: true,
3201 _phantom: std::marker::PhantomData,
3202 }
3203 }
3204
3205 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3206 if input.ndim() != 3 {
3207 return Err(FerrotorchError::InvalidArgument {
3208 message: format!(
3209 "CircularPad1d: expected 3-D input [N,C,W], got {:?}",
3210 input.shape()
3211 ),
3212 });
3213 }
3214 if input.is_cuda() {
3215 return Err(FerrotorchError::NotImplementedOnCuda {
3216 op: "CircularPad1d",
3217 });
3218 }
3219 let shape = input.shape();
3220 let (n, c, w) = (shape[0], shape[1], shape[2]);
3221 let (pl, pr) = self.padding;
3222 let new_w = w + pl + pr;
3223 let data = input.data()?;
3224 let zero = <T as num_traits::Zero>::zero();
3225 let mut out = vec![zero; n * c * new_w];
3226
3227 for batch in 0..n {
3228 for ch in 0..c {
3229 for ow in 0..new_w {
3230 let iw = ((ow as isize - pl as isize).rem_euclid(w as isize)) as usize;
3231 out[batch * c * new_w + ch * new_w + ow] = data[batch * c * w + ch * w + iw];
3232 }
3233 }
3234 }
3235
3236 Tensor::from_storage(TensorStorage::cpu(out), vec![n, c, new_w], false)
3237 }
3238}
3239
3240impl<T: Float> Default for CircularPad1d<T> {
3241 fn default() -> Self {
3242 Self::new((0, 0))
3243 }
3244}
3245
3246impl_padding_module!(CircularPad1d);
3247
3248#[derive(Debug, Clone)]
3251pub struct CircularPad2d<T: Float> {
3252 pub padding: (usize, usize, usize, usize),
3253 training: bool,
3254 _phantom: std::marker::PhantomData<T>,
3255}
3256
3257impl<T: Float> CircularPad2d<T> {
3258 pub fn new(padding: (usize, usize, usize, usize)) -> Self {
3259 Self {
3260 padding,
3261 training: true,
3262 _phantom: std::marker::PhantomData,
3263 }
3264 }
3265
3266 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3267 if input.ndim() != 4 {
3268 return Err(FerrotorchError::InvalidArgument {
3269 message: format!(
3270 "CircularPad2d: expected 4-D input [N,C,H,W], got {:?}",
3271 input.shape()
3272 ),
3273 });
3274 }
3275 if input.is_cuda() {
3276 return Err(FerrotorchError::NotImplementedOnCuda {
3277 op: "CircularPad2d",
3278 });
3279 }
3280 let shape = input.shape();
3281 let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
3282 let (pl, pr, pt, pb) = self.padding;
3283 let new_h = h + pt + pb;
3284 let new_w = w + pl + pr;
3285 let data = input.data()?;
3286 let zero = <T as num_traits::Zero>::zero();
3287 let mut out = vec![zero; n * c * new_h * new_w];
3288
3289 for batch in 0..n {
3290 for ch in 0..c {
3291 for oh in 0..new_h {
3292 let ih = ((oh as isize - pt as isize).rem_euclid(h as isize)) as usize;
3293 for ow in 0..new_w {
3294 let iw = ((ow as isize - pl as isize).rem_euclid(w as isize)) as usize;
3295 out[batch * c * new_h * new_w + ch * new_h * new_w + oh * new_w + ow] =
3296 data[batch * c * h * w + ch * h * w + ih * w + iw];
3297 }
3298 }
3299 }
3300 }
3301
3302 Tensor::from_storage(TensorStorage::cpu(out), vec![n, c, new_h, new_w], false)
3303 }
3304}
3305
3306impl<T: Float> Default for CircularPad2d<T> {
3307 fn default() -> Self {
3308 Self::new((0, 0, 0, 0))
3309 }
3310}
3311
3312impl_padding_module!(CircularPad2d);
3313
3314#[derive(Debug, Clone)]
3317pub struct CircularPad3d<T: Float> {
3318 pub padding: (usize, usize, usize, usize, usize, usize),
3319 training: bool,
3320 _phantom: std::marker::PhantomData<T>,
3321}
3322
3323impl<T: Float> CircularPad3d<T> {
3324 pub fn new(padding: (usize, usize, usize, usize, usize, usize)) -> Self {
3325 Self {
3326 padding,
3327 training: true,
3328 _phantom: std::marker::PhantomData,
3329 }
3330 }
3331
3332 fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3333 if input.ndim() != 5 {
3334 return Err(FerrotorchError::InvalidArgument {
3335 message: format!(
3336 "CircularPad3d: expected 5-D input [N,C,D,H,W], got {:?}",
3337 input.shape()
3338 ),
3339 });
3340 }
3341 if input.is_cuda() {
3342 return Err(FerrotorchError::NotImplementedOnCuda {
3343 op: "CircularPad3d",
3344 });
3345 }
3346 let shape = input.shape();
3347 let (n, c, d, h, w) = (shape[0], shape[1], shape[2], shape[3], shape[4]);
3348 let (pl, pr, pt, pb, pf, pk) = self.padding;
3349 let (new_d, new_h, new_w) = (d + pf + pk, h + pt + pb, w + pl + pr);
3350 let data = input.data()?;
3351 let zero = <T as num_traits::Zero>::zero();
3352 let mut out = vec![zero; n * c * new_d * new_h * new_w];
3353
3354 for batch in 0..n {
3355 for ch in 0..c {
3356 for od in 0..new_d {
3357 let id = ((od as isize - pf as isize).rem_euclid(d as isize)) as usize;
3358 for oh in 0..new_h {
3359 let ih = ((oh as isize - pt as isize).rem_euclid(h as isize)) as usize;
3360 for ow in 0..new_w {
3361 let iw = ((ow as isize - pl as isize).rem_euclid(w as isize)) as usize;
3362 out[batch * c * new_d * new_h * new_w
3363 + ch * new_d * new_h * new_w
3364 + od * new_h * new_w
3365 + oh * new_w
3366 + ow] = data
3367 [batch * c * d * h * w + ch * d * h * w + id * h * w + ih * w + iw];
3368 }
3369 }
3370 }
3371 }
3372 }
3373
3374 Tensor::from_storage(
3375 TensorStorage::cpu(out),
3376 vec![n, c, new_d, new_h, new_w],
3377 false,
3378 )
3379 }
3380}
3381
3382impl<T: Float> Default for CircularPad3d<T> {
3383 fn default() -> Self {
3384 Self::new((0, 0, 0, 0, 0, 0))
3385 }
3386}
3387
3388impl_padding_module!(CircularPad3d);
3389
3390#[cfg(test)]
3395mod tests {
3396 use super::*;
3397 use crate::module::Module;
3398
3399 fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
3400 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
3401 }
3402
3403 fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
3404 assert_eq!(
3405 actual.len(),
3406 expected.len(),
3407 "length mismatch: {} vs {}",
3408 actual.len(),
3409 expected.len()
3410 );
3411 for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
3412 assert!((a - e).abs() < tol, "index {i}: actual={a} expected={e}");
3413 }
3414 }
3415
3416 #[test]
3421 fn test_constant_pad1d_basic() {
3422 let pad = ConstantPad1d::<f32>::new((2, 3), 9.0);
3423 let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3424 let output = pad.forward(&input).unwrap();
3425 assert_eq!(output.shape(), &[1, 1, 8]);
3426 assert_close(
3427 output.data().unwrap(),
3428 &[9.0, 9.0, 1.0, 2.0, 3.0, 9.0, 9.0, 9.0],
3429 1e-7,
3430 );
3431 }
3432
3433 #[test]
3438 fn test_zero_pad1d() {
3439 let pad = ZeroPad1d::<f32>::new((1, 2));
3440 let input = t(&[1.0, 2.0, 3.0], &[3]);
3441 let output = pad.forward(&input).unwrap();
3442 assert_eq!(output.shape(), &[6]);
3443 assert_close(
3444 output.data().unwrap(),
3445 &[0.0, 1.0, 2.0, 3.0, 0.0, 0.0],
3446 1e-7,
3447 );
3448 }
3449
3450 #[test]
3455 fn test_zero_pad2d() {
3456 let pad = ZeroPad2d::<f32>::new((1, 1, 1, 1));
3457 let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
3458 let output = pad.forward(&input).unwrap();
3459 assert_eq!(output.shape(), &[1, 1, 4, 4]);
3460 #[rustfmt::skip]
3461 let expected = [
3462 0.0, 0.0, 0.0, 0.0,
3463 0.0, 1.0, 2.0, 0.0,
3464 0.0, 3.0, 4.0, 0.0,
3465 0.0, 0.0, 0.0, 0.0,
3466 ];
3467 assert_close(output.data().unwrap(), &expected, 1e-7);
3468 }
3469
3470 #[test]
3475 fn test_zero_pad3d_shape() {
3476 let pad = ZeroPad3d::<f32>::new((1, 1, 1, 1, 1, 1));
3477 let input = t(&[1.0; 2 * 2 * 2], &[1, 1, 2, 2, 2]);
3478 let output = pad.forward(&input).unwrap();
3479 assert_eq!(output.shape(), &[1, 1, 4, 4, 4]);
3480 }
3481
3482 #[test]
3487 fn test_reflection_pad1d() {
3488 let pad = ReflectionPad1d::<f32>::new((2, 2));
3489 let input = t(&[1.0, 2.0, 3.0, 4.0], &[4]);
3491 let output = pad.forward(&input).unwrap();
3492 assert_eq!(output.shape(), &[8]);
3493 assert_close(
3495 output.data().unwrap(),
3496 &[3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 2.0],
3497 1e-7,
3498 );
3499 }
3500
3501 #[test]
3502 fn test_reflection_pad1d_too_large() {
3503 let pad = ReflectionPad1d::<f32>::new((4, 0));
3504 let input = t(&[1.0, 2.0, 3.0], &[3]); assert!(pad.forward(&input).is_err());
3506 }
3507
3508 #[test]
3513 fn test_reflection_pad2d() {
3514 let pad = ReflectionPad2d::<f32>::new((1, 1, 1, 1));
3515 #[rustfmt::skip]
3516 let input = t(&[
3517 1.0, 2.0, 3.0,
3518 4.0, 5.0, 6.0,
3519 7.0, 8.0, 9.0,
3520 ], &[1, 1, 3, 3]);
3521 let output = pad.forward(&input).unwrap();
3522 assert_eq!(output.shape(), &[1, 1, 5, 5]);
3523 let out = output.data().unwrap();
3525 assert_close(&out[0..1], &[5.0], 1e-7); }
3527
3528 #[test]
3533 fn test_replication_pad1d() {
3534 let pad = ReplicationPad1d::<f32>::new((2, 3));
3535 let input = t(&[1.0, 2.0, 3.0], &[3]);
3536 let output = pad.forward(&input).unwrap();
3537 assert_eq!(output.shape(), &[8]);
3538 assert_close(
3539 output.data().unwrap(),
3540 &[1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0],
3541 1e-7,
3542 );
3543 }
3544
3545 #[test]
3550 fn test_replication_pad2d() {
3551 let pad = ReplicationPad2d::<f32>::new((1, 1, 1, 1));
3552 #[rustfmt::skip]
3553 let input = t(&[
3554 1.0, 2.0,
3555 3.0, 4.0,
3556 ], &[1, 1, 2, 2]);
3557 let output = pad.forward(&input).unwrap();
3558 assert_eq!(output.shape(), &[1, 1, 4, 4]);
3559 #[rustfmt::skip]
3560 let expected = [
3561 1.0, 1.0, 2.0, 2.0,
3562 1.0, 1.0, 2.0, 2.0,
3563 3.0, 3.0, 4.0, 4.0,
3564 3.0, 3.0, 4.0, 4.0,
3565 ];
3566 assert_close(output.data().unwrap(), &expected, 1e-7);
3567 }
3568
3569 #[test]
3574 fn test_constant_pad2d() {
3575 let pad = ConstantPad2d::<f32>::new((1, 1, 1, 1), -1.0);
3576 let input = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
3577 let output = pad.forward(&input).unwrap();
3578 assert_eq!(output.shape(), &[4, 4]);
3579 #[rustfmt::skip]
3580 let expected = [
3581 -1.0, -1.0, -1.0, -1.0,
3582 -1.0, 5.0, 6.0, -1.0,
3583 -1.0, 7.0, 8.0, -1.0,
3584 -1.0, -1.0, -1.0, -1.0,
3585 ];
3586 assert_close(output.data().unwrap(), &expected, 1e-7);
3587 }
3588
3589 #[test]
3594 fn test_constant_pad3d_shape() {
3595 let pad = ConstantPad3d::<f32>::new((1, 2, 1, 2, 1, 2), 0.0);
3596 let input = t(&vec![1.0; 3 * 4 * 5], &[1, 1, 3, 4, 5]);
3597 let output = pad.forward(&input).unwrap();
3598 assert_eq!(output.shape(), &[1, 1, 6, 7, 8]);
3599 }
3600
3601 #[test]
3606 fn test_circular_pad_1d() {
3607 let data = [1.0f32, 2.0, 3.0, 4.0];
3610 let (out, new_shape) = pad_1d_circular(&data, &[4], 1, 2);
3611 assert_eq!(new_shape, &[7]);
3612 assert_close(&out, &[4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0], 1e-7);
3613 }
3614
3615 #[test]
3620 fn test_padding_mode_eq() {
3621 assert_eq!(PaddingMode::Zeros, PaddingMode::Zeros);
3622 assert_ne!(PaddingMode::Zeros, PaddingMode::Reflect);
3623 }
3624
3625 #[test]
3630 fn test_padding_module_no_params() {
3631 let pad = ZeroPad2d::<f32>::new((1, 1, 1, 1));
3632 assert!(pad.parameters().is_empty());
3633 assert!(pad.named_parameters().is_empty());
3634 }
3635
3636 #[test]
3637 fn test_padding_module_train_eval() {
3638 let mut pad = ReflectionPad1d::<f32>::new((1, 1));
3639 assert!(pad.is_training());
3640 pad.eval();
3641 assert!(!pad.is_training());
3642 pad.train();
3643 assert!(pad.is_training());
3644 }
3645
3646 #[test]
3662 fn test_constant_pad1d_empty_numel_no_panic() {
3663 let (out, new_shape) = pad_1d_constant::<f32>(&[], &[0, 3], 2, 3, 7.0);
3665 assert_eq!(new_shape, vec![0, 8]);
3667 assert!(out.iter().all(|&v| v == 7.0));
3669 }
3670
3671 #[test]
3672 fn test_constant_pad2d_empty_numel_no_panic() {
3673 let (out, new_shape) = pad_2d_constant::<f32>(&[], &[0, 2, 3], 1, 1, 1, 1, 5.0);
3675 assert_eq!(new_shape, vec![0, 4, 5]);
3676 assert!(out.iter().all(|&v| v == 5.0));
3677 }
3678
3679 #[test]
3680 fn test_constant_pad3d_empty_numel_no_panic() {
3681 let (out, new_shape) = pad_3d_constant::<f32>(&[], &[0, 2, 2, 3], 1, 1, 1, 1, 1, 1, 3.0);
3683 assert_eq!(new_shape, vec![0, 4, 4, 5]);
3684 assert!(out.iter().all(|&v| v == 3.0));
3685 }
3686
3687 #[test]
3699 fn test_functional_pad_1d_constant_uses_value() {
3700 let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3701 let out = functional_pad_1d(&input, 1, 1, PaddingMode::Zeros, 2.0).unwrap();
3702 assert_eq!(out.shape(), &[1, 1, 5]);
3703 assert_close(out.data().unwrap(), &[2.0, 1.0, 2.0, 3.0, 2.0], 1e-7);
3705 }
3706
3707 #[test]
3708 fn test_functional_pad_2d_constant_uses_value() {
3709 let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
3711 let out = functional_pad_2d(&input, 1, 1, 1, 1, PaddingMode::Zeros, 2.0).unwrap();
3712 assert_eq!(out.shape(), &[1, 1, 4, 4]);
3713 #[rustfmt::skip]
3714 let expected = [
3715 2.0, 2.0, 2.0, 2.0,
3716 2.0, 1.0, 2.0, 2.0,
3717 2.0, 3.0, 4.0, 2.0,
3718 2.0, 2.0, 2.0, 2.0,
3719 ];
3720 assert_close(out.data().unwrap(), &expected, 1e-7);
3721 assert!(out.data().unwrap().iter().all(|&v| v != 0.0));
3723 }
3724
3725 #[test]
3726 fn test_functional_pad_3d_constant_uses_value() {
3727 let input = t(&[5.0], &[1, 1, 1, 1, 1]);
3729 let out = functional_pad_3d(&input, 1, 1, 0, 0, 0, 0, PaddingMode::Zeros, 2.0).unwrap();
3730 assert_eq!(out.shape(), &[1, 1, 1, 1, 3]);
3731 assert_close(out.data().unwrap(), &[2.0, 5.0, 2.0], 1e-7);
3732 }
3733
3734 fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
3746 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
3747 }
3748
3749 #[test]
3753 fn test_functional_pad_1d_reflect_backward_matches_torch() {
3754 let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
3755 let y = functional_pad_1d(&x, 2, 2, PaddingMode::Reflect, 0.0).unwrap();
3756 assert_eq!(y.shape(), &[1, 1, 8]);
3757 assert!(
3758 y.grad_fn().is_some(),
3759 "functional_pad_1d Reflect lost grad_fn — would sever Conv1d autograd (#1550 class)"
3760 );
3761 assert_eq!(y.grad_fn().unwrap().name(), "Pad1dBackward");
3762 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3763 ferrotorch_core::backward(&sum).unwrap();
3764 let g = x.grad().unwrap().expect("grad must be populated");
3765 assert_close(g.data().unwrap(), &[1.0, 3.0, 3.0, 1.0], 1e-5);
3766 }
3767
3768 #[test]
3772 fn test_functional_pad_3d_circular_backward_matches_torch() {
3773 let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
3774 let x = leaf(&x_data, &[1, 1, 2, 2, 2]);
3775 let y = functional_pad_3d(&x, 1, 1, 1, 1, 1, 1, PaddingMode::Circular, 0.0).unwrap();
3776 assert_eq!(y.shape(), &[1, 1, 4, 4, 4]);
3777 assert!(y.grad_fn().is_some());
3778 assert_eq!(y.grad_fn().unwrap().name(), "Pad3dBackward");
3779 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3780 ferrotorch_core::backward(&sum).unwrap();
3781 let g = x.grad().unwrap().expect("grad must be populated");
3782 assert_close(g.data().unwrap(), &[8.0; 8], 1e-5);
3783 }
3784
3785 #[test]
3802 fn test_functional_pad_1d_signed_crop_both_matches_torch() {
3803 let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
3804 let y = functional_pad_1d_signed(&x, -1, -1, PaddingMode::Zeros, 0.0).unwrap();
3805 assert_eq!(y.shape(), &[1, 1, 3]);
3806 assert_close(y.data().unwrap(), &[2.0, 3.0, 4.0], 1e-7);
3807 assert_eq!(y.grad_fn().unwrap().name(), "PadNdSignedBackward");
3808 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3809 ferrotorch_core::backward(&sum).unwrap();
3810 let g = x.grad().unwrap().expect("grad must be populated");
3811 assert_close(g.data().unwrap(), &[0.0, 1.0, 1.0, 1.0, 0.0], 1e-7);
3812 }
3813
3814 #[test]
3819 fn test_functional_pad_1d_signed_mixed_matches_torch() {
3820 let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
3821 let y = functional_pad_1d_signed(&x, -1, 2, PaddingMode::Zeros, 9.0).unwrap();
3822 assert_eq!(y.shape(), &[1, 1, 5]);
3823 assert_close(y.data().unwrap(), &[2.0, 3.0, 4.0, 9.0, 9.0], 1e-7);
3824 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3825 ferrotorch_core::backward(&sum).unwrap();
3826 let g = x.grad().unwrap().expect("grad must be populated");
3827 assert_close(g.data().unwrap(), &[0.0, 1.0, 1.0, 1.0], 1e-7);
3828 }
3829
3830 #[test]
3834 fn test_functional_pad_2d_signed_crop_matches_torch() {
3835 #[rustfmt::skip]
3836 let x = leaf(&[
3837 1.0, 2.0, 3.0,
3838 4.0, 5.0, 6.0,
3839 7.0, 8.0, 9.0,
3840 ], &[1, 1, 3, 3]);
3841 let y = functional_pad_2d_signed(&x, -1, 0, 0, -1, PaddingMode::Zeros, 0.0).unwrap();
3842 assert_eq!(y.shape(), &[1, 1, 2, 2]);
3843 assert_close(y.data().unwrap(), &[2.0, 3.0, 5.0, 6.0], 1e-7);
3844 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3845 ferrotorch_core::backward(&sum).unwrap();
3846 let g = x.grad().unwrap().expect("grad must be populated");
3847 assert_close(
3848 g.data().unwrap(),
3849 &[0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
3850 1e-7,
3851 );
3852 }
3853
3854 #[test]
3859 fn test_functional_pad_2d_signed_mixed_matches_torch() {
3860 #[rustfmt::skip]
3861 let x = leaf(&[
3862 1.0, 2.0, 3.0,
3863 4.0, 5.0, 6.0,
3864 ], &[1, 1, 2, 3]);
3865 let y = functional_pad_2d_signed(&x, -1, 2, 1, -1, PaddingMode::Zeros, 7.0).unwrap();
3866 assert_eq!(y.shape(), &[1, 1, 2, 4]);
3867 #[rustfmt::skip]
3868 let expected = [
3869 7.0, 7.0, 7.0, 7.0,
3870 2.0, 3.0, 7.0, 7.0,
3871 ];
3872 assert_close(y.data().unwrap(), &expected, 1e-7);
3873 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3874 ferrotorch_core::backward(&sum).unwrap();
3875 let g = x.grad().unwrap().expect("grad must be populated");
3876 assert_close(g.data().unwrap(), &[0.0, 1.0, 1.0, 0.0, 0.0, 0.0], 1e-7);
3877 }
3878
3879 #[test]
3883 fn test_functional_pad_3d_signed_crop_matches_torch() {
3884 let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
3885 let x = leaf(&x_data, &[1, 1, 2, 2, 2]);
3886 let y = functional_pad_3d_signed(&x, -1, 0, 0, -1, -1, 0, PaddingMode::Zeros, 0.0).unwrap();
3887 assert_eq!(y.shape(), &[1, 1, 1, 1, 1]);
3888 assert_close(y.data().unwrap(), &[6.0], 1e-7);
3889 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3890 ferrotorch_core::backward(&sum).unwrap();
3891 let g = x.grad().unwrap().expect("grad must be populated");
3892 assert_close(
3893 g.data().unwrap(),
3894 &[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
3895 1e-7,
3896 );
3897 }
3898
3899 #[test]
3903 fn test_functional_pad_3d_signed_mixed_matches_torch() {
3904 let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
3905 let x = leaf(&x_data, &[1, 1, 2, 2, 2]);
3906 let y = functional_pad_3d_signed(&x, 1, -1, 0, 1, -1, 2, PaddingMode::Zeros, 3.0).unwrap();
3907 assert_eq!(y.shape(), &[1, 1, 3, 3, 2]);
3908 #[rustfmt::skip]
3909 let expected = [
3910 3.0, 5.0, 3.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0,
3911 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
3912 ];
3913 assert_close(y.data().unwrap(), &expected, 1e-7);
3914 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3915 ferrotorch_core::backward(&sum).unwrap();
3916 let g = x.grad().unwrap().expect("grad must be populated");
3917 assert_close(
3918 g.data().unwrap(),
3919 &[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
3920 1e-7,
3921 );
3922 }
3923
3924 #[test]
3929 fn test_functional_pad_1d_signed_over_crop_errors() {
3930 let x = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3932 assert!(
3933 functional_pad_1d_signed(&x, -4, 0, PaddingMode::Zeros, 0.0).is_err(),
3934 "single-side over-crop must error like torch narrow()"
3935 );
3936 assert!(
3938 functional_pad_1d_signed(&x, -2, -2, PaddingMode::Zeros, 0.0).is_err(),
3939 "combined net-negative crop must error like torch"
3940 );
3941 assert!(
3943 functional_pad_1d_signed(&x, -1, -3, PaddingMode::Zeros, 0.0).is_err(),
3944 "right-after-left over-crop must error like torch"
3945 );
3946 }
3947
3948 #[test]
3951 fn test_functional_pad_1d_signed_net_zero_empty_dim_matches_torch() {
3952 let x = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3953 let y = functional_pad_1d_signed(&x, -1, -2, PaddingMode::Zeros, 0.0).unwrap();
3954 assert_eq!(y.shape(), &[1, 1, 0]);
3955 assert!(y.data().unwrap().is_empty());
3956 }
3957
3958 #[test]
3966 fn test_functional_pad_signed_negative_non_constant_crops() {
3967 let x = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
3968 for mode in [
3969 PaddingMode::Reflect,
3970 PaddingMode::Replicate,
3971 PaddingMode::Circular,
3972 ] {
3973 let y = functional_pad_1d_signed(&x, -1, 0, mode, 0.0)
3974 .unwrap_or_else(|_| panic!("negative pad under {mode:?} must crop, not error"));
3975 assert_eq!(
3976 y.shape(),
3977 &[1, 1, 3],
3978 "{mode:?} crops left -> shape [1,1,3]"
3979 );
3980 assert_close(y.data().unwrap(), &[2.0, 3.0, 4.0], 1e-7);
3981 }
3982 }
3983
3984 #[test]
3990 fn test_functional_pad_1d_signed_nonneg_equals_positive_path() {
3991 let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3992 let signed = functional_pad_1d_signed(&input, 1, 1, PaddingMode::Zeros, 2.0).unwrap();
3993 let positive = functional_pad_1d(&input, 1, 1, PaddingMode::Zeros, 2.0).unwrap();
3994 assert_eq!(signed.shape(), positive.shape());
3995 assert_close(signed.data().unwrap(), positive.data().unwrap(), 1e-7);
3996 assert_close(signed.data().unwrap(), &[2.0, 1.0, 2.0, 3.0, 2.0], 1e-7);
3997 }
3998}