1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6#[derive(Clone, Debug)]
56pub struct TensorStack<T, S, const D: usize> {
57 sources: S,
58 _type: PhantomData<T>,
59 along: (usize, Dimension),
60}
61
62fn validate_shapes_equal<const D: usize, I>(mut shapes: I)
63where
64 I: Iterator<Item = [(Dimension, usize); D]>,
65{
66 let first_shape = shapes.next().unwrap();
69 for (i, shape) in shapes.enumerate() {
70 if shape != first_shape {
71 panic!(
72 "The shapes of each tensor in the sources to stack along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
73 i + 1,
74 shape,
75 first_shape
76 );
77 }
78 }
79}
80
81impl<T, S, const D: usize, const N: usize> TensorStack<T, [S; N], D>
82where
83 S: TensorRef<T, D>,
84{
85 #[track_caller]
100 pub fn from(sources: [S; N], along: (usize, Dimension)) -> Self {
101 if N == 0 {
102 panic!("No sources provided");
103 }
104 if along.0 > D {
105 panic!(
106 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
107 along
108 );
109 }
110 let shape = sources[0].view_shape();
111 if dimensions::contains(&shape, along.1) {
112 panic!(
113 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
114 along, shape
115 );
116 }
117 validate_shapes_equal(sources.iter().map(|tensor| tensor.view_shape()));
118 Self {
119 sources,
120 along,
121 _type: PhantomData,
122 }
123 }
124
125 pub fn sources(self) -> [S; N] {
129 self.sources
130 }
131
132 pub fn sources_ref(&self) -> &[S; N] {
142 &self.sources
143 }
144
145 fn source_view_shape(&self) -> [(Dimension, usize); D] {
149 self.sources[0].view_shape()
150 }
151
152 fn number_of_sources() -> usize {
153 N
154 }
155}
156
157impl<T, S1, S2, const D: usize> TensorStack<T, (S1, S2), D>
158where
159 S1: TensorRef<T, D>,
160 S2: TensorRef<T, D>,
161{
162 #[track_caller]
173 pub fn from(sources: (S1, S2), along: (usize, Dimension)) -> Self {
174 if along.0 > D {
175 panic!(
176 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
177 along
178 );
179 }
180 let shape = sources.0.view_shape();
181 if dimensions::contains(&shape, along.1) {
182 panic!(
183 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
184 along, shape
185 );
186 }
187 validate_shapes_equal([sources.0.view_shape(), sources.1.view_shape()].into_iter());
188 Self {
189 sources,
190 along,
191 _type: PhantomData,
192 }
193 }
194
195 pub fn sources(self) -> (S1, S2) {
199 self.sources
200 }
201
202 pub fn sources_ref(&self) -> &(S1, S2) {
212 &self.sources
213 }
214
215 fn source_view_shape(&self) -> [(Dimension, usize); D] {
219 self.sources.0.view_shape()
220 }
221
222 fn number_of_sources() -> usize {
223 2
224 }
225}
226
227impl<T, S1, S2, S3, const D: usize> TensorStack<T, (S1, S2, S3), D>
228where
229 S1: TensorRef<T, D>,
230 S2: TensorRef<T, D>,
231 S3: TensorRef<T, D>,
232{
233 #[track_caller]
244 pub fn from(sources: (S1, S2, S3), along: (usize, Dimension)) -> Self {
245 if along.0 > D {
246 panic!(
247 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
248 along
249 );
250 }
251 let shape = sources.0.view_shape();
252 if dimensions::contains(&shape, along.1) {
253 panic!(
254 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
255 along, shape
256 );
257 }
258 validate_shapes_equal(
259 [
260 sources.0.view_shape(),
261 sources.1.view_shape(),
262 sources.2.view_shape(),
263 ]
264 .into_iter(),
265 );
266 Self {
267 sources,
268 along,
269 _type: PhantomData,
270 }
271 }
272
273 pub fn sources(self) -> (S1, S2, S3) {
277 self.sources
278 }
279
280 pub fn sources_ref(&self) -> &(S1, S2, S3) {
290 &self.sources
291 }
292
293 fn source_view_shape(&self) -> [(Dimension, usize); D] {
297 self.sources.0.view_shape()
298 }
299
300 fn number_of_sources() -> usize {
301 3
302 }
303}
304
305impl<T, S1, S2, S3, S4, const D: usize> TensorStack<T, (S1, S2, S3, S4), D>
306where
307 S1: TensorRef<T, D>,
308 S2: TensorRef<T, D>,
309 S3: TensorRef<T, D>,
310 S4: TensorRef<T, D>,
311{
312 #[track_caller]
323 pub fn from(sources: (S1, S2, S3, S4), along: (usize, Dimension)) -> Self {
324 if along.0 > D {
325 panic!(
326 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
327 along
328 );
329 }
330 let shape = sources.0.view_shape();
331 if dimensions::contains(&shape, along.1) {
332 panic!(
333 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
334 along, shape
335 );
336 }
337 validate_shapes_equal(
338 [
339 sources.0.view_shape(),
340 sources.1.view_shape(),
341 sources.2.view_shape(),
342 sources.3.view_shape(),
343 ]
344 .into_iter(),
345 );
346 Self {
347 sources,
348 along,
349 _type: PhantomData,
350 }
351 }
352
353 pub fn sources(self) -> (S1, S2, S3, S4) {
357 self.sources
358 }
359
360 pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
370 &self.sources
371 }
372
373 fn source_view_shape(&self) -> [(Dimension, usize); D] {
377 self.sources.0.view_shape()
378 }
379
380 fn number_of_sources() -> usize {
381 4
382 }
383}
384
385macro_rules! tensor_stack_ref_impl {
386 (unsafe impl TensorRef for TensorStack $d:literal $mod:ident) => {
387 mod $mod {
389 use crate::tensors::views::{TensorRef, TensorMut, DataLayout, TensorStack};
390 use crate::tensors::Dimension;
391
392 fn view_shape_impl(
393 shape: [(Dimension, usize); $d],
394 along: (usize, Dimension),
395 sources: usize,
396 ) -> [(Dimension, usize); $d + 1] {
397 let mut extra_shape = [("", 0); $d + 1];
398 let mut i = 0;
399 for (d, dimension) in extra_shape.iter_mut().enumerate() {
400 match d == along.0 {
401 true => {
402 *dimension = (along.1, sources);
403 },
405 false => {
406 *dimension = shape[i];
407 i += 1;
408 }
409 }
410 }
411 extra_shape
412 }
413
414 fn indexing(
415 indexes: [usize; $d + 1],
416 along: (usize, Dimension)
417 ) -> (usize, [usize; $d]) {
418 let mut indexes_into_source = [0; $d];
419 let mut i = 0;
420 for (d, &index) in indexes.iter().enumerate() {
421 if d != along.0 {
422 indexes_into_source[i] = index;
423 i += 1;
424 }
425 }
426 (indexes[along.0], indexes_into_source)
427 }
428
429 unsafe impl<T, S, const N: usize> TensorRef<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
430 where
431 S: TensorRef<T, $d>
432 {
433 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
434 let (source, indexes) = indexing(indexes, self.along);
435 self.sources.get(source)?.get_reference(indexes)
436 }
437
438 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
439 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
440 }
441
442 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
443 let (source, indexes) = indexing(indexes, self.along);
444 self.sources.get(source).unwrap().get_reference_unchecked(indexes)
446 }}
447
448 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
449 DataLayout::NonLinear
452 }
453 }
454
455 unsafe impl<T, S, const N: usize> TensorMut<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
456 where
457 S: TensorMut<T, $d>
458 {
459 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
460 let (source, indexes) = indexing(indexes, self.along);
461 self.sources.get_mut(source)?.get_reference_mut(indexes)
462 }
463
464 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
465 let (source, indexes) = indexing(indexes, self.along);
466 self.sources.get_mut(source).unwrap().get_reference_unchecked_mut(indexes)
468 }}
469 }
470
471 unsafe impl<T, S1, S2> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
472 where
473 S1: TensorRef<T, $d>,
474 S2: TensorRef<T, $d>,
475 {
476 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
477 let (source, indexes) = indexing(indexes, self.along);
478 match source {
479 0 => self.sources.0.get_reference(indexes),
480 1 => self.sources.1.get_reference(indexes),
481 _ => None
482 }
483 }
484
485 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
486 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
487 }
488
489 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
490 let (source, indexes) = indexing(indexes, self.along);
491 match source {
492 0 => self.sources.0.get_reference_unchecked(indexes),
493 1 => self.sources.1.get_reference_unchecked(indexes),
494 _ => panic!(
496 "Invalid index should never be given to get_reference_unchecked"
497 )
498 }
499 }}
500
501 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
502 DataLayout::NonLinear
505 }
506 }
507
508 unsafe impl<T, S1, S2> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
509 where
510 S1: TensorMut<T, $d>,
511 S2: TensorMut<T, $d>,
512 {
513 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
514 let (source, indexes) = indexing(indexes, self.along);
515 match source {
516 0 => self.sources.0.get_reference_mut(indexes),
517 1 => self.sources.1.get_reference_mut(indexes),
518 _ => None
519 }
520 }
521
522 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
523 let (source, indexes) = indexing(indexes, self.along);
524 match source {
525 0 => self.sources.0.get_reference_unchecked_mut(indexes),
526 1 => self.sources.1.get_reference_unchecked_mut(indexes),
527 _ => panic!(
529 "Invalid index should never be given to get_reference_unchecked"
530 )
531 }
532 }}
533 }
534
535 unsafe impl<T, S1, S2, S3> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
536 where
537 S1: TensorRef<T, $d>,
538 S2: TensorRef<T, $d>,
539 S3: TensorRef<T, $d>,
540 {
541 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
542 let (source, indexes) = indexing(indexes, self.along);
543 match source {
544 0 => self.sources.0.get_reference(indexes),
545 1 => self.sources.1.get_reference(indexes),
546 2 => self.sources.2.get_reference(indexes),
547 _ => None
548 }
549 }
550
551 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
552 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
553 }
554
555 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
556 let (source, indexes) = indexing(indexes, self.along);
557 match source {
558 0 => self.sources.0.get_reference_unchecked(indexes),
559 1 => self.sources.1.get_reference_unchecked(indexes),
560 2 => self.sources.2.get_reference_unchecked(indexes),
561 _ => panic!(
563 "Invalid index should never be given to get_reference_unchecked"
564 )
565 }
566 }}
567
568 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
569 DataLayout::NonLinear
572 }
573 }
574
575 unsafe impl<T, S1, S2, S3> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
576 where
577 S1: TensorMut<T, $d>,
578 S2: TensorMut<T, $d>,
579 S3: TensorMut<T, $d>,
580 {
581 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
582 let (source, indexes) = indexing(indexes, self.along);
583 match source {
584 0 => self.sources.0.get_reference_mut(indexes),
585 1 => self.sources.1.get_reference_mut(indexes),
586 2 => self.sources.2.get_reference_mut(indexes),
587 _ => None
588 }
589 }
590
591 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
592 let (source, indexes) = indexing(indexes, self.along);
593 match source {
594 0 => self.sources.0.get_reference_unchecked_mut(indexes),
595 1 => self.sources.1.get_reference_unchecked_mut(indexes),
596 2 => self.sources.2.get_reference_unchecked_mut(indexes),
597 _ => panic!(
599 "Invalid index should never be given to get_reference_unchecked"
600 )
601 }
602 }}
603 }
604
605 unsafe impl<T, S1, S2, S3, S4> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
606 where
607 S1: TensorRef<T, $d>,
608 S2: TensorRef<T, $d>,
609 S3: TensorRef<T, $d>,
610 S4: TensorRef<T, $d>,
611 {
612 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
613 let (source, indexes) = indexing(indexes, self.along);
614 match source {
615 0 => self.sources.0.get_reference(indexes),
616 1 => self.sources.1.get_reference(indexes),
617 2 => self.sources.2.get_reference(indexes),
618 3 => self.sources.3.get_reference(indexes),
619 _ => None
620 }
621 }
622
623 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
624 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
625 }
626
627 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
628 let (source, indexes) = indexing(indexes, self.along);
629 match source {
630 0 => self.sources.0.get_reference_unchecked(indexes),
631 1 => self.sources.1.get_reference_unchecked(indexes),
632 2 => self.sources.2.get_reference_unchecked(indexes),
633 3 => self.sources.3.get_reference_unchecked(indexes),
634 _ => panic!(
636 "Invalid index should never be given to get_reference_unchecked"
637 )
638 }
639 }}
640
641 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
642 DataLayout::NonLinear
645 }
646 }
647
648 unsafe impl<T, S1, S2, S3, S4> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
649 where
650 S1: TensorMut<T, $d>,
651 S2: TensorMut<T, $d>,
652 S3: TensorMut<T, $d>,
653 S4: TensorMut<T, $d>,
654 {
655 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
656 let (source, indexes) = indexing(indexes, self.along);
657 match source {
658 0 => self.sources.0.get_reference_mut(indexes),
659 1 => self.sources.1.get_reference_mut(indexes),
660 2 => self.sources.2.get_reference_mut(indexes),
661 3 => self.sources.3.get_reference_mut(indexes),
662 _ => None
663 }
664 }
665
666 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
667 let (source, indexes) = indexing(indexes, self.along);
668 match source {
669 0 => self.sources.0.get_reference_unchecked_mut(indexes),
670 1 => self.sources.1.get_reference_unchecked_mut(indexes),
671 2 => self.sources.2.get_reference_unchecked_mut(indexes),
672 3 => self.sources.3.get_reference_unchecked_mut(indexes),
673 _ => panic!(
675 "Invalid index should never be given to get_reference_unchecked"
676 )
677 }
678 }}
679 }
680 }
681 }
682}
683
684tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 0 zero);
685tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 1 one);
686tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 2 two);
687tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 3 three);
688tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 4 four);
689tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 5 five);
690
691#[test]
692fn test_stacking() {
693 use crate::tensors::Tensor;
694 use crate::tensors::views::{TensorMut, TensorView};
695 let vector1 = Tensor::from([("a", 3)], vec![9, 5, 2]);
696 let vector2 = Tensor::from([("a", 3)], vec![3, 6, 0]);
697 let vector3 = Tensor::from([("a", 3)], vec![8, 7, 1]);
698 let matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
699 (&vector1, &vector2, &vector3),
700 (1, "b"),
701 ));
702 #[rustfmt::skip]
703 assert_eq!(
704 matrix,
705 Tensor::from([("a", 3), ("b", 3)], vec![
706 9, 3, 8,
707 5, 6, 7,
708 2, 0, 1,
709 ])
710 );
711 let different_matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
712 (&vector1, &vector2, &vector3),
713 (0, "b"),
714 ));
715 #[rustfmt::skip]
716 assert_eq!(
717 different_matrix,
718 Tensor::from([("b", 3), ("a", 3)], vec![
719 9, 5, 2,
720 3, 6, 0,
721 8, 7, 1,
722 ])
723 );
724 let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
725 let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
726 Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
727 let tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
728 [matrix_erased, different_matrix_erased],
729 (2, "c"),
730 ));
731 #[rustfmt::skip]
732 assert!(
733 tensor.eq(
734 &Tensor::from([("a", 3), ("b", 3), ("c", 2)], vec![
735 9, 9,
736 3, 5,
737 8, 2,
738
739 5, 3,
740 6, 6,
741 7, 0,
742
743 2, 8,
744 0, 7,
745 1, 1
746 ])
747 ),
748 );
749 let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
750 let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
751 Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
752 let different_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
753 [matrix_erased, different_matrix_erased],
754 (1, "c"),
755 ));
756 #[rustfmt::skip]
757 assert!(
758 different_tensor.eq(
759 &Tensor::from([("a", 3), ("c", 2), ("b", 3)], vec![
760 9, 3, 8,
761 9, 5, 2,
762
763 5, 6, 7,
764 3, 6, 0,
765
766 2, 0, 1,
767 8, 7, 1
768 ])
769 ),
770 );
771 let matrix_erased: Box<dyn TensorRef<_, 2>> = Box::new(matrix.map(|x| x));
772 let different_matrix_erased: Box<dyn TensorRef<_, 2>> =
773 Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
774 let another_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
775 [matrix_erased, different_matrix_erased],
776 (0, "c"),
777 ));
778 #[rustfmt::skip]
779 assert!(
780 another_tensor.eq(
781 &Tensor::from([("c", 2), ("a", 3), ("b", 3)], vec![
782 9, 3, 8,
783 5, 6, 7,
784 2, 0, 1,
785
786 9, 5, 2,
787 3, 6, 0,
788 8, 7, 1,
789 ])
790 ),
791 );
792}
793
794#[derive(Clone, Debug)]
843pub struct TensorChain<T, S, const D: usize> {
844 sources: S,
845 _type: PhantomData<T>,
846 along: usize,
847}
848
849fn validate_shapes_similar<const D: usize, I>(mut shapes: I, along: usize)
850where
851 I: Iterator<Item = [(Dimension, usize); D]>,
852{
853 let first_shape = shapes.next().unwrap();
856 for (i, shape) in shapes.enumerate() {
857 for d in 0..D {
858 let similar = if d == along {
859 shape[d].0 == first_shape[d].0
861 } else {
862 shape[d] == first_shape[d]
863 };
864 if !similar {
865 panic!(
866 "The shapes of each tensor in the sources to chain along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
867 i + 1,
868 shape,
869 first_shape
870 );
871 }
872 }
873 }
874}
875
876impl<T, S, const D: usize, const N: usize> TensorChain<T, [S; N], D>
877where
878 S: TensorRef<T, D>,
879{
880 #[track_caller]
894 pub fn from(sources: [S; N], along: Dimension) -> Self {
895 if N == 0 {
896 panic!("No sources provided");
897 }
898 if D == 0 {
899 panic!("Can't chain along 0 dimensional tensors");
900 }
901 let shape = sources[0].view_shape();
902 let along = match dimensions::position_of(&shape, along) {
903 Some(d) => d,
904 None => panic!(
905 "The dimension {:?} is not in the source's shapes: {:?}",
906 along, shape
907 ),
908 };
909 validate_shapes_similar(sources.iter().map(|tensor| tensor.view_shape()), along);
910 Self {
911 sources,
912 along,
913 _type: PhantomData,
914 }
915 }
916
917 pub fn sources(self) -> [S; N] {
921 self.sources
922 }
923
924 pub fn sources_ref(&self) -> &[S; N] {
934 &self.sources
935 }
936}
937
938impl<T, S1, S2, const D: usize> TensorChain<T, (S1, S2), D>
939where
940 S1: TensorRef<T, D>,
941 S2: TensorRef<T, D>,
942{
943 #[track_caller]
956 pub fn from(sources: (S1, S2), along: Dimension) -> Self {
957 if D == 0 {
958 panic!("Can't chain along 0 dimensional tensors");
959 }
960 let shape = sources.0.view_shape();
961 let along = match dimensions::position_of(&shape, along) {
962 Some(d) => d,
963 None => panic!(
964 "The dimension {:?} is not in the source's shapes: {:?}",
965 along, shape
966 ),
967 };
968 validate_shapes_similar(
969 [sources.0.view_shape(), sources.1.view_shape()].into_iter(),
970 along,
971 );
972 Self {
973 sources,
974 along,
975 _type: PhantomData,
976 }
977 }
978
979 pub fn sources(self) -> (S1, S2) {
983 self.sources
984 }
985
986 pub fn sources_ref(&self) -> &(S1, S2) {
996 &self.sources
997 }
998}
999
1000impl<T, S1, S2, S3, const D: usize> TensorChain<T, (S1, S2, S3), D>
1001where
1002 S1: TensorRef<T, D>,
1003 S2: TensorRef<T, D>,
1004 S3: TensorRef<T, D>,
1005{
1006 #[track_caller]
1019 pub fn from(sources: (S1, S2, S3), along: Dimension) -> Self {
1020 if D == 0 {
1021 panic!("Can't chain along 0 dimensional tensors");
1022 }
1023 let shape = sources.0.view_shape();
1024 let along = match dimensions::position_of(&shape, along) {
1025 Some(d) => d,
1026 None => panic!(
1027 "The dimension {:?} is not in the source's shapes: {:?}",
1028 along, shape
1029 ),
1030 };
1031 validate_shapes_similar(
1032 [
1033 sources.0.view_shape(),
1034 sources.1.view_shape(),
1035 sources.2.view_shape(),
1036 ]
1037 .into_iter(),
1038 along,
1039 );
1040 Self {
1041 sources,
1042 along,
1043 _type: PhantomData,
1044 }
1045 }
1046
1047 pub fn sources(self) -> (S1, S2, S3) {
1051 self.sources
1052 }
1053
1054 pub fn sources_ref(&self) -> &(S1, S2, S3) {
1064 &self.sources
1065 }
1066}
1067
1068impl<T, S1, S2, S3, S4, const D: usize> TensorChain<T, (S1, S2, S3, S4), D>
1069where
1070 S1: TensorRef<T, D>,
1071 S2: TensorRef<T, D>,
1072 S3: TensorRef<T, D>,
1073 S4: TensorRef<T, D>,
1074{
1075 #[track_caller]
1088 pub fn from(sources: (S1, S2, S3, S4), along: Dimension) -> Self {
1089 if D == 0 {
1090 panic!("Can't chain along 0 dimensional tensors");
1091 }
1092 let shape = sources.0.view_shape();
1093 let along = match dimensions::position_of(&shape, along) {
1094 Some(d) => d,
1095 None => panic!(
1096 "The dimension {:?} is not in the source's shapes: {:?}",
1097 along, shape
1098 ),
1099 };
1100 validate_shapes_similar(
1101 [
1102 sources.0.view_shape(),
1103 sources.1.view_shape(),
1104 sources.2.view_shape(),
1105 sources.3.view_shape(),
1106 ]
1107 .into_iter(),
1108 along,
1109 );
1110 Self {
1111 sources,
1112 along,
1113 _type: PhantomData,
1114 }
1115 }
1116
1117 pub fn sources(self) -> (S1, S2, S3, S4) {
1121 self.sources
1122 }
1123
1124 pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
1134 &self.sources
1135 }
1136}
1137
1138fn view_shape_impl<I, const D: usize>(
1139 first_shape: [(Dimension, usize); D],
1140 shapes: I,
1141 along: usize,
1142) -> [(Dimension, usize); D]
1143where
1144 I: Iterator<Item = [(Dimension, usize); D]>,
1145{
1146 let mut shape = first_shape;
1147 shape[along].1 = shapes.into_iter().map(|shape| shape[along].1).sum();
1148 shape
1149}
1150
1151fn indexing<I, const D: usize>(
1152 indexes: [usize; D],
1153 shapes: I,
1154 along: usize,
1155) -> Option<(usize, [usize; D])>
1156where
1157 I: Iterator<Item = [(Dimension, usize); D]>,
1158{
1159 let mut shapes = shapes.enumerate();
1160 let mut i = indexes[along];
1164 loop {
1165 let (source, next_shape) = shapes.next()?;
1166 let length_along_chained_dimension = next_shape[along].1;
1167 if i < length_along_chained_dimension {
1168 #[allow(clippy::clone_on_copy)]
1169 let mut indexes = indexes.clone();
1170 indexes[along] = i;
1171 return Some((source, indexes));
1172 }
1173 i -= length_along_chained_dimension;
1174 }
1175}
1176
1177unsafe impl<T, S, const D: usize, const N: usize> TensorRef<T, D> for TensorChain<T, [S; N], D>
1178where
1179 S: TensorRef<T, D>,
1180{
1181 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1182 let (source, indexes) = indexing(
1183 indexes,
1184 self.sources.iter().map(|s| s.view_shape()),
1185 self.along,
1186 )?;
1187 self.sources.get(source)?.get_reference(indexes)
1188 }
1189
1190 fn view_shape(&self) -> [(Dimension, usize); D] {
1191 view_shape_impl(
1192 self.sources[0].view_shape(),
1193 self.sources.iter().map(|s| s.view_shape()),
1194 self.along,
1195 )
1196 }
1197
1198 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1199 unsafe {
1200 let (source, indexes) = indexing(
1202 indexes,
1203 self.sources.iter().map(|s| s.view_shape()),
1204 self.along,
1205 )
1206 .unwrap();
1207 self.sources
1208 .get(source)
1209 .unwrap()
1210 .get_reference_unchecked(indexes)
1211 }
1212 }
1213
1214 fn data_layout(&self) -> DataLayout<D> {
1215 DataLayout::NonLinear
1218 }
1219}
1220
1221unsafe impl<T, S, const D: usize, const N: usize> TensorMut<T, D> for TensorChain<T, [S; N], D>
1222where
1223 S: TensorMut<T, D>,
1224{
1225 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1226 let (source, indexes) = indexing(
1227 indexes,
1228 self.sources.iter().map(|s| s.view_shape()),
1229 self.along,
1230 )?;
1231 self.sources.get_mut(source)?.get_reference_mut(indexes)
1232 }
1233
1234 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1235 unsafe {
1236 let (source, indexes) = indexing(
1238 indexes,
1239 self.sources.iter().map(|s| s.view_shape()),
1240 self.along,
1241 )
1242 .unwrap();
1243 self.sources
1244 .get_mut(source)
1245 .unwrap()
1246 .get_reference_unchecked_mut(indexes)
1247 }
1248 }
1249}
1250
1251unsafe impl<T, S1, S2, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2), D>
1252where
1253 S1: TensorRef<T, D>,
1254 S2: TensorRef<T, D>,
1255{
1256 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1257 let (source, indexes) = indexing(
1258 indexes,
1259 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1260 self.along,
1261 )?;
1262 match source {
1263 0 => self.sources.0.get_reference(indexes),
1264 1 => self.sources.1.get_reference(indexes),
1265 _ => None,
1266 }
1267 }
1268
1269 fn view_shape(&self) -> [(Dimension, usize); D] {
1270 view_shape_impl(
1271 self.sources.0.view_shape(),
1272 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1273 self.along,
1274 )
1275 }
1276
1277 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1278 unsafe {
1279 let (source, indexes) = indexing(
1281 indexes,
1282 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1283 self.along,
1284 )
1285 .unwrap();
1286 match source {
1287 0 => self.sources.0.get_reference_unchecked(indexes),
1288 1 => self.sources.1.get_reference_unchecked(indexes),
1289 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1291 }
1292 }
1293 }
1294
1295 fn data_layout(&self) -> DataLayout<D> {
1296 DataLayout::NonLinear
1299 }
1300}
1301
1302unsafe impl<T, S1, S2, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2), D>
1303where
1304 S1: TensorMut<T, D>,
1305 S2: TensorMut<T, D>,
1306{
1307 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1308 let (source, indexes) = indexing(
1309 indexes,
1310 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1311 self.along,
1312 )?;
1313 match source {
1314 0 => self.sources.0.get_reference_mut(indexes),
1315 1 => self.sources.1.get_reference_mut(indexes),
1316 _ => None,
1317 }
1318 }
1319
1320 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1321 unsafe {
1322 let (source, indexes) = indexing(
1324 indexes,
1325 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1326 self.along,
1327 )
1328 .unwrap();
1329 match source {
1330 0 => self.sources.0.get_reference_unchecked_mut(indexes),
1331 1 => self.sources.1.get_reference_unchecked_mut(indexes),
1332 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1334 }
1335 }
1336 }
1337}
1338
1339unsafe impl<T, S1, S2, S3, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2, S3), D>
1340where
1341 S1: TensorRef<T, D>,
1342 S2: TensorRef<T, D>,
1343 S3: TensorRef<T, D>,
1344{
1345 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1346 let (source, indexes) = indexing(
1347 indexes,
1348 [
1349 self.sources.0.view_shape(),
1350 self.sources.1.view_shape(),
1351 self.sources.2.view_shape(),
1352 ]
1353 .into_iter(),
1354 self.along,
1355 )?;
1356 match source {
1357 0 => self.sources.0.get_reference(indexes),
1358 1 => self.sources.1.get_reference(indexes),
1359 2 => self.sources.2.get_reference(indexes),
1360 _ => None,
1361 }
1362 }
1363
1364 fn view_shape(&self) -> [(Dimension, usize); D] {
1365 view_shape_impl(
1366 self.sources.0.view_shape(),
1367 [
1368 self.sources.0.view_shape(),
1369 self.sources.1.view_shape(),
1370 self.sources.2.view_shape(),
1371 ]
1372 .into_iter(),
1373 self.along,
1374 )
1375 }
1376
1377 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1378 unsafe {
1379 let (source, indexes) = indexing(
1381 indexes,
1382 [
1383 self.sources.0.view_shape(),
1384 self.sources.1.view_shape(),
1385 self.sources.2.view_shape(),
1386 ]
1387 .into_iter(),
1388 self.along,
1389 )
1390 .unwrap();
1391 match source {
1392 0 => self.sources.0.get_reference_unchecked(indexes),
1393 1 => self.sources.1.get_reference_unchecked(indexes),
1394 2 => self.sources.2.get_reference_unchecked(indexes),
1395 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1397 }
1398 }
1399 }
1400
1401 fn data_layout(&self) -> DataLayout<D> {
1402 DataLayout::NonLinear
1405 }
1406}
1407
1408unsafe impl<T, S1, S2, S3, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2, S3), D>
1409where
1410 S1: TensorMut<T, D>,
1411 S2: TensorMut<T, D>,
1412 S3: TensorMut<T, D>,
1413{
1414 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1415 let (source, indexes) = indexing(
1416 indexes,
1417 [
1418 self.sources.0.view_shape(),
1419 self.sources.1.view_shape(),
1420 self.sources.2.view_shape(),
1421 ]
1422 .into_iter(),
1423 self.along,
1424 )?;
1425 match source {
1426 0 => self.sources.0.get_reference_mut(indexes),
1427 1 => self.sources.1.get_reference_mut(indexes),
1428 2 => self.sources.2.get_reference_mut(indexes),
1429 _ => None,
1430 }
1431 }
1432
1433 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1434 unsafe {
1435 let (source, indexes) = indexing(
1437 indexes,
1438 [
1439 self.sources.0.view_shape(),
1440 self.sources.1.view_shape(),
1441 self.sources.2.view_shape(),
1442 ]
1443 .into_iter(),
1444 self.along,
1445 )
1446 .unwrap();
1447 match source {
1448 0 => self.sources.0.get_reference_unchecked_mut(indexes),
1449 1 => self.sources.1.get_reference_unchecked_mut(indexes),
1450 2 => self.sources.2.get_reference_unchecked_mut(indexes),
1451 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1453 }
1454 }
1455 }
1456}
1457
1458unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorRef<T, D>
1459 for TensorChain<T, (S1, S2, S3, S4), D>
1460where
1461 S1: TensorRef<T, D>,
1462 S2: TensorRef<T, D>,
1463 S3: TensorRef<T, D>,
1464 S4: TensorRef<T, D>,
1465{
1466 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1467 let (source, indexes) = indexing(
1468 indexes,
1469 [
1470 self.sources.0.view_shape(),
1471 self.sources.1.view_shape(),
1472 self.sources.2.view_shape(),
1473 self.sources.3.view_shape(),
1474 ]
1475 .into_iter(),
1476 self.along,
1477 )?;
1478 match source {
1479 0 => self.sources.0.get_reference(indexes),
1480 1 => self.sources.1.get_reference(indexes),
1481 2 => self.sources.2.get_reference(indexes),
1482 3 => self.sources.3.get_reference(indexes),
1483 _ => None,
1484 }
1485 }
1486
1487 fn view_shape(&self) -> [(Dimension, usize); D] {
1488 view_shape_impl(
1489 self.sources.0.view_shape(),
1490 [
1491 self.sources.0.view_shape(),
1492 self.sources.1.view_shape(),
1493 self.sources.2.view_shape(),
1494 self.sources.3.view_shape(),
1495 ]
1496 .into_iter(),
1497 self.along,
1498 )
1499 }
1500
1501 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1502 unsafe {
1503 let (source, indexes) = indexing(
1505 indexes,
1506 [
1507 self.sources.0.view_shape(),
1508 self.sources.1.view_shape(),
1509 self.sources.2.view_shape(),
1510 self.sources.3.view_shape(),
1511 ]
1512 .into_iter(),
1513 self.along,
1514 )
1515 .unwrap();
1516 match source {
1517 0 => self.sources.0.get_reference_unchecked(indexes),
1518 1 => self.sources.1.get_reference_unchecked(indexes),
1519 2 => self.sources.2.get_reference_unchecked(indexes),
1520 3 => self.sources.3.get_reference_unchecked(indexes),
1521 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1523 }
1524 }
1525 }
1526
1527 fn data_layout(&self) -> DataLayout<D> {
1528 DataLayout::NonLinear
1531 }
1532}
1533
1534unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorMut<T, D>
1535 for TensorChain<T, (S1, S2, S3, S4), D>
1536where
1537 S1: TensorMut<T, D>,
1538 S2: TensorMut<T, D>,
1539 S3: TensorMut<T, D>,
1540 S4: TensorMut<T, D>,
1541{
1542 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1543 let (source, indexes) = indexing(
1544 indexes,
1545 [
1546 self.sources.0.view_shape(),
1547 self.sources.1.view_shape(),
1548 self.sources.2.view_shape(),
1549 self.sources.3.view_shape(),
1550 ]
1551 .into_iter(),
1552 self.along,
1553 )?;
1554 match source {
1555 0 => self.sources.0.get_reference_mut(indexes),
1556 1 => self.sources.1.get_reference_mut(indexes),
1557 2 => self.sources.2.get_reference_mut(indexes),
1558 3 => self.sources.3.get_reference_mut(indexes),
1559 _ => None,
1560 }
1561 }
1562
1563 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1564 unsafe {
1565 let (source, indexes) = indexing(
1567 indexes,
1568 [
1569 self.sources.0.view_shape(),
1570 self.sources.1.view_shape(),
1571 self.sources.2.view_shape(),
1572 self.sources.3.view_shape(),
1573 ]
1574 .into_iter(),
1575 self.along,
1576 )
1577 .unwrap();
1578 match source {
1579 0 => self.sources.0.get_reference_unchecked_mut(indexes),
1580 1 => self.sources.1.get_reference_unchecked_mut(indexes),
1581 2 => self.sources.2.get_reference_unchecked_mut(indexes),
1582 3 => self.sources.3.get_reference_unchecked_mut(indexes),
1583 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1585 }
1586 }
1587 }
1588}
1589
1590#[test]
1591fn test_chaining() {
1592 use crate::tensors::Tensor;
1593 use crate::tensors::views::TensorView;
1594 #[rustfmt::skip]
1595 let matrix1 = Tensor::from(
1596 [("a", 3), ("b", 2)],
1597 vec![
1598 9, 5,
1599 2, 1,
1600 3, 5
1601 ]
1602 );
1603 #[rustfmt::skip]
1604 let matrix2 = Tensor::from(
1605 [("a", 4), ("b", 2)],
1606 vec![
1607 0, 1,
1608 8, 4,
1609 1, 7,
1610 6, 3
1611 ]
1612 );
1613 let matrix = TensorView::from(TensorChain::<_, (_, _), 2>::from((&matrix1, &matrix2), "a"));
1614 #[rustfmt::skip]
1615 assert_eq!(
1616 matrix,
1617 Tensor::from([("a", 7), ("b", 2)], vec![
1618 9, 5,
1619 2, 1,
1620 3, 5,
1621 0, 1,
1622 8, 4,
1623 1, 7,
1624 6, 3
1625 ])
1626 );
1627 let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
1628 let different_matrix = Tensor::from([("a", 7), ("b", 1)], (0..7).collect());
1629 let different_matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(different_matrix);
1630 let another_matrix = TensorView::from(TensorChain::<_, [_; 2], 2>::from(
1631 [matrix_erased, different_matrix_erased],
1632 "b",
1633 ));
1634 #[rustfmt::skip]
1635 assert!(
1636 another_matrix.eq(
1637 &Tensor::from([("a", 7), ("b", 3)], vec![
1638 9, 5, 0,
1639 2, 1, 1,
1640 3, 5, 2,
1641 0, 1, 3,
1642 8, 4, 4,
1643 1, 7, 5,
1644 6, 3, 6
1645 ])
1646 )
1647 );
1648}