1use ferrotorch_core::grad_fns::shape::reshape;
28use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
29
30use crate::module::Module;
31use crate::parameter::Parameter;
32
33#[derive(Debug, Clone, Copy)]
51pub struct Identity {
52 training: bool,
53}
54
55impl Identity {
56 pub fn new() -> Self {
58 Self { training: true }
59 }
60}
61
62impl Default for Identity {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl<T: Float> Module<T> for Identity {
69 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
70 Ok(input.clone())
71 }
72
73 fn parameters(&self) -> Vec<&Parameter<T>> {
74 vec![]
75 }
76
77 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
78 vec![]
79 }
80
81 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
82 vec![]
83 }
84
85 fn train(&mut self) {
86 self.training = true;
87 }
88
89 fn eval(&mut self) {
90 self.training = false;
91 }
92
93 fn is_training(&self) -> bool {
94 self.training
95 }
96}
97
98#[derive(Debug, Clone, Copy)]
124pub struct Flatten {
125 pub start_dim: usize,
127 pub end_dim: isize,
129 training: bool,
130}
131
132impl Flatten {
133 pub fn new(start_dim: usize, end_dim: isize) -> Self {
141 Self {
142 start_dim,
143 end_dim,
144 training: true,
145 }
146 }
147
148 fn resolve_end_dim(&self, ndim: usize) -> FerrotorchResult<usize> {
150 let resolved = if self.end_dim < 0 {
151 let d = ndim as isize + self.end_dim;
152 if d < 0 {
153 return Err(FerrotorchError::InvalidArgument {
154 message: format!(
155 "Flatten: end_dim {} is out of range for input with {} dims",
156 self.end_dim, ndim
157 ),
158 });
159 }
160 d as usize
161 } else {
162 self.end_dim as usize
163 };
164
165 if resolved >= ndim {
166 return Err(FerrotorchError::InvalidArgument {
167 message: format!(
168 "Flatten: resolved end_dim {} is out of range for input with {} dims",
169 resolved, ndim
170 ),
171 });
172 }
173
174 Ok(resolved)
175 }
176}
177
178impl Default for Flatten {
179 fn default() -> Self {
180 Self::new(1, -1)
181 }
182}
183
184impl<T: Float> Module<T> for Flatten {
185 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
186 let shape = input.shape();
187 let ndim = shape.len();
188
189 if ndim == 0 {
191 return Err(FerrotorchError::InvalidArgument {
192 message: "Flatten: cannot flatten a 0-D (scalar) tensor".into(),
193 });
194 }
195
196 if ndim == 1 {
198 return Ok(input.clone());
199 }
200
201 if self.start_dim >= ndim {
202 return Err(FerrotorchError::InvalidArgument {
203 message: format!(
204 "Flatten: start_dim {} is out of range for input with {} dims",
205 self.start_dim, ndim
206 ),
207 });
208 }
209
210 let end_dim = self.resolve_end_dim(ndim)?;
211
212 if self.start_dim > end_dim {
213 return Err(FerrotorchError::InvalidArgument {
214 message: format!(
215 "Flatten: start_dim ({}) must be <= end_dim ({})",
216 self.start_dim, end_dim
217 ),
218 });
219 }
220
221 if self.start_dim == end_dim {
223 return Ok(input.clone());
224 }
225
226 let mut new_shape: Vec<isize> = Vec::with_capacity(ndim - (end_dim - self.start_dim));
228
229 for &d in &shape[..self.start_dim] {
230 new_shape.push(d as isize);
231 }
232
233 let flattened: usize = shape[self.start_dim..=end_dim].iter().product();
235 new_shape.push(flattened as isize);
236
237 for &d in &shape[end_dim + 1..] {
238 new_shape.push(d as isize);
239 }
240
241 reshape(input, &new_shape)
242 }
243
244 fn parameters(&self) -> Vec<&Parameter<T>> {
245 vec![]
246 }
247
248 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
249 vec![]
250 }
251
252 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
253 vec![]
254 }
255
256 fn train(&mut self) {
257 self.training = true;
258 }
259
260 fn eval(&mut self) {
261 self.training = false;
262 }
263
264 fn is_training(&self) -> bool {
265 self.training
266 }
267}
268
269#[derive(Debug, Clone)]
281pub struct Unflatten {
282 pub dim: usize,
284 pub unflattened_size: Vec<usize>,
286 training: bool,
287}
288
289impl Unflatten {
290 pub fn new(dim: usize, unflattened_size: Vec<usize>) -> Self {
291 Self {
292 dim,
293 unflattened_size,
294 training: true,
295 }
296 }
297
298 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
299 let shape = input.shape();
300 if self.dim >= shape.len() {
301 return Err(FerrotorchError::InvalidArgument {
302 message: format!(
303 "Unflatten: dim {} out of range for input with {} dims",
304 self.dim,
305 shape.len()
306 ),
307 });
308 }
309
310 let expected_size: usize = self.unflattened_size.iter().product();
311 if expected_size != shape[self.dim] {
312 return Err(FerrotorchError::InvalidArgument {
313 message: format!(
314 "Unflatten: unflattened_size {:?} (product={}) doesn't match dim {} size {}",
315 self.unflattened_size, expected_size, self.dim, shape[self.dim]
316 ),
317 });
318 }
319
320 let mut new_shape = Vec::with_capacity(shape.len() - 1 + self.unflattened_size.len());
321 new_shape.extend_from_slice(&shape[..self.dim]);
322 new_shape.extend_from_slice(&self.unflattened_size);
323 new_shape.extend_from_slice(&shape[self.dim + 1..]);
324
325 input.view_reshape(new_shape)
326 }
327}
328
329impl<T: Float> Module<T> for Unflatten {
330 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
331 Unflatten::forward(self, input)
332 }
333
334 fn parameters(&self) -> Vec<&Parameter<T>> {
335 vec![]
336 }
337 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
338 vec![]
339 }
340 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
341 vec![]
342 }
343 fn train(&mut self) {
344 self.training = true;
345 }
346 fn eval(&mut self) {
347 self.training = false;
348 }
349 fn is_training(&self) -> bool {
350 self.training
351 }
352}
353
354#[derive(Debug, Clone)]
367pub struct ChannelShuffle {
368 pub groups: usize,
369 training: bool,
370}
371
372impl ChannelShuffle {
373 pub fn new(groups: usize) -> Self {
374 Self {
375 groups,
376 training: true,
377 }
378 }
379
380 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
381 if input.ndim() < 2 {
382 return Err(FerrotorchError::InvalidArgument {
383 message: format!(
384 "ChannelShuffle: input must have at least 2 dims, got {:?}",
385 input.shape()
386 ),
387 });
388 }
389 if input.is_cuda() {
390 return Err(FerrotorchError::NotImplementedOnCuda {
391 op: "ChannelShuffle",
392 });
393 }
394
395 let shape = input.shape();
396 let channels = shape[1];
397 if channels % self.groups != 0 {
398 return Err(FerrotorchError::InvalidArgument {
399 message: format!(
400 "ChannelShuffle: channels ({}) must be divisible by groups ({})",
401 channels, self.groups
402 ),
403 });
404 }
405
406 let g = self.groups;
407 let cpg = channels / g; let batch = shape[0];
409 let spatial: usize = shape[2..].iter().product();
410 let data = input.data()?;
411
412 let mut out = vec![<T as num_traits::Zero>::zero(); data.len()];
414 for n in 0..batch {
415 for c_out in 0..channels {
416 let c_in = (c_out % g) * cpg + (c_out / g);
418 for s in 0..spatial {
419 out[n * channels * spatial + c_out * spatial + s] =
420 data[n * channels * spatial + c_in * spatial + s];
421 }
422 }
423 }
424
425 Tensor::from_storage(
426 ferrotorch_core::storage::TensorStorage::cpu(out),
427 shape.to_vec(),
428 false,
429 )
430 }
431}
432
433impl<T: Float> Module<T> for ChannelShuffle {
434 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
435 ChannelShuffle::forward(self, input)
436 }
437
438 fn parameters(&self) -> Vec<&Parameter<T>> {
439 vec![]
440 }
441 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
442 vec![]
443 }
444 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
445 vec![]
446 }
447 fn train(&mut self) {
448 self.training = true;
449 }
450 fn eval(&mut self) {
451 self.training = false;
452 }
453 fn is_training(&self) -> bool {
454 self.training
455 }
456}
457
458#[derive(Debug, Clone)]
468pub struct CosineSimilarity {
469 pub dim: usize,
471 pub eps: f64,
473}
474
475impl CosineSimilarity {
476 pub fn new(dim: usize, eps: f64) -> Self {
477 Self { dim, eps }
478 }
479
480 pub fn forward<T: Float>(&self, x1: &Tensor<T>, x2: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
481 if x1.shape() != x2.shape() {
482 return Err(FerrotorchError::ShapeMismatch {
483 message: format!(
484 "CosineSimilarity: shapes must match, got {:?} and {:?}",
485 x1.shape(),
486 x2.shape()
487 ),
488 });
489 }
490 if x1.is_cuda() || x2.is_cuda() {
491 return Err(FerrotorchError::NotImplementedOnCuda {
492 op: "CosineSimilarity",
493 });
494 }
495
496 let shape = x1.shape();
497 if self.dim >= shape.len() {
498 return Err(FerrotorchError::InvalidArgument {
499 message: format!(
500 "CosineSimilarity: dim {} out of range for shape {:?}",
501 self.dim, shape
502 ),
503 });
504 }
505
506 let d1 = x1.data()?;
507 let d2 = x2.data()?;
508 let dim_size = shape[self.dim];
509 let outer: usize = shape[..self.dim].iter().product();
510 let inner: usize = shape[self.dim + 1..].iter().product();
511 let eps_t = T::from(self.eps).unwrap();
512
513 let out_numel = outer * inner;
514 let mut result = Vec::with_capacity(out_numel);
515
516 for o in 0..outer {
517 for i in 0..inner {
518 let mut dot = <T as num_traits::Zero>::zero();
519 let mut n1 = <T as num_traits::Zero>::zero();
520 let mut n2 = <T as num_traits::Zero>::zero();
521 for d in 0..dim_size {
522 let idx = o * dim_size * inner + d * inner + i;
523 dot += d1[idx] * d2[idx];
524 n1 += d1[idx] * d1[idx];
525 n2 += d2[idx] * d2[idx];
526 }
527 let denom = (n1.sqrt() * n2.sqrt()).max(eps_t);
528 result.push(dot / denom);
529 }
530 }
531
532 let mut out_shape = shape.to_vec();
533 out_shape.remove(self.dim);
534 if out_shape.is_empty() {
535 out_shape.push(1);
536 }
537 Tensor::from_storage(
538 ferrotorch_core::storage::TensorStorage::cpu(result),
539 out_shape,
540 false,
541 )
542 }
543}
544
545impl Default for CosineSimilarity {
546 fn default() -> Self {
547 Self::new(1, 1e-8)
548 }
549}
550
551#[derive(Debug, Clone)]
561pub struct PairwiseDistance {
562 pub p: f64,
564 pub eps: f64,
566 pub keepdim: bool,
568}
569
570impl PairwiseDistance {
571 pub fn new(p: f64, eps: f64, keepdim: bool) -> Self {
572 Self { p, eps, keepdim }
573 }
574
575 pub fn forward<T: Float>(&self, x1: &Tensor<T>, x2: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
576 if x1.shape() != x2.shape() {
577 return Err(FerrotorchError::ShapeMismatch {
578 message: format!(
579 "PairwiseDistance: shapes must match, got {:?} and {:?}",
580 x1.shape(),
581 x2.shape()
582 ),
583 });
584 }
585 if x1.is_cuda() || x2.is_cuda() {
586 return Err(FerrotorchError::NotImplementedOnCuda {
587 op: "PairwiseDistance",
588 });
589 }
590
591 let shape = x1.shape();
592 let ndim = shape.len();
593 if ndim == 0 {
594 return Err(FerrotorchError::InvalidArgument {
595 message: "PairwiseDistance: input must have at least 1 dimension".into(),
596 });
597 }
598
599 let d1 = x1.data()?;
600 let d2 = x2.data()?;
601 let last_dim = shape[ndim - 1];
602 let outer: usize = d1.len() / last_dim;
603 let p_t = T::from(self.p).unwrap();
604 let inv_p = T::from(1.0 / self.p).unwrap();
605 let eps_t = T::from(self.eps).unwrap();
606
607 let mut result = Vec::with_capacity(outer);
608 for o in 0..outer {
609 let mut norm = <T as num_traits::Zero>::zero();
610 for i in 0..last_dim {
611 let diff = d1[o * last_dim + i] - d2[o * last_dim + i];
612 let abs_diff = if diff < <T as num_traits::Zero>::zero() {
613 <T as num_traits::Zero>::zero() - diff
614 } else {
615 diff
616 };
617 norm += (abs_diff + eps_t).powf(p_t);
618 }
619 result.push(norm.powf(inv_p));
620 }
621
622 let mut out_shape: Vec<usize> = shape[..ndim - 1].to_vec();
623 if self.keepdim {
624 out_shape.push(1);
625 }
626 if out_shape.is_empty() {
627 out_shape.push(1);
628 }
629 Tensor::from_storage(
630 ferrotorch_core::storage::TensorStorage::cpu(result),
631 out_shape,
632 false,
633 )
634 }
635}
636
637impl Default for PairwiseDistance {
638 fn default() -> Self {
639 Self::new(2.0, 1e-6, false)
640 }
641}
642
643#[cfg(test)]
648mod tests {
649 use super::*;
650 use ferrotorch_core::autograd::graph::backward;
651 use ferrotorch_core::storage::TensorStorage;
652
653 fn leaf(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f64> {
655 Tensor::from_storage(
656 TensorStorage::cpu(data.to_vec()),
657 shape.to_vec(),
658 requires_grad,
659 )
660 .unwrap()
661 }
662
663 #[test]
668 fn test_identity_forward() {
669 let id = Identity::new();
670 let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
671 let output: Tensor<f64> = id.forward(&input).unwrap();
672 assert_eq!(output.shape(), input.shape());
673 assert_eq!(output.data_vec().unwrap(), input.data_vec().unwrap());
674 }
675
676 #[test]
677 fn test_identity_no_parameters() {
678 let id = Identity::new();
679 assert!(Module::<f64>::parameters(&id).is_empty());
680 assert!(Module::<f64>::named_parameters(&id).is_empty());
681 }
682
683 #[test]
684 fn test_identity_preserves_grad() {
685 let id = Identity::new();
686 let input = leaf(&[1.0, 2.0, 3.0], &[3], true);
687 let output: Tensor<f64> = id.forward(&input).unwrap();
688 assert!(output.requires_grad());
689 }
690
691 #[test]
692 fn test_identity_train_eval() {
693 let mut id = Identity::new();
694 assert!(Module::<f64>::is_training(&id));
695 Module::<f64>::eval(&mut id);
696 assert!(!Module::<f64>::is_training(&id));
697 Module::<f64>::train(&mut id);
698 assert!(Module::<f64>::is_training(&id));
699 }
700
701 #[test]
702 fn test_identity_empty_tensor() {
703 let id = Identity::new();
704 let input = leaf(&[], &[0], false);
705 let output: Tensor<f64> = id.forward(&input).unwrap();
706 assert_eq!(output.shape(), &[0]);
707 assert_eq!(output.numel(), 0);
708 }
709
710 #[test]
711 fn test_identity_is_send_sync() {
712 fn assert_send_sync<T: Send + Sync>() {}
713 assert_send_sync::<Identity>();
714 }
715
716 #[test]
721 fn test_flatten_default() {
722 let flatten = Flatten::default();
724 let input = leaf(
725 &(0..120).map(|i| i as f64).collect::<Vec<_>>(),
726 &[2, 3, 4, 5],
727 false,
728 );
729 let output: Tensor<f64> = flatten.forward(&input).unwrap();
730 assert_eq!(output.shape(), &[2, 60]);
731 }
732
733 #[test]
734 fn test_flatten_specific_range() {
735 let flatten = Flatten::new(2, 3);
737 let input = leaf(
738 &(0..120).map(|i| i as f64).collect::<Vec<_>>(),
739 &[2, 3, 4, 5],
740 false,
741 );
742 let output: Tensor<f64> = flatten.forward(&input).unwrap();
743 assert_eq!(output.shape(), &[2, 3, 20]);
744 }
745
746 #[test]
747 fn test_flatten_all_dims() {
748 let flatten = Flatten::new(0, -1);
750 let input = leaf(
751 &(0..24).map(|i| i as f64).collect::<Vec<_>>(),
752 &[2, 3, 4],
753 false,
754 );
755 let output: Tensor<f64> = flatten.forward(&input).unwrap();
756 assert_eq!(output.shape(), &[24]);
757 }
758
759 #[test]
760 fn test_flatten_noop_single_dim() {
761 let flatten = Flatten::new(1, 1);
763 let input = leaf(
764 &(0..12).map(|i| i as f64).collect::<Vec<_>>(),
765 &[3, 4],
766 false,
767 );
768 let output: Tensor<f64> = flatten.forward(&input).unwrap();
769 assert_eq!(output.shape(), &[3, 4]);
770 }
771
772 #[test]
773 fn test_flatten_1d_input() {
774 let flatten = Flatten::new(0, -1);
776 let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
777 let output: Tensor<f64> = flatten.forward(&input).unwrap();
778 assert_eq!(output.shape(), &[3]);
779 }
780
781 #[test]
782 fn test_flatten_0d_error() {
783 let flatten = Flatten::new(0, -1);
785 let input = leaf(&[42.0], &[], false);
786 assert!(Module::<f64>::forward(&flatten, &input).is_err());
787 }
788
789 #[test]
790 fn test_flatten_start_dim_out_of_range() {
791 let flatten = Flatten::new(5, -1);
792 let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
793 assert!(Module::<f64>::forward(&flatten, &input).is_err());
794 }
795
796 #[test]
797 fn test_flatten_end_dim_out_of_range() {
798 let flatten = Flatten::new(0, 10);
799 let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
800 assert!(Module::<f64>::forward(&flatten, &input).is_err());
801 }
802
803 #[test]
804 fn test_flatten_start_gt_end_error() {
805 let flatten = Flatten::new(2, 1);
806 let input = leaf(
807 &(0..24).map(|i| i as f64).collect::<Vec<_>>(),
808 &[2, 3, 4],
809 false,
810 );
811 assert!(Module::<f64>::forward(&flatten, &input).is_err());
812 }
813
814 #[test]
815 fn test_flatten_preserves_data() {
816 let flatten = Flatten::default();
817 let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
818 let input = leaf(&data, &[2, 3, 4], false);
819 let output: Tensor<f64> = flatten.forward(&input).unwrap();
820 assert_eq!(output.data_vec().unwrap(), data);
821 }
822
823 #[test]
824 fn test_flatten_backward() {
825 use ferrotorch_core::tensor::GradFn;
826 use std::sync::Arc;
827
828 #[derive(Debug)]
830 struct SumBackwardHelper {
831 input: Tensor<f64>,
832 }
833
834 impl GradFn<f64> for SumBackwardHelper {
835 fn backward(
836 &self,
837 _grad_output: &Tensor<f64>,
838 ) -> FerrotorchResult<Vec<Option<Tensor<f64>>>> {
839 let ones_data = vec![1.0f64; self.input.numel()];
840 let ones = Tensor::from_storage(
841 TensorStorage::cpu(ones_data),
842 self.input.shape().to_vec(),
843 false,
844 )?;
845 Ok(vec![Some(ones)])
846 }
847
848 fn inputs(&self) -> Vec<&Tensor<f64>> {
849 vec![&self.input]
850 }
851
852 fn name(&self) -> &'static str {
853 "SumBackwardHelper"
854 }
855 }
856
857 let flatten = Flatten::default();
858 let input = leaf(
859 &(0..24).map(|i| i as f64).collect::<Vec<_>>(),
860 &[2, 3, 4],
861 true,
862 );
863 let output: Tensor<f64> = flatten.forward(&input).unwrap();
864 assert_eq!(output.shape(), &[2, 12]);
865 assert!(output.requires_grad());
866
867 let out_data = output.data().unwrap();
869 let total: f64 = out_data.iter().sum();
870 let sum_gf = Arc::new(SumBackwardHelper {
871 input: output.clone(),
872 });
873 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
874 backward(&loss).unwrap();
875
876 let grad = input.grad().unwrap().unwrap();
877 assert_eq!(grad.shape(), &[2, 3, 4]);
878 for &v in grad.data().unwrap().iter() {
880 assert!((v - 1.0).abs() < 1e-10);
881 }
882 }
883
884 #[test]
885 fn test_flatten_no_parameters() {
886 let flatten = Flatten::default();
887 assert!(Module::<f64>::parameters(&flatten).is_empty());
888 assert!(Module::<f64>::named_parameters(&flatten).is_empty());
889 }
890
891 #[test]
892 fn test_flatten_zero_size_dim() {
893 let flatten = Flatten::default();
895 let input = leaf(&[], &[2, 0, 4], false);
896 let output: Tensor<f64> = flatten.forward(&input).unwrap();
897 assert_eq!(output.shape(), &[2, 0]);
898 }
899
900 #[test]
901 fn test_flatten_is_send_sync() {
902 fn assert_send_sync<T: Send + Sync>() {}
903 assert_send_sync::<Flatten>();
904 }
905}