1use std::sync::Arc;
35
36use ferrotorch_core::autograd::no_grad::is_grad_enabled;
37use ferrotorch_core::tensor::GradFn;
38use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
39
40use crate::module::Module;
41use crate::parameter::Parameter;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum InterpolateMode {
50 Nearest,
52 Bilinear,
54 Bicubic,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum GridSamplePaddingMode {
61 Zeros,
63 Border,
65 Reflection,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum GridSampleMode {
72 Bilinear,
74 Nearest,
76}
77
78fn validate_4d<T: Float>(
84 input: &Tensor<T>,
85 fn_name: &str,
86) -> FerrotorchResult<(usize, usize, usize, usize)> {
87 let shape = input.shape();
88 if shape.len() != 4 {
89 return Err(FerrotorchError::InvalidArgument {
90 message: format!(
91 "{fn_name} expects 4D input [B, C, H, W], got shape {:?}",
92 shape
93 ),
94 });
95 }
96 Ok((shape[0], shape[1], shape[2], shape[3]))
97}
98
99#[inline]
104fn cubic_weight(t: f64) -> f64 {
105 let abs_t = t.abs();
106 let a: f64 = -0.75;
107
108 if abs_t <= 1.0 {
109 (a + 2.0) * abs_t * abs_t * abs_t - (a + 3.0) * abs_t * abs_t + 1.0
110 } else if abs_t < 2.0 {
111 a * abs_t * abs_t * abs_t - 5.0 * a * abs_t * abs_t + 8.0 * a * abs_t - 4.0 * a
112 } else {
113 0.0
114 }
115}
116
117#[inline]
122fn align_corners_coord(i: usize, in_size: usize, out_size: usize) -> f64 {
123 if out_size <= 1 {
124 return 0.0;
125 }
126 (i as f64) * ((in_size - 1) as f64) / ((out_size - 1) as f64)
127}
128
129#[inline]
133fn half_pixel_coord(i: usize, in_size: usize, out_size: usize) -> f64 {
134 (i as f64 + 0.5) * (in_size as f64 / out_size as f64) - 0.5
135}
136
137#[inline]
139fn clamp_coord(val: isize, max: usize) -> usize {
140 if val < 0 {
141 0
142 } else if val as usize > max {
143 max
144 } else {
145 val as usize
146 }
147}
148
149pub fn interpolate<T: Float>(
158 input: &Tensor<T>,
159 size: Option<[usize; 2]>,
160 scale_factor: Option<[f64; 2]>,
161 mode: InterpolateMode,
162 align_corners: bool,
163) -> FerrotorchResult<Tensor<T>> {
164 let (batch, channels, h_in, w_in) = validate_4d(input, "interpolate")?;
165
166 let (h_out, w_out) = match (size, scale_factor) {
168 (Some(s), None) => (s[0], s[1]),
169 (None, Some(sf)) => {
170 let h = (h_in as f64 * sf[0]).round() as usize;
171 let w = (w_in as f64 * sf[1]).round() as usize;
172 if h == 0 || w == 0 {
173 return Err(FerrotorchError::InvalidArgument {
174 message: format!(
175 "interpolate: scale_factor {sf:?} with input ({h_in}, {w_in}) produces zero output"
176 ),
177 });
178 }
179 (h, w)
180 }
181 _ => {
182 return Err(FerrotorchError::InvalidArgument {
183 message: "interpolate: exactly one of size or scale_factor must be provided".into(),
184 });
185 }
186 };
187
188 if h_out == 0 || w_out == 0 {
189 return Err(FerrotorchError::InvalidArgument {
190 message: format!("interpolate: output size ({h_out}, {w_out}) must be > 0"),
191 });
192 }
193
194 if mode == InterpolateMode::Nearest && align_corners {
195 return Err(FerrotorchError::InvalidArgument {
196 message: "interpolate: align_corners is not supported with nearest mode".into(),
197 });
198 }
199
200 let input_device = input.device();
201 let data = input.data_vec()?;
202
203 let total = batch * channels * h_out * w_out;
204 let mut output = vec![T::from(0.0).unwrap(); total];
205
206 match mode {
207 InterpolateMode::Nearest => {
208 nearest_forward(
209 &data,
210 &mut output,
211 batch,
212 channels,
213 h_in,
214 w_in,
215 h_out,
216 w_out,
217 );
218 }
219 InterpolateMode::Bilinear => {
220 bilinear_forward(
221 &data,
222 &mut output,
223 batch,
224 channels,
225 h_in,
226 w_in,
227 h_out,
228 w_out,
229 align_corners,
230 );
231 }
232 InterpolateMode::Bicubic => {
233 bicubic_forward(
234 &data,
235 &mut output,
236 batch,
237 channels,
238 h_in,
239 w_in,
240 h_out,
241 w_out,
242 align_corners,
243 );
244 }
245 }
246
247 let out_shape = vec![batch, channels, h_out, w_out];
248 let storage = TensorStorage::cpu(output);
249
250 if is_grad_enabled() && input.requires_grad() {
251 Tensor::from_operation(
252 storage,
253 out_shape,
254 Arc::new(InterpolateBackward {
255 input: input.clone(),
256 h_out,
257 w_out,
258 mode,
259 align_corners,
260 }),
261 )?
262 .to(input_device)
263 } else {
264 Tensor::from_storage(storage, out_shape, false)?.to(input_device)
265 }
266}
267
268#[allow(clippy::too_many_arguments)]
276fn nearest_forward<T: Float>(
277 data: &[T],
278 output: &mut [T],
279 batch: usize,
280 channels: usize,
281 h_in: usize,
282 w_in: usize,
283 h_out: usize,
284 w_out: usize,
285) {
286 let h_scale = h_in as f64 / h_out as f64;
287 let w_scale = w_in as f64 / w_out as f64;
288
289 for b in 0..batch {
290 for c in 0..channels {
291 for oh in 0..h_out {
292 let ih = ((oh as f64 * h_scale).floor() as usize).min(h_in - 1);
293 for ow in 0..w_out {
294 let iw = ((ow as f64 * w_scale).floor() as usize).min(w_in - 1);
295 let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
296 let in_idx = ((b * channels + c) * h_in + ih) * w_in + iw;
297 output[out_idx] = data[in_idx];
298 }
299 }
300 }
301 }
302}
303
304#[allow(clippy::too_many_arguments)]
306fn bilinear_forward<T: Float>(
307 data: &[T],
308 output: &mut [T],
309 batch: usize,
310 channels: usize,
311 h_in: usize,
312 w_in: usize,
313 h_out: usize,
314 w_out: usize,
315 align_corners: bool,
316) {
317 let one = T::from(1.0).unwrap();
318
319 for b in 0..batch {
320 for c in 0..channels {
321 for oh in 0..h_out {
322 let src_h = if align_corners {
323 align_corners_coord(oh, h_in, h_out)
324 } else {
325 half_pixel_coord(oh, h_in, h_out)
326 };
327
328 let h0 = src_h.floor() as isize;
329 let h1 = h0 + 1;
330 let th = T::from(src_h - h0 as f64).unwrap();
331
332 for ow in 0..w_out {
333 let src_w = if align_corners {
334 align_corners_coord(ow, w_in, w_out)
335 } else {
336 half_pixel_coord(ow, w_in, w_out)
337 };
338
339 let w0 = src_w.floor() as isize;
340 let w1 = w0 + 1;
341 let tw = T::from(src_w - w0 as f64).unwrap();
342
343 let ch0 = clamp_coord(h0, h_in - 1);
344 let ch1 = clamp_coord(h1, h_in - 1);
345 let cw0 = clamp_coord(w0, w_in - 1);
346 let cw1 = clamp_coord(w1, w_in - 1);
347
348 let base = (b * channels + c) * h_in;
349 let v00 = data[(base + ch0) * w_in + cw0];
350 let v01 = data[(base + ch0) * w_in + cw1];
351 let v10 = data[(base + ch1) * w_in + cw0];
352 let v11 = data[(base + ch1) * w_in + cw1];
353
354 let val = v00 * (one - th) * (one - tw)
355 + v01 * (one - th) * tw
356 + v10 * th * (one - tw)
357 + v11 * th * tw;
358
359 let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
360 output[out_idx] = val;
361 }
362 }
363 }
364 }
365}
366
367#[allow(clippy::too_many_arguments)]
369fn bicubic_forward<T: Float>(
370 data: &[T],
371 output: &mut [T],
372 batch: usize,
373 channels: usize,
374 h_in: usize,
375 w_in: usize,
376 h_out: usize,
377 w_out: usize,
378 align_corners: bool,
379) {
380 for b in 0..batch {
381 for c in 0..channels {
382 for oh in 0..h_out {
383 let src_h = if align_corners {
384 align_corners_coord(oh, h_in, h_out)
385 } else {
386 half_pixel_coord(oh, h_in, h_out)
387 };
388
389 let h_floor = src_h.floor() as isize;
390 let frac_h = src_h - h_floor as f64;
391
392 let wh: [T; 4] = [
394 T::from(cubic_weight(frac_h + 1.0)).unwrap(),
395 T::from(cubic_weight(frac_h)).unwrap(),
396 T::from(cubic_weight(frac_h - 1.0)).unwrap(),
397 T::from(cubic_weight(frac_h - 2.0)).unwrap(),
398 ];
399
400 for ow in 0..w_out {
401 let src_w = if align_corners {
402 align_corners_coord(ow, w_in, w_out)
403 } else {
404 half_pixel_coord(ow, w_in, w_out)
405 };
406
407 let w_floor = src_w.floor() as isize;
408 let frac_w = src_w - w_floor as f64;
409
410 let ww: [T; 4] = [
411 T::from(cubic_weight(frac_w + 1.0)).unwrap(),
412 T::from(cubic_weight(frac_w)).unwrap(),
413 T::from(cubic_weight(frac_w - 1.0)).unwrap(),
414 T::from(cubic_weight(frac_w - 2.0)).unwrap(),
415 ];
416
417 let mut val = T::from(0.0).unwrap();
418 let base = (b * channels + c) * h_in;
419
420 for dy in 0..4isize {
421 let iy = clamp_coord(h_floor - 1 + dy, h_in - 1);
422 for dx in 0..4isize {
423 let ix = clamp_coord(w_floor - 1 + dx, w_in - 1);
424 let pixel = data[(base + iy) * w_in + ix];
425 val += pixel * wh[dy as usize] * ww[dx as usize];
426 }
427 }
428
429 let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
430 output[out_idx] = val;
431 }
432 }
433 }
434 }
435}
436
437#[derive(Debug)]
442struct InterpolateBackward<T: Float> {
443 input: Tensor<T>,
444 h_out: usize,
445 w_out: usize,
446 mode: InterpolateMode,
447 align_corners: bool,
448}
449
450impl<T: Float> GradFn<T> for InterpolateBackward<T> {
451 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
452 if !self.input.requires_grad() {
453 return Ok(vec![None]);
454 }
455
456 let in_shape = self.input.shape();
457 let (batch, channels, h_in, w_in) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
458 let h_out = self.h_out;
459 let w_out = self.w_out;
460
461 let go_data = grad_output.data_vec()?;
462 let mut grad_input = vec![T::from(0.0).unwrap(); batch * channels * h_in * w_in];
463
464 match self.mode {
465 InterpolateMode::Nearest => {
466 nearest_backward(
467 &go_data,
468 &mut grad_input,
469 batch,
470 channels,
471 h_in,
472 w_in,
473 h_out,
474 w_out,
475 );
476 }
477 InterpolateMode::Bilinear => {
478 bilinear_backward(
479 &go_data,
480 &mut grad_input,
481 batch,
482 channels,
483 h_in,
484 w_in,
485 h_out,
486 w_out,
487 self.align_corners,
488 );
489 }
490 InterpolateMode::Bicubic => {
491 bicubic_backward(
492 &go_data,
493 &mut grad_input,
494 batch,
495 channels,
496 h_in,
497 w_in,
498 h_out,
499 w_out,
500 self.align_corners,
501 );
502 }
503 }
504
505 let grad_tensor = Tensor::from_storage(
506 TensorStorage::cpu(grad_input),
507 self.input.shape().to_vec(),
508 false,
509 )?;
510 Ok(vec![Some(grad_tensor)])
511 }
512
513 fn inputs(&self) -> Vec<&Tensor<T>> {
514 vec![&self.input]
515 }
516
517 fn name(&self) -> &'static str {
518 "InterpolateBackward"
519 }
520}
521
522#[allow(clippy::too_many_arguments)]
524fn nearest_backward<T: Float>(
525 go: &[T],
526 grad_input: &mut [T],
527 batch: usize,
528 channels: usize,
529 h_in: usize,
530 w_in: usize,
531 h_out: usize,
532 w_out: usize,
533) {
534 let h_scale = h_in as f64 / h_out as f64;
535 let w_scale = w_in as f64 / w_out as f64;
536
537 for b in 0..batch {
538 for c in 0..channels {
539 for oh in 0..h_out {
540 let ih = ((oh as f64 * h_scale).floor() as usize).min(h_in - 1);
541 for ow in 0..w_out {
542 let iw = ((ow as f64 * w_scale).floor() as usize).min(w_in - 1);
543 let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
544 let in_idx = ((b * channels + c) * h_in + ih) * w_in + iw;
545 grad_input[in_idx] += go[out_idx];
546 }
547 }
548 }
549 }
550}
551
552#[allow(clippy::too_many_arguments)]
554fn bilinear_backward<T: Float>(
555 go: &[T],
556 grad_input: &mut [T],
557 batch: usize,
558 channels: usize,
559 h_in: usize,
560 w_in: usize,
561 h_out: usize,
562 w_out: usize,
563 align_corners: bool,
564) {
565 let one = T::from(1.0).unwrap();
566
567 for b in 0..batch {
568 for c in 0..channels {
569 for oh in 0..h_out {
570 let src_h = if align_corners {
571 align_corners_coord(oh, h_in, h_out)
572 } else {
573 half_pixel_coord(oh, h_in, h_out)
574 };
575
576 let h0 = src_h.floor() as isize;
577 let h1 = h0 + 1;
578 let th = T::from(src_h - h0 as f64).unwrap();
579
580 for ow in 0..w_out {
581 let src_w = if align_corners {
582 align_corners_coord(ow, w_in, w_out)
583 } else {
584 half_pixel_coord(ow, w_in, w_out)
585 };
586
587 let w0 = src_w.floor() as isize;
588 let w1 = w0 + 1;
589 let tw = T::from(src_w - w0 as f64).unwrap();
590
591 let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
592 let g = go[out_idx];
593
594 let ch0 = clamp_coord(h0, h_in - 1);
595 let ch1 = clamp_coord(h1, h_in - 1);
596 let cw0 = clamp_coord(w0, w_in - 1);
597 let cw1 = clamp_coord(w1, w_in - 1);
598
599 let base = (b * channels + c) * h_in;
600
601 grad_input[(base + ch0) * w_in + cw0] += g * (one - th) * (one - tw);
602 grad_input[(base + ch0) * w_in + cw1] += g * (one - th) * tw;
603 grad_input[(base + ch1) * w_in + cw0] += g * th * (one - tw);
604 grad_input[(base + ch1) * w_in + cw1] += g * th * tw;
605 }
606 }
607 }
608 }
609}
610
611#[allow(clippy::too_many_arguments)]
613fn bicubic_backward<T: Float>(
614 go: &[T],
615 grad_input: &mut [T],
616 batch: usize,
617 channels: usize,
618 h_in: usize,
619 w_in: usize,
620 h_out: usize,
621 w_out: usize,
622 align_corners: bool,
623) {
624 for b in 0..batch {
625 for c in 0..channels {
626 for oh in 0..h_out {
627 let src_h: f64 = if align_corners {
628 align_corners_coord(oh, h_in, h_out)
629 } else {
630 half_pixel_coord(oh, h_in, h_out)
631 };
632
633 let h_floor = src_h.floor() as isize;
634 let frac_h = src_h - h_floor as f64;
635
636 let wh: [T; 4] = [
637 T::from(cubic_weight(frac_h + 1.0)).unwrap(),
638 T::from(cubic_weight(frac_h)).unwrap(),
639 T::from(cubic_weight(frac_h - 1.0)).unwrap(),
640 T::from(cubic_weight(frac_h - 2.0)).unwrap(),
641 ];
642
643 for ow in 0..w_out {
644 let src_w: f64 = if align_corners {
645 align_corners_coord(ow, w_in, w_out)
646 } else {
647 half_pixel_coord(ow, w_in, w_out)
648 };
649
650 let w_floor = src_w.floor() as isize;
651 let frac_w = src_w - w_floor as f64;
652
653 let ww: [T; 4] = [
654 T::from(cubic_weight(frac_w + 1.0)).unwrap(),
655 T::from(cubic_weight(frac_w)).unwrap(),
656 T::from(cubic_weight(frac_w - 1.0)).unwrap(),
657 T::from(cubic_weight(frac_w - 2.0)).unwrap(),
658 ];
659
660 let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
661 let g = go[out_idx];
662 let base = (b * channels + c) * h_in;
663
664 for dy in 0..4isize {
665 let iy = clamp_coord(h_floor - 1 + dy, h_in - 1);
666 for dx in 0..4isize {
667 let ix = clamp_coord(w_floor - 1 + dx, w_in - 1);
668 grad_input[(base + iy) * w_in + ix] +=
669 g * wh[dy as usize] * ww[dx as usize];
670 }
671 }
672 }
673 }
674 }
675 }
676}
677
678#[derive(Debug, Clone)]
689pub struct Upsample {
690 pub size: Option<[usize; 2]>,
692 pub scale_factor: Option<[f64; 2]>,
694 pub mode: InterpolateMode,
696 pub align_corners: bool,
698}
699
700impl Upsample {
701 pub fn new(size: [usize; 2], mode: InterpolateMode) -> Self {
703 Self {
704 size: Some(size),
705 scale_factor: None,
706 mode,
707 align_corners: false,
708 }
709 }
710
711 pub fn with_scale_factor(scale_factor: [f64; 2], mode: InterpolateMode) -> Self {
713 Self {
714 size: None,
715 scale_factor: Some(scale_factor),
716 mode,
717 align_corners: false,
718 }
719 }
720
721 pub fn align_corners(mut self, align_corners: bool) -> Self {
723 self.align_corners = align_corners;
724 self
725 }
726}
727
728impl<T: Float> Module<T> for Upsample {
729 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
730 interpolate(
731 input,
732 self.size,
733 self.scale_factor,
734 self.mode,
735 self.align_corners,
736 )
737 }
738
739 fn parameters(&self) -> Vec<&Parameter<T>> {
740 vec![]
741 }
742
743 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
744 vec![]
745 }
746
747 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
748 vec![]
749 }
750
751 fn train(&mut self) {}
752 fn eval(&mut self) {}
753
754 fn is_training(&self) -> bool {
755 false
756 }
757}
758
759pub fn grid_sample<T: Float>(
775 input: &Tensor<T>,
776 grid: &Tensor<T>,
777 mode: GridSampleMode,
778 padding_mode: GridSamplePaddingMode,
779 align_corners: bool,
780) -> FerrotorchResult<Tensor<T>> {
781 let (batch, channels, h_in, w_in) = validate_4d(input, "grid_sample")?;
782
783 let grid_shape = grid.shape();
784 if grid_shape.len() != 4 || grid_shape[3] != 2 {
785 return Err(FerrotorchError::InvalidArgument {
786 message: format!(
787 "grid_sample: grid must be [B, H_out, W_out, 2], got {:?}",
788 grid_shape
789 ),
790 });
791 }
792 if grid_shape[0] != batch {
793 return Err(FerrotorchError::ShapeMismatch {
794 message: format!(
795 "grid_sample: batch mismatch between input ({batch}) and grid ({})",
796 grid_shape[0]
797 ),
798 });
799 }
800 let h_out = grid_shape[1];
801 let w_out = grid_shape[2];
802
803 let input_device = input.device();
804 let in_data = input.data_vec()?;
805 let grid_data = grid.data_vec()?;
806
807 let total = batch * channels * h_out * w_out;
808 let mut output = vec![T::from(0.0).unwrap(); total];
809
810 let one = T::from(1.0).unwrap();
811 let two = T::from(2.0).unwrap();
812 let zero = T::from(0.0).unwrap();
813
814 for b in 0..batch {
815 for oh in 0..h_out {
816 for ow in 0..w_out {
817 let grid_base = ((b * h_out + oh) * w_out + ow) * 2;
818 let gx = grid_data[grid_base]; let gy = grid_data[grid_base + 1]; let (src_x, src_y) = if align_corners {
823 let sx = (gx + one) * T::from(w_in - 1).unwrap() / two;
824 let sy = (gy + one) * T::from(h_in - 1).unwrap() / two;
825 (sx, sy)
826 } else {
827 let sx = ((gx + one) * T::from(w_in).unwrap() - one) / two;
828 let sy = ((gy + one) * T::from(h_in).unwrap() - one) / two;
829 (sx, sy)
830 };
831
832 for c in 0..channels {
833 let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
834 let in_base = (b * channels + c) * h_in;
835
836 match mode {
837 GridSampleMode::Nearest => {
838 let ix = src_x.to_f64().unwrap().round() as isize;
839 let iy = src_y.to_f64().unwrap().round() as isize;
840 let (ix, iy) = apply_padding_mode(ix, iy, w_in, h_in, padding_mode);
841
842 if ix >= 0 && ix < w_in as isize && iy >= 0 && iy < h_in as isize {
843 output[out_idx] =
844 in_data[(in_base + iy as usize) * w_in + ix as usize];
845 }
846 }
848 GridSampleMode::Bilinear => {
849 let sx = src_x.to_f64().unwrap();
850 let sy = src_y.to_f64().unwrap();
851 let x0 = sx.floor() as isize;
852 let y0 = sy.floor() as isize;
853 let x1 = x0 + 1;
854 let y1 = y0 + 1;
855 let tx = T::from(sx - x0 as f64).unwrap();
856 let ty = T::from(sy - y0 as f64).unwrap();
857
858 let get_pixel = |iy: isize, ix: isize| -> T {
859 let (ix, iy) = apply_padding_mode(ix, iy, w_in, h_in, padding_mode);
860 if ix >= 0 && ix < w_in as isize && iy >= 0 && iy < h_in as isize {
861 in_data[(in_base + iy as usize) * w_in + ix as usize]
862 } else {
863 zero
864 }
865 };
866
867 let v00 = get_pixel(y0, x0);
868 let v01 = get_pixel(y0, x1);
869 let v10 = get_pixel(y1, x0);
870 let v11 = get_pixel(y1, x1);
871
872 output[out_idx] = v00 * (one - ty) * (one - tx)
873 + v01 * (one - ty) * tx
874 + v10 * ty * (one - tx)
875 + v11 * ty * tx;
876 }
877 }
878 }
879 }
880 }
881 }
882
883 let out_shape = vec![batch, channels, h_out, w_out];
884 let storage = TensorStorage::cpu(output);
885
886 if is_grad_enabled() && (input.requires_grad() || grid.requires_grad()) {
887 Tensor::from_operation(
888 storage,
889 out_shape,
890 Arc::new(GridSampleBackward {
891 input: input.clone(),
892 grid: grid.clone(),
893 mode,
894 padding_mode,
895 align_corners,
896 }),
897 )?
898 .to(input_device)
899 } else {
900 Tensor::from_storage(storage, out_shape, false)?.to(input_device)
901 }
902}
903
904fn apply_padding_mode(
906 ix: isize,
907 iy: isize,
908 w: usize,
909 h: usize,
910 padding_mode: GridSamplePaddingMode,
911) -> (isize, isize) {
912 match padding_mode {
913 GridSamplePaddingMode::Zeros => (ix, iy),
914 GridSamplePaddingMode::Border => {
915 let cx = ix.max(0).min(w as isize - 1);
916 let cy = iy.max(0).min(h as isize - 1);
917 (cx, cy)
918 }
919 GridSamplePaddingMode::Reflection => {
920 let reflect = |v: isize, size: usize| -> isize {
921 if size <= 1 {
922 return 0;
923 }
924 let max = size as isize - 1;
925 let mut v = v;
926 if v < 0 {
927 v = -v;
928 }
929 let period = 2 * max;
931 v %= period;
932 if v > max {
933 v = period - v;
934 }
935 v
936 };
937 (reflect(ix, w), reflect(iy, h))
938 }
939 }
940}
941
942#[derive(Debug)]
947struct GridSampleBackward<T: Float> {
948 input: Tensor<T>,
949 grid: Tensor<T>,
950 mode: GridSampleMode,
951 padding_mode: GridSamplePaddingMode,
952 align_corners: bool,
953}
954
955impl<T: Float> GradFn<T> for GridSampleBackward<T> {
956 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
957 let in_shape = self.input.shape();
958 let (batch, channels, h_in, w_in) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
959 let grid_shape = self.grid.shape();
960 let h_out = grid_shape[1];
961 let w_out = grid_shape[2];
962
963 let go_data = grad_output.data_vec()?;
964 let in_data = self.input.data_vec()?;
965 let grid_data = self.grid.data_vec()?;
966
967 let one = T::from(1.0).unwrap();
968 let two = T::from(2.0).unwrap();
969 let zero = T::from(0.0).unwrap();
970
971 let grad_input_needed = self.input.requires_grad();
972 let grad_grid_needed = self.grid.requires_grad();
973
974 let mut grad_input = if grad_input_needed {
975 vec![zero; batch * channels * h_in * w_in]
976 } else {
977 vec![]
978 };
979 let mut grad_grid = if grad_grid_needed {
980 vec![zero; batch * h_out * w_out * 2]
981 } else {
982 vec![]
983 };
984
985 for b in 0..batch {
986 for oh in 0..h_out {
987 for ow in 0..w_out {
988 let grid_base = ((b * h_out + oh) * w_out + ow) * 2;
989 let gx = grid_data[grid_base];
990 let gy = grid_data[grid_base + 1];
991
992 let (src_x, src_y) = if self.align_corners {
993 let sx = (gx + one) * T::from(w_in - 1).unwrap() / two;
994 let sy = (gy + one) * T::from(h_in - 1).unwrap() / two;
995 (sx, sy)
996 } else {
997 let sx = ((gx + one) * T::from(w_in).unwrap() - one) / two;
998 let sy = ((gy + one) * T::from(h_in).unwrap() - one) / two;
999 (sx, sy)
1000 };
1001
1002 match self.mode {
1003 GridSampleMode::Bilinear => {
1004 let sx = src_x.to_f64().unwrap();
1005 let sy = src_y.to_f64().unwrap();
1006 let x0 = sx.floor() as isize;
1007 let y0 = sy.floor() as isize;
1008 let x1 = x0 + 1;
1009 let y1 = y0 + 1;
1010 let tx = T::from(sx - x0 as f64).unwrap();
1011 let ty = T::from(sy - y0 as f64).unwrap();
1012
1013 let get_clamped = |iy: isize, ix: isize| -> (isize, isize) {
1014 apply_padding_mode(ix, iy, w_in, h_in, self.padding_mode)
1015 };
1016
1017 for c in 0..channels {
1018 let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
1019 let g = go_data[out_idx];
1020 let in_base = (b * channels + c) * h_in;
1021
1022 if grad_input_needed {
1024 let coords = [
1025 (y0, x0, (one - ty) * (one - tx)),
1026 (y0, x1, (one - ty) * tx),
1027 (y1, x0, ty * (one - tx)),
1028 (y1, x1, ty * tx),
1029 ];
1030 for (iy, ix, w) in coords {
1031 let (ix, iy) = get_clamped(iy, ix);
1032 if ix >= 0
1033 && ix < w_in as isize
1034 && iy >= 0
1035 && iy < h_in as isize
1036 {
1037 grad_input
1038 [(in_base + iy as usize) * w_in + ix as usize] +=
1039 g * w;
1040 }
1041 }
1042 }
1043
1044 if grad_grid_needed {
1046 let get_pixel = |iy: isize, ix: isize| -> T {
1047 let (ix, iy) = get_clamped(iy, ix);
1048 if ix >= 0
1049 && ix < w_in as isize
1050 && iy >= 0
1051 && iy < h_in as isize
1052 {
1053 in_data[(in_base + iy as usize) * w_in + ix as usize]
1054 } else {
1055 zero
1056 }
1057 };
1058
1059 let v00 = get_pixel(y0, x0);
1060 let v01 = get_pixel(y0, x1);
1061 let v10 = get_pixel(y1, x0);
1062 let v11 = get_pixel(y1, x1);
1063
1064 let dout_dsx = (one - ty) * (v01 - v00) + ty * (v11 - v10);
1066 let dout_dsy = (one - tx) * (v10 - v00) + tx * (v11 - v01);
1068
1069 let dsx_dgx = if self.align_corners {
1071 T::from(w_in - 1).unwrap() / two
1072 } else {
1073 T::from(w_in).unwrap() / two
1074 };
1075 let dsy_dgy = if self.align_corners {
1076 T::from(h_in - 1).unwrap() / two
1077 } else {
1078 T::from(h_in).unwrap() / two
1079 };
1080
1081 grad_grid[grid_base] += g * dout_dsx * dsx_dgx;
1082 grad_grid[grid_base + 1] += g * dout_dsy * dsy_dgy;
1083 }
1084 }
1085 }
1086 GridSampleMode::Nearest => {
1087 if grad_input_needed {
1090 let ix = src_x.to_f64().unwrap().round() as isize;
1091 let iy = src_y.to_f64().unwrap().round() as isize;
1092 let (ix, iy) =
1093 apply_padding_mode(ix, iy, w_in, h_in, self.padding_mode);
1094
1095 if ix >= 0 && ix < w_in as isize && iy >= 0 && iy < h_in as isize {
1096 for c in 0..channels {
1097 let out_idx =
1098 ((b * channels + c) * h_out + oh) * w_out + ow;
1099 let in_base = (b * channels + c) * h_in;
1100 grad_input[(in_base + iy as usize) * w_in + ix as usize] +=
1101 go_data[out_idx];
1102 }
1103 }
1104 }
1105 }
1106 }
1107 }
1108 }
1109 }
1110
1111 let gi = if grad_input_needed {
1112 Some(Tensor::from_storage(
1113 TensorStorage::cpu(grad_input),
1114 self.input.shape().to_vec(),
1115 false,
1116 )?)
1117 } else {
1118 None
1119 };
1120
1121 let gg = if grad_grid_needed {
1122 Some(Tensor::from_storage(
1123 TensorStorage::cpu(grad_grid),
1124 self.grid.shape().to_vec(),
1125 false,
1126 )?)
1127 } else {
1128 None
1129 };
1130
1131 Ok(vec![gi, gg])
1132 }
1133
1134 fn inputs(&self) -> Vec<&Tensor<T>> {
1135 vec![&self.input, &self.grid]
1136 }
1137
1138 fn name(&self) -> &'static str {
1139 "GridSampleBackward"
1140 }
1141}
1142
1143pub fn affine_grid<T: Float>(
1157 theta: &Tensor<T>,
1158 size: [usize; 4],
1159 align_corners: bool,
1160) -> FerrotorchResult<Tensor<T>> {
1161 let theta_shape = theta.shape();
1162 if theta_shape.len() != 3 || theta_shape[1] != 2 || theta_shape[2] != 3 {
1163 return Err(FerrotorchError::InvalidArgument {
1164 message: format!(
1165 "affine_grid: theta must be [B, 2, 3], got {:?}",
1166 theta_shape
1167 ),
1168 });
1169 }
1170 let batch = theta_shape[0];
1171 if size[0] != batch {
1172 return Err(FerrotorchError::ShapeMismatch {
1173 message: format!(
1174 "affine_grid: batch mismatch: theta batch {batch}, size batch {}",
1175 size[0]
1176 ),
1177 });
1178 }
1179
1180 let h = size[2];
1181 let w = size[3];
1182 let one = T::from(1.0).unwrap();
1183 let two = T::from(2.0).unwrap();
1184
1185 let theta_data = theta.data_vec()?;
1186 let theta_device = theta.device();
1187 let total = batch * h * w * 2;
1188 let mut grid = vec![T::from(0.0).unwrap(); total];
1189
1190 for b in 0..batch {
1191 let t_base = b * 6;
1192 let t00 = theta_data[t_base];
1193 let t01 = theta_data[t_base + 1];
1194 let t02 = theta_data[t_base + 2];
1195 let t10 = theta_data[t_base + 3];
1196 let t11 = theta_data[t_base + 4];
1197 let t12 = theta_data[t_base + 5];
1198
1199 for iy in 0..h {
1200 let y_norm = if align_corners {
1201 if h <= 1 {
1202 T::from(0.0).unwrap()
1203 } else {
1204 two * T::from(iy).unwrap() / T::from(h - 1).unwrap() - one
1205 }
1206 } else {
1207 (two * T::from(iy).unwrap() + one) / T::from(h).unwrap() - one
1208 };
1209
1210 for ix in 0..w {
1211 let x_norm = if align_corners {
1212 if w <= 1 {
1213 T::from(0.0).unwrap()
1214 } else {
1215 two * T::from(ix).unwrap() / T::from(w - 1).unwrap() - one
1216 }
1217 } else {
1218 (two * T::from(ix).unwrap() + one) / T::from(w).unwrap() - one
1219 };
1220
1221 let out_base = ((b * h + iy) * w + ix) * 2;
1222 grid[out_base] = t00 * x_norm + t01 * y_norm + t02;
1223 grid[out_base + 1] = t10 * x_norm + t11 * y_norm + t12;
1224 }
1225 }
1226 }
1227
1228 let out_shape = vec![batch, h, w, 2];
1229 let storage = TensorStorage::cpu(grid);
1230
1231 if is_grad_enabled() && theta.requires_grad() {
1232 Tensor::from_operation(
1233 storage,
1234 out_shape,
1235 Arc::new(AffineGridBackward {
1236 theta: theta.clone(),
1237 size,
1238 align_corners,
1239 }),
1240 )?
1241 .to(theta_device)
1242 } else {
1243 Tensor::from_storage(storage, out_shape, false)?.to(theta_device)
1244 }
1245}
1246
1247#[derive(Debug)]
1248struct AffineGridBackward<T: Float> {
1249 theta: Tensor<T>,
1250 size: [usize; 4],
1251 align_corners: bool,
1252}
1253
1254impl<T: Float> GradFn<T> for AffineGridBackward<T> {
1255 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1256 if !self.theta.requires_grad() {
1257 return Ok(vec![None]);
1258 }
1259
1260 let batch = self.size[0];
1261 let h = self.size[2];
1262 let w = self.size[3];
1263 let one = T::from(1.0).unwrap();
1264 let two = T::from(2.0).unwrap();
1265 let zero = T::from(0.0).unwrap();
1266
1267 let go_data = grad_output.data_vec()?;
1268 let mut grad_theta = vec![zero; batch * 6];
1269
1270 for b in 0..batch {
1271 for iy in 0..h {
1272 let y_norm = if self.align_corners {
1273 if h <= 1 {
1274 zero
1275 } else {
1276 two * T::from(iy).unwrap() / T::from(h - 1).unwrap() - one
1277 }
1278 } else {
1279 (two * T::from(iy).unwrap() + one) / T::from(h).unwrap() - one
1280 };
1281
1282 for ix in 0..w {
1283 let x_norm = if self.align_corners {
1284 if w <= 1 {
1285 zero
1286 } else {
1287 two * T::from(ix).unwrap() / T::from(w - 1).unwrap() - one
1288 }
1289 } else {
1290 (two * T::from(ix).unwrap() + one) / T::from(w).unwrap() - one
1291 };
1292
1293 let go_base = ((b * h + iy) * w + ix) * 2;
1294 let gx = go_data[go_base];
1295 let gy = go_data[go_base + 1];
1296
1297 let t_base = b * 6;
1298 grad_theta[t_base] += gx * x_norm;
1300 grad_theta[t_base + 1] += gx * y_norm;
1301 grad_theta[t_base + 2] += gx;
1302 grad_theta[t_base + 3] += gy * x_norm;
1304 grad_theta[t_base + 4] += gy * y_norm;
1305 grad_theta[t_base + 5] += gy;
1306 }
1307 }
1308 }
1309
1310 let grad_tensor = Tensor::from_storage(
1311 TensorStorage::cpu(grad_theta),
1312 self.theta.shape().to_vec(),
1313 false,
1314 )?;
1315 Ok(vec![Some(grad_tensor)])
1316 }
1317
1318 fn inputs(&self) -> Vec<&Tensor<T>> {
1319 vec![&self.theta]
1320 }
1321
1322 fn name(&self) -> &'static str {
1323 "AffineGridBackward"
1324 }
1325}
1326
1327#[derive(Debug, Clone, Copy)]
1338pub struct PixelShuffle {
1339 pub upscale_factor: usize,
1341}
1342
1343impl PixelShuffle {
1344 pub fn new(upscale_factor: usize) -> Self {
1345 Self { upscale_factor }
1346 }
1347}
1348
1349impl<T: Float> Module<T> for PixelShuffle {
1350 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1351 pixel_shuffle(input, self.upscale_factor)
1352 }
1353
1354 fn parameters(&self) -> Vec<&Parameter<T>> {
1355 vec![]
1356 }
1357 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1358 vec![]
1359 }
1360 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1361 vec![]
1362 }
1363 fn train(&mut self) {}
1364 fn eval(&mut self) {}
1365 fn is_training(&self) -> bool {
1366 false
1367 }
1368}
1369
1370#[derive(Debug, Clone, Copy)]
1374pub struct PixelUnshuffle {
1375 pub downscale_factor: usize,
1377}
1378
1379impl PixelUnshuffle {
1380 pub fn new(downscale_factor: usize) -> Self {
1381 Self { downscale_factor }
1382 }
1383}
1384
1385impl<T: Float> Module<T> for PixelUnshuffle {
1386 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1387 pixel_unshuffle(input, self.downscale_factor)
1388 }
1389
1390 fn parameters(&self) -> Vec<&Parameter<T>> {
1391 vec![]
1392 }
1393 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1394 vec![]
1395 }
1396 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1397 vec![]
1398 }
1399 fn train(&mut self) {}
1400 fn eval(&mut self) {}
1401 fn is_training(&self) -> bool {
1402 false
1403 }
1404}
1405
1406pub fn pixel_shuffle<T: Float>(
1410 input: &Tensor<T>,
1411 upscale_factor: usize,
1412) -> FerrotorchResult<Tensor<T>> {
1413 let (batch, channels_in, h, w) = validate_4d(input, "pixel_shuffle")?;
1414 let r = upscale_factor;
1415
1416 if r == 0 {
1417 return Err(FerrotorchError::InvalidArgument {
1418 message: "pixel_shuffle: upscale_factor must be > 0".into(),
1419 });
1420 }
1421 if channels_in % (r * r) != 0 {
1422 return Err(FerrotorchError::InvalidArgument {
1423 message: format!(
1424 "pixel_shuffle: channels ({channels_in}) must be divisible by r^2 ({})",
1425 r * r
1426 ),
1427 });
1428 }
1429
1430 let c_out = channels_in / (r * r);
1431 let h_out = h * r;
1432 let w_out = w * r;
1433
1434 let input_device = input.device();
1435 let data = input.data_vec()?;
1436
1437 let total = batch * c_out * h_out * w_out;
1438 let mut output = vec![T::from(0.0).unwrap(); total];
1439
1440 for b in 0..batch {
1442 for c in 0..c_out {
1443 for ih in 0..h {
1444 for iw in 0..w {
1445 for rh in 0..r {
1446 for rw in 0..r {
1447 let in_c = c * r * r + rh * r + rw;
1448 let in_idx = ((b * channels_in + in_c) * h + ih) * w + iw;
1449
1450 let oh = ih * r + rh;
1451 let ow_pos = iw * r + rw;
1452 let out_idx = ((b * c_out + c) * h_out + oh) * w_out + ow_pos;
1453
1454 output[out_idx] = data[in_idx];
1455 }
1456 }
1457 }
1458 }
1459 }
1460 }
1461
1462 let out_shape = vec![batch, c_out, h_out, w_out];
1463 let storage = TensorStorage::cpu(output);
1464
1465 if is_grad_enabled() && input.requires_grad() {
1466 Tensor::from_operation(
1467 storage,
1468 out_shape,
1469 Arc::new(PixelShuffleBackward {
1470 input: input.clone(),
1471 upscale_factor: r,
1472 }),
1473 )?
1474 .to(input_device)
1475 } else {
1476 Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1477 }
1478}
1479
1480pub fn pixel_unshuffle<T: Float>(
1484 input: &Tensor<T>,
1485 downscale_factor: usize,
1486) -> FerrotorchResult<Tensor<T>> {
1487 let (batch, channels, h_in, w_in) = validate_4d(input, "pixel_unshuffle")?;
1488 let r = downscale_factor;
1489
1490 if r == 0 {
1491 return Err(FerrotorchError::InvalidArgument {
1492 message: "pixel_unshuffle: downscale_factor must be > 0".into(),
1493 });
1494 }
1495 if h_in % r != 0 || w_in % r != 0 {
1496 return Err(FerrotorchError::InvalidArgument {
1497 message: format!(
1498 "pixel_unshuffle: spatial dims ({h_in}, {w_in}) must be divisible by r={r}"
1499 ),
1500 });
1501 }
1502
1503 let h_out = h_in / r;
1504 let w_out = w_in / r;
1505 let c_out = channels * r * r;
1506
1507 let input_device = input.device();
1508 let data = input.data_vec()?;
1509
1510 let total = batch * c_out * h_out * w_out;
1511 let mut output = vec![T::from(0.0).unwrap(); total];
1512
1513 for b in 0..batch {
1514 for c in 0..channels {
1515 for oh in 0..h_out {
1516 for ow in 0..w_out {
1517 for rh in 0..r {
1518 for rw in 0..r {
1519 let in_h = oh * r + rh;
1520 let in_w = ow * r + rw;
1521 let in_idx = ((b * channels + c) * h_in + in_h) * w_in + in_w;
1522
1523 let out_c = c * r * r + rh * r + rw;
1524 let out_idx = ((b * c_out + out_c) * h_out + oh) * w_out + ow;
1525
1526 output[out_idx] = data[in_idx];
1527 }
1528 }
1529 }
1530 }
1531 }
1532 }
1533
1534 let out_shape = vec![batch, c_out, h_out, w_out];
1535 let storage = TensorStorage::cpu(output);
1536
1537 if is_grad_enabled() && input.requires_grad() {
1538 Tensor::from_operation(
1539 storage,
1540 out_shape,
1541 Arc::new(PixelUnshuffleBackward {
1542 input: input.clone(),
1543 downscale_factor: r,
1544 }),
1545 )?
1546 .to(input_device)
1547 } else {
1548 Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1549 }
1550}
1551
1552#[derive(Debug)]
1553struct PixelShuffleBackward<T: Float> {
1554 input: Tensor<T>,
1555 upscale_factor: usize,
1556}
1557
1558impl<T: Float> GradFn<T> for PixelShuffleBackward<T> {
1559 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1560 if !self.input.requires_grad() {
1561 return Ok(vec![None]);
1562 }
1563 let grad_input = pixel_unshuffle(grad_output, self.upscale_factor)?;
1565 Ok(vec![Some(grad_input)])
1566 }
1567
1568 fn inputs(&self) -> Vec<&Tensor<T>> {
1569 vec![&self.input]
1570 }
1571
1572 fn name(&self) -> &'static str {
1573 "PixelShuffleBackward"
1574 }
1575}
1576
1577#[derive(Debug)]
1578struct PixelUnshuffleBackward<T: Float> {
1579 input: Tensor<T>,
1580 downscale_factor: usize,
1581}
1582
1583impl<T: Float> GradFn<T> for PixelUnshuffleBackward<T> {
1584 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1585 if !self.input.requires_grad() {
1586 return Ok(vec![None]);
1587 }
1588 let grad_input = pixel_shuffle(grad_output, self.downscale_factor)?;
1590 Ok(vec![Some(grad_input)])
1591 }
1592
1593 fn inputs(&self) -> Vec<&Tensor<T>> {
1594 vec![&self.input]
1595 }
1596
1597 fn name(&self) -> &'static str {
1598 "PixelUnshuffleBackward"
1599 }
1600}
1601
1602#[derive(Debug, Clone, Copy)]
1614pub struct Unfold {
1615 pub kernel_size: [usize; 2],
1616 pub dilation: [usize; 2],
1617 pub padding: [usize; 2],
1618 pub stride: [usize; 2],
1619}
1620
1621impl Unfold {
1622 pub fn new(
1623 kernel_size: [usize; 2],
1624 dilation: [usize; 2],
1625 padding: [usize; 2],
1626 stride: [usize; 2],
1627 ) -> Self {
1628 Self {
1629 kernel_size,
1630 dilation,
1631 padding,
1632 stride,
1633 }
1634 }
1635}
1636
1637impl<T: Float> Module<T> for Unfold {
1638 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1639 unfold(
1640 input,
1641 self.kernel_size,
1642 self.dilation,
1643 self.padding,
1644 self.stride,
1645 )
1646 }
1647
1648 fn parameters(&self) -> Vec<&Parameter<T>> {
1649 vec![]
1650 }
1651 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1652 vec![]
1653 }
1654 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1655 vec![]
1656 }
1657 fn train(&mut self) {}
1658 fn eval(&mut self) {}
1659 fn is_training(&self) -> bool {
1660 false
1661 }
1662}
1663
1664#[derive(Debug, Clone, Copy)]
1671pub struct Fold {
1672 pub output_size: [usize; 2],
1673 pub kernel_size: [usize; 2],
1674 pub dilation: [usize; 2],
1675 pub padding: [usize; 2],
1676 pub stride: [usize; 2],
1677}
1678
1679impl Fold {
1680 pub fn new(
1681 output_size: [usize; 2],
1682 kernel_size: [usize; 2],
1683 dilation: [usize; 2],
1684 padding: [usize; 2],
1685 stride: [usize; 2],
1686 ) -> Self {
1687 Self {
1688 output_size,
1689 kernel_size,
1690 dilation,
1691 padding,
1692 stride,
1693 }
1694 }
1695}
1696
1697impl<T: Float> Module<T> for Fold {
1698 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1699 fold(
1700 input,
1701 self.output_size,
1702 self.kernel_size,
1703 self.dilation,
1704 self.padding,
1705 self.stride,
1706 )
1707 }
1708
1709 fn parameters(&self) -> Vec<&Parameter<T>> {
1710 vec![]
1711 }
1712 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1713 vec![]
1714 }
1715 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1716 vec![]
1717 }
1718 fn train(&mut self) {}
1719 fn eval(&mut self) {}
1720 fn is_training(&self) -> bool {
1721 false
1722 }
1723}
1724
1725#[inline]
1727fn unfold_output_size(
1728 input_size: usize,
1729 kernel_size: usize,
1730 dilation: usize,
1731 padding: usize,
1732 stride: usize,
1733) -> usize {
1734 (input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
1735}
1736
1737pub fn unfold<T: Float>(
1741 input: &Tensor<T>,
1742 kernel_size: [usize; 2],
1743 dilation: [usize; 2],
1744 padding: [usize; 2],
1745 stride: [usize; 2],
1746) -> FerrotorchResult<Tensor<T>> {
1747 let (batch, channels, h, w) = validate_4d(input, "unfold")?;
1748
1749 if kernel_size[0] == 0
1750 || kernel_size[1] == 0
1751 || stride[0] == 0
1752 || stride[1] == 0
1753 || dilation[0] == 0
1754 || dilation[1] == 0
1755 {
1756 return Err(FerrotorchError::InvalidArgument {
1757 message: "unfold: kernel_size, stride, dilation must all be > 0".into(),
1758 });
1759 }
1760
1761 let out_h = unfold_output_size(h, kernel_size[0], dilation[0], padding[0], stride[0]);
1762 let out_w = unfold_output_size(w, kernel_size[1], dilation[1], padding[1], stride[1]);
1763 let l = out_h * out_w;
1764 let k = channels * kernel_size[0] * kernel_size[1];
1765
1766 let input_device = input.device();
1767 let data = input.data_vec()?;
1768
1769 let total = batch * k * l;
1770 let mut output = vec![T::from(0.0).unwrap(); total];
1771
1772 for b in 0..batch {
1773 for c in 0..channels {
1774 for kh in 0..kernel_size[0] {
1775 for kw in 0..kernel_size[1] {
1776 let k_idx = (c * kernel_size[0] + kh) * kernel_size[1] + kw;
1777 for oh in 0..out_h {
1778 for ow in 0..out_w {
1779 let ih = oh * stride[0] + kh * dilation[0];
1780 let iw = ow * stride[1] + kw * dilation[1];
1781 let ih = ih as isize - padding[0] as isize;
1782 let iw = iw as isize - padding[1] as isize;
1783
1784 let l_idx = oh * out_w + ow;
1785 let out_idx = (b * k + k_idx) * l + l_idx;
1786
1787 if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1788 let in_idx =
1789 ((b * channels + c) * h + ih as usize) * w + iw as usize;
1790 output[out_idx] = data[in_idx];
1791 }
1792 }
1794 }
1795 }
1796 }
1797 }
1798 }
1799
1800 let out_shape = vec![batch, k, l];
1801 let storage = TensorStorage::cpu(output);
1802
1803 if is_grad_enabled() && input.requires_grad() {
1804 Tensor::from_operation(
1805 storage,
1806 out_shape,
1807 Arc::new(UnfoldBackward {
1808 input: input.clone(),
1809 kernel_size,
1810 dilation,
1811 padding,
1812 stride,
1813 }),
1814 )?
1815 .to(input_device)
1816 } else {
1817 Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1818 }
1819}
1820
1821pub fn fold<T: Float>(
1825 input: &Tensor<T>,
1826 output_size: [usize; 2],
1827 kernel_size: [usize; 2],
1828 dilation: [usize; 2],
1829 padding: [usize; 2],
1830 stride: [usize; 2],
1831) -> FerrotorchResult<Tensor<T>> {
1832 let shape = input.shape();
1833 if shape.len() != 3 {
1834 return Err(FerrotorchError::InvalidArgument {
1835 message: format!(
1836 "fold expects 3D input [B, C*kH*kW, L], got shape {:?}",
1837 shape
1838 ),
1839 });
1840 }
1841
1842 if kernel_size[0] == 0
1843 || kernel_size[1] == 0
1844 || stride[0] == 0
1845 || stride[1] == 0
1846 || dilation[0] == 0
1847 || dilation[1] == 0
1848 {
1849 return Err(FerrotorchError::InvalidArgument {
1850 message: "fold: kernel_size, stride, dilation must all be > 0".into(),
1851 });
1852 }
1853
1854 let batch = shape[0];
1855 let k = shape[1]; let l = shape[2]; let [h_out, w_out] = output_size;
1859 let k_area = kernel_size[0] * kernel_size[1];
1860
1861 if k % k_area != 0 {
1862 return Err(FerrotorchError::InvalidArgument {
1863 message: format!("fold: dim 1 ({k}) must be divisible by kH*kW ({})", k_area),
1864 });
1865 }
1866 let channels = k / k_area;
1867
1868 let expected_out_h =
1869 unfold_output_size(h_out, kernel_size[0], dilation[0], padding[0], stride[0]);
1870 let expected_out_w =
1871 unfold_output_size(w_out, kernel_size[1], dilation[1], padding[1], stride[1]);
1872 let expected_l = expected_out_h * expected_out_w;
1873
1874 if l != expected_l {
1875 return Err(FerrotorchError::InvalidArgument {
1876 message: format!(
1877 "fold: L={l} does not match expected {expected_l} for output_size ({h_out}, {w_out})"
1878 ),
1879 });
1880 }
1881
1882 let input_device = input.device();
1883 let data = input.data_vec()?;
1884
1885 let total = batch * channels * h_out * w_out;
1886 let mut output = vec![T::from(0.0).unwrap(); total];
1887
1888 for b in 0..batch {
1889 for c in 0..channels {
1890 for kh in 0..kernel_size[0] {
1891 for kw in 0..kernel_size[1] {
1892 let k_idx = (c * kernel_size[0] + kh) * kernel_size[1] + kw;
1893 for oh in 0..expected_out_h {
1894 for ow in 0..expected_out_w {
1895 let ih = oh * stride[0] + kh * dilation[0];
1896 let iw = ow * stride[1] + kw * dilation[1];
1897 let ih = ih as isize - padding[0] as isize;
1898 let iw = iw as isize - padding[1] as isize;
1899
1900 if ih >= 0 && ih < h_out as isize && iw >= 0 && iw < w_out as isize {
1901 let l_idx = oh * expected_out_w + ow;
1902 let in_idx = (b * k + k_idx) * l + l_idx;
1903 let out_idx = ((b * channels + c) * h_out + ih as usize) * w_out
1904 + iw as usize;
1905 output[out_idx] += data[in_idx];
1906 }
1907 }
1908 }
1909 }
1910 }
1911 }
1912 }
1913
1914 let out_shape = vec![batch, channels, h_out, w_out];
1915 let storage = TensorStorage::cpu(output);
1916
1917 if is_grad_enabled() && input.requires_grad() {
1918 Tensor::from_operation(
1919 storage,
1920 out_shape,
1921 Arc::new(FoldBackward {
1922 input: input.clone(),
1923 kernel_size,
1924 dilation,
1925 padding,
1926 stride,
1927 }),
1928 )?
1929 .to(input_device)
1930 } else {
1931 Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1932 }
1933}
1934
1935#[derive(Debug)]
1936struct UnfoldBackward<T: Float> {
1937 input: Tensor<T>,
1938 kernel_size: [usize; 2],
1939 dilation: [usize; 2],
1940 padding: [usize; 2],
1941 stride: [usize; 2],
1942}
1943
1944impl<T: Float> GradFn<T> for UnfoldBackward<T> {
1945 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1946 if !self.input.requires_grad() {
1947 return Ok(vec![None]);
1948 }
1949 let in_shape = self.input.shape();
1951 let h = in_shape[2];
1952 let w = in_shape[3];
1953 let grad_input = fold(
1954 grad_output,
1955 [h, w],
1956 self.kernel_size,
1957 self.dilation,
1958 self.padding,
1959 self.stride,
1960 )?;
1961 Ok(vec![Some(grad_input)])
1962 }
1963
1964 fn inputs(&self) -> Vec<&Tensor<T>> {
1965 vec![&self.input]
1966 }
1967
1968 fn name(&self) -> &'static str {
1969 "UnfoldBackward"
1970 }
1971}
1972
1973#[derive(Debug)]
1974struct FoldBackward<T: Float> {
1975 input: Tensor<T>,
1976 kernel_size: [usize; 2],
1977 dilation: [usize; 2],
1978 padding: [usize; 2],
1979 stride: [usize; 2],
1980}
1981
1982impl<T: Float> GradFn<T> for FoldBackward<T> {
1983 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1984 if !self.input.requires_grad() {
1985 return Ok(vec![None]);
1986 }
1987 let grad_input = unfold(
1989 grad_output,
1990 self.kernel_size,
1991 self.dilation,
1992 self.padding,
1993 self.stride,
1994 )?;
1995 Ok(vec![Some(grad_input)])
1996 }
1997
1998 fn inputs(&self) -> Vec<&Tensor<T>> {
1999 vec![&self.input]
2000 }
2001
2002 fn name(&self) -> &'static str {
2003 "FoldBackward"
2004 }
2005}
2006
2007#[cfg(test)]
2012mod tests {
2013 use super::*;
2014
2015 fn leaf_4d(data: &[f32], shape: [usize; 4], requires_grad: bool) -> Tensor<f32> {
2017 Tensor::from_storage(
2018 TensorStorage::cpu(data.to_vec()),
2019 shape.to_vec(),
2020 requires_grad,
2021 )
2022 .unwrap()
2023 }
2024
2025 fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
2027 Tensor::from_storage(
2028 TensorStorage::cpu(data.to_vec()),
2029 shape.to_vec(),
2030 requires_grad,
2031 )
2032 .unwrap()
2033 }
2034
2035 fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
2036 assert_eq!(
2037 actual.len(),
2038 expected.len(),
2039 "length mismatch: {} vs {}",
2040 actual.len(),
2041 expected.len()
2042 );
2043 for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
2044 assert!(
2045 (a - e).abs() < tol,
2046 "index {i}: actual={a} expected={e} diff={}",
2047 (a - e).abs(),
2048 );
2049 }
2050 }
2051
2052 #[test]
2057 fn test_interpolate_nearest_upsample_2x() {
2058 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2060 let input = leaf_4d(&data, [1, 1, 2, 2], false);
2061 let out = interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).unwrap();
2062 assert_eq!(out.shape(), &[1, 1, 4, 4]);
2063
2064 let d = out.data().unwrap();
2065 #[rustfmt::skip]
2067 let expected: Vec<f32> = vec![
2068 1.0, 1.0, 2.0, 2.0,
2069 1.0, 1.0, 2.0, 2.0,
2070 3.0, 3.0, 4.0, 4.0,
2071 3.0, 3.0, 4.0, 4.0,
2072 ];
2073 assert_close(d, &expected, 1e-6);
2074 }
2075
2076 #[test]
2077 fn test_interpolate_nearest_downsample() {
2078 #[rustfmt::skip]
2080 let data: Vec<f32> = vec![
2081 1.0, 2.0, 3.0, 4.0,
2082 5.0, 6.0, 7.0, 8.0,
2083 9.0, 10.0, 11.0, 12.0,
2084 13.0, 14.0, 15.0, 16.0,
2085 ];
2086 let input = leaf_4d(&data, [1, 1, 4, 4], false);
2087 let out = interpolate(&input, Some([2, 2]), None, InterpolateMode::Nearest, false).unwrap();
2088 assert_eq!(out.shape(), &[1, 1, 2, 2]);
2089 let d = out.data().unwrap();
2090 assert_close(d, &[1.0, 3.0, 9.0, 11.0], 1e-6);
2092 }
2093
2094 #[test]
2095 fn test_interpolate_nearest_scale_factor() {
2096 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2097 let input = leaf_4d(&data, [1, 1, 2, 2], false);
2098 let out = interpolate(
2099 &input,
2100 None,
2101 Some([2.0, 2.0]),
2102 InterpolateMode::Nearest,
2103 false,
2104 )
2105 .unwrap();
2106 assert_eq!(out.shape(), &[1, 1, 4, 4]);
2107 }
2108
2109 #[test]
2114 fn test_interpolate_bilinear_upsample() {
2115 let data: Vec<f32> = vec![0.0, 1.0, 2.0, 3.0];
2117 let input = leaf_4d(&data, [1, 1, 2, 2], false);
2118 let out = interpolate(&input, Some([3, 3]), None, InterpolateMode::Bilinear, true).unwrap();
2119 assert_eq!(out.shape(), &[1, 1, 3, 3]);
2120
2121 let d = out.data().unwrap();
2122 assert!((d[0] - 0.0).abs() < 1e-5); assert!((d[2] - 1.0).abs() < 1e-5); assert!((d[6] - 2.0).abs() < 1e-5); assert!((d[8] - 3.0).abs() < 1e-5); assert!((d[4] - 1.5).abs() < 1e-5);
2129 }
2130
2131 #[test]
2132 fn test_interpolate_bilinear_identity() {
2133 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
2135 let input = leaf_4d(&data, [1, 1, 3, 3], false);
2136 let out = interpolate(&input, Some([3, 3]), None, InterpolateMode::Bilinear, true).unwrap();
2137 assert_eq!(out.shape(), &[1, 1, 3, 3]);
2138 assert_close(out.data().unwrap(), &data, 1e-5);
2139 }
2140
2141 #[test]
2146 fn test_interpolate_bicubic_output_shape() {
2147 let data: Vec<f32> = vec![0.0; 64];
2148 let input = leaf_4d(&data, [1, 1, 8, 8], false);
2149 let out = interpolate(
2150 &input,
2151 Some([16, 16]),
2152 None,
2153 InterpolateMode::Bicubic,
2154 false,
2155 )
2156 .unwrap();
2157 assert_eq!(out.shape(), &[1, 1, 16, 16]);
2158 }
2159
2160 #[test]
2161 fn test_interpolate_bicubic_corners_align() {
2162 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2165 let input = leaf_4d(&data, [1, 1, 2, 2], false);
2166 let out = interpolate(&input, Some([5, 5]), None, InterpolateMode::Bicubic, true).unwrap();
2167 assert_eq!(out.shape(), &[1, 1, 5, 5]);
2168 let d = out.data().unwrap();
2169 assert!((d[0] - 1.0).abs() < 1e-4); assert!((d[4] - 2.0).abs() < 1e-4); assert!((d[20] - 3.0).abs() < 1e-4); assert!((d[24] - 4.0).abs() < 1e-4); }
2174
2175 #[test]
2180 fn test_upsample_module_nearest() {
2181 let up = Upsample::new([6, 6], InterpolateMode::Nearest);
2182 let input = leaf_4d(&[0.0; 9], [1, 1, 3, 3], false);
2183 let out: Tensor<f32> = Module::<f32>::forward(&up, &input).unwrap();
2184 assert_eq!(out.shape(), &[1, 1, 6, 6]);
2185 }
2186
2187 #[test]
2188 fn test_upsample_module_bilinear_scale() {
2189 let up = Upsample::with_scale_factor([2.0, 2.0], InterpolateMode::Bilinear);
2190 let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2191 let out: Tensor<f32> = Module::<f32>::forward(&up, &input).unwrap();
2192 assert_eq!(out.shape(), &[1, 1, 4, 4]);
2193 }
2194
2195 #[test]
2196 fn test_upsample_no_parameters() {
2197 let up = Upsample::new([4, 4], InterpolateMode::Nearest);
2198 assert!(Module::<f32>::parameters(&up).is_empty());
2199 }
2200
2201 #[test]
2206 fn test_interpolate_no_size_no_scale() {
2207 let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2208 assert!(interpolate(&input, None, None, InterpolateMode::Nearest, false).is_err());
2209 }
2210
2211 #[test]
2212 fn test_interpolate_both_size_and_scale() {
2213 let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2214 assert!(
2215 interpolate(
2216 &input,
2217 Some([4, 4]),
2218 Some([2.0, 2.0]),
2219 InterpolateMode::Nearest,
2220 false
2221 )
2222 .is_err()
2223 );
2224 }
2225
2226 #[test]
2227 fn test_interpolate_nearest_align_corners_rejected() {
2228 let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2229 assert!(interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, true).is_err());
2230 }
2231
2232 #[test]
2233 fn test_interpolate_3d_rejected() {
2234 let input = leaf(&[0.0; 6], &[2, 3], false);
2235 assert!(interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).is_err());
2236 }
2237
2238 #[test]
2243 fn test_interpolate_nearest_backward() {
2244 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2245 let input = leaf_4d(&data, [1, 1, 2, 2], true);
2246 let out = interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).unwrap();
2247
2248 let out_data = out.data().unwrap().to_vec();
2249 let total: f32 = out_data.iter().sum();
2250 let loss = Tensor::from_operation(
2251 TensorStorage::cpu(vec![total]),
2252 vec![],
2253 Arc::new(TestSumBackward { input: out }),
2254 )
2255 .unwrap();
2256 loss.backward().unwrap();
2257
2258 let grad = input.grad().unwrap().unwrap();
2259 let g = grad.data().unwrap();
2260 for (i, &val) in g.iter().enumerate() {
2262 assert!(
2263 (val - 4.0).abs() < 1e-5,
2264 "grad[{i}]: expected 4.0, got {val}"
2265 );
2266 }
2267 }
2268
2269 #[test]
2270 fn test_interpolate_bilinear_backward() {
2271 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2272 let input = leaf_4d(&data, [1, 1, 2, 2], true);
2273 let out =
2274 interpolate(&input, Some([4, 4]), None, InterpolateMode::Bilinear, false).unwrap();
2275
2276 let out_data = out.data().unwrap().to_vec();
2277 let total: f32 = out_data.iter().sum();
2278 let loss = Tensor::from_operation(
2279 TensorStorage::cpu(vec![total]),
2280 vec![],
2281 Arc::new(TestSumBackward { input: out }),
2282 )
2283 .unwrap();
2284 loss.backward().unwrap();
2285
2286 let grad = input.grad().unwrap().unwrap();
2287 let g = grad.data().unwrap();
2288 let grad_sum: f32 = g.iter().sum();
2290 assert!(
2292 (grad_sum - 16.0).abs() < 1e-3,
2293 "grad sum = {grad_sum}, expected 16.0"
2294 );
2295 }
2296
2297 #[test]
2302 fn test_pixel_shuffle_shape() {
2303 let data = vec![0.0f32; 16];
2305 let input = leaf_4d(&data, [1, 4, 2, 2], false);
2306 let out = pixel_shuffle(&input, 2).unwrap();
2307 assert_eq!(out.shape(), &[1, 1, 4, 4]);
2308 }
2309
2310 #[test]
2311 fn test_pixel_shuffle_values() {
2312 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2316 let input = leaf_4d(&data, [1, 4, 1, 1], false);
2317 let out = pixel_shuffle(&input, 2).unwrap();
2318 assert_eq!(out.shape(), &[1, 1, 2, 2]);
2319 assert_close(out.data().unwrap(), &[1.0, 2.0, 3.0, 4.0], 1e-6);
2320 }
2321
2322 #[test]
2323 fn test_pixel_shuffle_not_divisible() {
2324 let input = leaf_4d(&[0.0; 12], [1, 3, 2, 2], false);
2326 assert!(pixel_shuffle(&input, 2).is_err());
2327 }
2328
2329 #[test]
2334 fn test_pixel_unshuffle_shape() {
2335 let data = vec![0.0f32; 16];
2337 let input = leaf_4d(&data, [1, 1, 4, 4], false);
2338 let out = pixel_unshuffle(&input, 2).unwrap();
2339 assert_eq!(out.shape(), &[1, 4, 2, 2]);
2340 }
2341
2342 #[test]
2343 fn test_pixel_shuffle_unshuffle_roundtrip() {
2344 let data: Vec<f32> = (0..36).map(|i| i as f32).collect();
2346 let input = leaf_4d(&data, [1, 4, 3, 3], false);
2347 let shuffled = pixel_shuffle(&input, 2).unwrap();
2348 assert_eq!(shuffled.shape(), &[1, 1, 6, 6]);
2349 let roundtrip = pixel_unshuffle(&shuffled, 2).unwrap();
2350 assert_eq!(roundtrip.shape(), &[1, 4, 3, 3]);
2351 assert_close(roundtrip.data().unwrap(), &data, 1e-6);
2352 }
2353
2354 #[test]
2355 fn test_pixel_unshuffle_spatial_not_divisible() {
2356 let input = leaf_4d(&[0.0; 9], [1, 1, 3, 3], false);
2358 assert!(pixel_unshuffle(&input, 2).is_err());
2359 }
2360
2361 #[test]
2366 fn test_pixel_shuffle_backward() {
2367 let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
2368 let input = leaf_4d(&data, [1, 4, 2, 2], true);
2369 let out = pixel_shuffle(&input, 2).unwrap();
2370
2371 let out_data = out.data().unwrap().to_vec();
2372 let total: f32 = out_data.iter().sum();
2373 let loss = Tensor::from_operation(
2374 TensorStorage::cpu(vec![total]),
2375 vec![],
2376 Arc::new(TestSumBackward { input: out }),
2377 )
2378 .unwrap();
2379 loss.backward().unwrap();
2380
2381 let grad = input.grad().unwrap().unwrap();
2382 let g = grad.data().unwrap();
2383 for (i, &val) in g.iter().enumerate() {
2385 assert!(
2386 (val - 1.0).abs() < 1e-5,
2387 "grad[{i}]: expected 1.0, got {val}"
2388 );
2389 }
2390 }
2391
2392 #[test]
2397 fn test_unfold_shape() {
2398 let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2401 let out = unfold(&input, [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
2402 assert_eq!(out.shape(), &[1, 4, 9]);
2403 }
2404
2405 #[test]
2406 fn test_unfold_values() {
2407 #[rustfmt::skip]
2425 let data: Vec<f32> = vec![
2426 1.0, 2.0, 3.0,
2427 4.0, 5.0, 6.0,
2428 7.0, 8.0, 9.0,
2429 ];
2430 let input = leaf_4d(&data, [1, 1, 3, 3], false);
2431 let out = unfold(&input, [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
2432 assert_eq!(out.shape(), &[1, 4, 4]);
2433
2434 let d = out.data().unwrap();
2435 assert_close(&d[0..4], &[1.0, 2.0, 4.0, 5.0], 1e-6);
2436 assert_close(&d[4..8], &[2.0, 3.0, 5.0, 6.0], 1e-6);
2437 assert_close(&d[8..12], &[4.0, 5.0, 7.0, 8.0], 1e-6);
2438 assert_close(&d[12..16], &[5.0, 6.0, 8.0, 9.0], 1e-6);
2439 }
2440
2441 #[test]
2442 fn test_unfold_with_padding() {
2443 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2446 let input = leaf_4d(&data, [1, 1, 2, 2], false);
2447 let out = unfold(&input, [2, 2], [1, 1], [1, 1], [1, 1]).unwrap();
2448 assert_eq!(out.shape(), &[1, 4, 9]);
2449 }
2450
2451 #[test]
2452 fn test_unfold_with_stride() {
2453 let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2456 let out = unfold(&input, [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
2457 assert_eq!(out.shape(), &[1, 4, 4]);
2458 }
2459
2460 #[test]
2461 fn test_unfold_zero_kernel_rejected() {
2462 let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2463 assert!(unfold(&input, [0, 2], [1, 1], [0, 0], [1, 1]).is_err());
2464 }
2465
2466 #[test]
2471 fn test_fold_shape() {
2472 let data = vec![0.0f32; 36];
2474 let input = leaf(&data, &[1, 4, 9], false);
2475 let out = fold(&input, [4, 4], [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
2476 assert_eq!(out.shape(), &[1, 1, 4, 4]);
2477 }
2478
2479 #[test]
2480 fn test_unfold_fold_roundtrip() {
2481 #[rustfmt::skip]
2483 let data: Vec<f32> = vec![
2484 1.0, 2.0, 3.0, 4.0,
2485 5.0, 6.0, 7.0, 8.0,
2486 9.0, 10.0, 11.0, 12.0,
2487 13.0, 14.0, 15.0, 16.0,
2488 ];
2489 let input = leaf_4d(&data, [1, 1, 4, 4], false);
2490 let unfolded = unfold(&input, [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
2491 let refolded = fold(&unfolded, [4, 4], [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
2492 assert_eq!(refolded.shape(), &[1, 1, 4, 4]);
2493 assert_close(refolded.data().unwrap(), &data, 1e-6);
2494 }
2495
2496 #[test]
2497 fn test_fold_l_mismatch() {
2498 let data = vec![0.0f32; 20];
2500 let input = leaf(&data, &[1, 4, 5], false);
2501 assert!(fold(&input, [4, 4], [2, 2], [1, 1], [0, 0], [1, 1]).is_err());
2502 }
2503
2504 #[test]
2509 fn test_grid_sample_identity() {
2510 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2513 let input = leaf_4d(&data, [1, 1, 2, 2], false);
2514
2515 let grid_data: Vec<f32> = vec![-1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
2518 let grid = leaf(&grid_data, &[1, 2, 2, 2], false);
2519
2520 let out = grid_sample(
2521 &input,
2522 &grid,
2523 GridSampleMode::Bilinear,
2524 GridSamplePaddingMode::Zeros,
2525 true,
2526 )
2527 .unwrap();
2528 assert_eq!(out.shape(), &[1, 1, 2, 2]);
2529 assert_close(out.data().unwrap(), &data, 1e-5);
2530 }
2531
2532 #[test]
2533 fn test_grid_sample_nearest() {
2534 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2535 let input = leaf_4d(&data, [1, 1, 2, 2], false);
2536
2537 let grid_data: Vec<f32> = vec![-1.0, -1.0];
2539 let grid = leaf(&grid_data, &[1, 1, 1, 2], false);
2540
2541 let out = grid_sample(
2542 &input,
2543 &grid,
2544 GridSampleMode::Nearest,
2545 GridSamplePaddingMode::Zeros,
2546 true,
2547 )
2548 .unwrap();
2549 assert_eq!(out.shape(), &[1, 1, 1, 1]);
2550 assert!((out.data().unwrap()[0] - 1.0).abs() < 1e-5);
2551 }
2552
2553 #[test]
2554 fn test_grid_sample_batch_mismatch() {
2555 let input = leaf_4d(&[0.0; 8], [2, 1, 2, 2], false);
2556 let grid = leaf(&[0.0; 8], &[1, 2, 2, 2], false);
2557 assert!(
2558 grid_sample(
2559 &input,
2560 &grid,
2561 GridSampleMode::Bilinear,
2562 GridSamplePaddingMode::Zeros,
2563 true
2564 )
2565 .is_err()
2566 );
2567 }
2568
2569 #[test]
2570 fn test_grid_sample_wrong_grid_shape() {
2571 let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2572 let grid = leaf(&[0.0; 8], &[1, 2, 4], false);
2573 assert!(
2574 grid_sample(
2575 &input,
2576 &grid,
2577 GridSampleMode::Bilinear,
2578 GridSamplePaddingMode::Zeros,
2579 true
2580 )
2581 .is_err()
2582 );
2583 }
2584
2585 #[test]
2590 fn test_affine_grid_identity() {
2591 let theta_data: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
2593 let theta = leaf(&theta_data, &[1, 2, 3], false);
2594 let grid = affine_grid(&theta, [1, 1, 3, 3], true).unwrap();
2595 assert_eq!(grid.shape(), &[1, 3, 3, 2]);
2596
2597 let d = grid.data().unwrap();
2598 assert!((d[0] - (-1.0)).abs() < 1e-5); assert!((d[1] - (-1.0)).abs() < 1e-5); assert!((d[4] - 1.0).abs() < 1e-5); assert!((d[5] - (-1.0)).abs() < 1e-5); }
2606
2607 #[test]
2608 fn test_affine_grid_theta_shape_error() {
2609 let theta = leaf(&[0.0; 12], &[2, 3, 2], false);
2610 assert!(affine_grid(&theta, [2, 1, 3, 3], true).is_err());
2611 }
2612
2613 #[test]
2614 fn test_affine_grid_batch_mismatch() {
2615 let theta = leaf(&[0.0; 6], &[1, 2, 3], false);
2616 assert!(affine_grid(&theta, [2, 1, 3, 3], true).is_err());
2617 }
2618
2619 #[test]
2624 fn test_pixel_shuffle_module() {
2625 let ps = PixelShuffle::new(2);
2626 let input = leaf_4d(&[0.0; 16], [1, 4, 2, 2], false);
2627 let out: Tensor<f32> = Module::<f32>::forward(&ps, &input).unwrap();
2628 assert_eq!(out.shape(), &[1, 1, 4, 4]);
2629 }
2630
2631 #[test]
2632 fn test_pixel_unshuffle_module() {
2633 let pus = PixelUnshuffle::new(2);
2634 let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2635 let out: Tensor<f32> = Module::<f32>::forward(&pus, &input).unwrap();
2636 assert_eq!(out.shape(), &[1, 4, 2, 2]);
2637 }
2638
2639 #[test]
2644 fn test_unfold_module() {
2645 let uf = Unfold::new([2, 2], [1, 1], [0, 0], [1, 1]);
2646 let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2647 let out: Tensor<f32> = Module::<f32>::forward(&uf, &input).unwrap();
2648 assert_eq!(out.shape(), &[1, 4, 9]);
2649 }
2650
2651 #[test]
2652 fn test_fold_module() {
2653 let f = Fold::new([4, 4], [2, 2], [1, 1], [0, 0], [1, 1]);
2654 let data = vec![0.0f32; 36];
2655 let input = leaf(&data, &[1, 4, 9], false);
2656 let out: Tensor<f32> = Module::<f32>::forward(&f, &input).unwrap();
2657 assert_eq!(out.shape(), &[1, 1, 4, 4]);
2658 }
2659
2660 #[test]
2665 fn test_unfold_backward() {
2666 let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
2669 let input = leaf_4d(&data, [1, 1, 4, 4], true);
2670 let out = unfold(&input, [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
2671
2672 let out_data = out.data().unwrap().to_vec();
2673 let total: f32 = out_data.iter().sum();
2674 let loss = Tensor::from_operation(
2675 TensorStorage::cpu(vec![total]),
2676 vec![],
2677 Arc::new(TestSumBackward { input: out }),
2678 )
2679 .unwrap();
2680 loss.backward().unwrap();
2681
2682 let grad = input.grad().unwrap().unwrap();
2683 let g = grad.data().unwrap();
2684 for (i, &val) in g.iter().enumerate() {
2685 assert!(
2686 (val - 1.0).abs() < 1e-5,
2687 "grad[{i}]: expected 1.0, got {val}"
2688 );
2689 }
2690 }
2691
2692 #[test]
2697 fn test_unfold_backward_overlapping() {
2698 let data: Vec<f32> = (0..9).map(|i| i as f32).collect();
2701 let input = leaf_4d(&data, [1, 1, 3, 3], true);
2702 let out = unfold(&input, [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
2703 let out_data = out.data().unwrap().to_vec();
2706 let total: f32 = out_data.iter().sum();
2707 let loss = Tensor::from_operation(
2708 TensorStorage::cpu(vec![total]),
2709 vec![],
2710 Arc::new(TestSumBackward { input: out }),
2711 )
2712 .unwrap();
2713 loss.backward().unwrap();
2714
2715 let grad = input.grad().unwrap().unwrap();
2716 let g = grad.data().unwrap();
2717 #[rustfmt::skip]
2721 let expected: Vec<f32> = vec![
2722 1.0, 2.0, 1.0,
2723 2.0, 4.0, 2.0,
2724 1.0, 2.0, 1.0,
2725 ];
2726 assert_close(g, &expected, 1e-5);
2727 }
2728
2729 #[test]
2734 fn test_interpolate_multichannel_batch() {
2735 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
2737 let input = leaf_4d(&data, [2, 3, 2, 2], false);
2738 let out = interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).unwrap();
2739 assert_eq!(out.shape(), &[2, 3, 4, 4]);
2740 }
2741
2742 #[derive(Debug)]
2747 struct TestSumBackward {
2748 input: Tensor<f32>,
2749 }
2750
2751 impl GradFn<f32> for TestSumBackward {
2752 fn backward(
2753 &self,
2754 _grad_output: &Tensor<f32>,
2755 ) -> FerrotorchResult<Vec<Option<Tensor<f32>>>> {
2756 let ones_data = vec![1.0f32; self.input.numel()];
2757 let ones = Tensor::from_storage(
2758 TensorStorage::cpu(ones_data),
2759 self.input.shape().to_vec(),
2760 false,
2761 )?;
2762 Ok(vec![Some(ones)])
2763 }
2764
2765 fn inputs(&self) -> Vec<&Tensor<f32>> {
2766 vec![&self.input]
2767 }
2768
2769 fn name(&self) -> &'static str {
2770 "TestSumBackward"
2771 }
2772 }
2773}