1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6#[derive(Clone, Debug)]
59pub struct TensorStack<T, S, const D: usize> {
60 sources: S,
61 _type: PhantomData<T>,
62 along: (usize, Dimension),
63}
64
65fn validate_shapes_equal<const D: usize, I>(mut shapes: I)
66where
67 I: Iterator<Item = [(Dimension, usize); D]>,
68{
69 let first_shape = shapes.next().unwrap();
72 for (i, shape) in shapes.enumerate() {
73 if shape != first_shape {
74 panic!(
75 "The shapes of each tensor in the sources to stack along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
76 i + 1,
77 shape,
78 first_shape
79 );
80 }
81 }
82}
83
84impl<T, S, const D: usize, const N: usize> TensorStack<T, [S; N], D>
85where
86 S: TensorRef<T, D>,
87{
88 #[track_caller]
103 pub fn from(sources: [S; N], along: (usize, Dimension)) -> Self {
104 if N == 0 {
105 panic!("No sources provided");
106 }
107 if along.0 > D {
108 panic!(
109 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
110 along
111 );
112 }
113 let shape = sources[0].view_shape();
114 if dimensions::contains(&shape, along.1) {
115 panic!(
116 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
117 along, shape
118 );
119 }
120 validate_shapes_equal(sources.iter().map(|tensor| tensor.view_shape()));
121 Self {
122 sources,
123 along,
124 _type: PhantomData,
125 }
126 }
127
128 pub fn sources(self) -> [S; N] {
132 self.sources
133 }
134
135 pub fn sources_ref(&self) -> &[S; N] {
145 &self.sources
146 }
147
148 fn source_view_shape(&self) -> [(Dimension, usize); D] {
152 self.sources[0].view_shape()
153 }
154
155 fn number_of_sources() -> usize {
156 N
157 }
158}
159
160impl<T, S1, S2, const D: usize> TensorStack<T, (S1, S2), D>
161where
162 S1: TensorRef<T, D>,
163 S2: TensorRef<T, D>,
164{
165 #[track_caller]
176 pub fn from(sources: (S1, S2), along: (usize, Dimension)) -> Self {
177 if along.0 > D {
178 panic!(
179 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
180 along
181 );
182 }
183 let shape = sources.0.view_shape();
184 if dimensions::contains(&shape, along.1) {
185 panic!(
186 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
187 along, shape
188 );
189 }
190 validate_shapes_equal([sources.0.view_shape(), sources.1.view_shape()].into_iter());
191 Self {
192 sources,
193 along,
194 _type: PhantomData,
195 }
196 }
197
198 pub fn sources(self) -> (S1, S2) {
202 self.sources
203 }
204
205 pub fn sources_ref(&self) -> &(S1, S2) {
215 &self.sources
216 }
217
218 fn source_view_shape(&self) -> [(Dimension, usize); D] {
222 self.sources.0.view_shape()
223 }
224
225 fn number_of_sources() -> usize {
226 2
227 }
228}
229
230impl<T, S1, S2, S3, const D: usize> TensorStack<T, (S1, S2, S3), D>
231where
232 S1: TensorRef<T, D>,
233 S2: TensorRef<T, D>,
234 S3: TensorRef<T, D>,
235{
236 #[track_caller]
247 pub fn from(sources: (S1, S2, S3), along: (usize, Dimension)) -> Self {
248 if along.0 > D {
249 panic!(
250 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
251 along
252 );
253 }
254 let shape = sources.0.view_shape();
255 if dimensions::contains(&shape, along.1) {
256 panic!(
257 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
258 along, shape
259 );
260 }
261 validate_shapes_equal(
262 [
263 sources.0.view_shape(),
264 sources.1.view_shape(),
265 sources.2.view_shape(),
266 ]
267 .into_iter(),
268 );
269 Self {
270 sources,
271 along,
272 _type: PhantomData,
273 }
274 }
275
276 pub fn sources(self) -> (S1, S2, S3) {
280 self.sources
281 }
282
283 pub fn sources_ref(&self) -> &(S1, S2, S3) {
293 &self.sources
294 }
295
296 fn source_view_shape(&self) -> [(Dimension, usize); D] {
300 self.sources.0.view_shape()
301 }
302
303 fn number_of_sources() -> usize {
304 3
305 }
306}
307
308impl<T, S1, S2, S3, S4, const D: usize> TensorStack<T, (S1, S2, S3, S4), D>
309where
310 S1: TensorRef<T, D>,
311 S2: TensorRef<T, D>,
312 S3: TensorRef<T, D>,
313 S4: TensorRef<T, D>,
314{
315 #[track_caller]
326 pub fn from(sources: (S1, S2, S3, S4), along: (usize, Dimension)) -> Self {
327 if along.0 > D {
328 panic!(
329 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
330 along
331 );
332 }
333 let shape = sources.0.view_shape();
334 if dimensions::contains(&shape, along.1) {
335 panic!(
336 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
337 along, shape
338 );
339 }
340 validate_shapes_equal(
341 [
342 sources.0.view_shape(),
343 sources.1.view_shape(),
344 sources.2.view_shape(),
345 sources.3.view_shape(),
346 ]
347 .into_iter(),
348 );
349 Self {
350 sources,
351 along,
352 _type: PhantomData,
353 }
354 }
355
356 pub fn sources(self) -> (S1, S2, S3, S4) {
360 self.sources
361 }
362
363 pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
373 &self.sources
374 }
375
376 fn source_view_shape(&self) -> [(Dimension, usize); D] {
380 self.sources.0.view_shape()
381 }
382
383 fn number_of_sources() -> usize {
384 4
385 }
386}
387
388macro_rules! tensor_stack_ref_impl {
389 (unsafe impl TensorRef for TensorStack $d:literal $mod:ident) => {
390 mod $mod {
392 use crate::tensors::views::{TensorRef, TensorMut, DataLayout, TensorStack};
393 use crate::tensors::Dimension;
394
395 fn view_shape_impl(
396 shape: [(Dimension, usize); $d],
397 along: (usize, Dimension),
398 sources: usize,
399 ) -> [(Dimension, usize); $d + 1] {
400 let mut extra_shape = [("", 0); $d + 1];
401 let mut i = 0;
402 for (d, dimension) in extra_shape.iter_mut().enumerate() {
403 match d == along.0 {
404 true => {
405 *dimension = (along.1, sources);
406 },
408 false => {
409 *dimension = shape[i];
410 i += 1;
411 }
412 }
413 }
414 extra_shape
415 }
416
417 fn indexing(
418 indexes: [usize; $d + 1],
419 along: (usize, Dimension)
420 ) -> (usize, [usize; $d]) {
421 let mut indexes_into_source = [0; $d];
422 let mut i = 0;
423 for (d, &index) in indexes.iter().enumerate() {
424 if d != along.0 {
425 indexes_into_source[i] = index;
426 i += 1;
427 }
428 }
429 (indexes[along.0], indexes_into_source)
430 }
431
432 unsafe impl<T, S, const N: usize> TensorRef<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
433 where
434 S: TensorRef<T, $d>
435 {
436 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
437 let (source, indexes) = indexing(indexes, self.along);
438 self.sources.get(source)?.get_reference(indexes)
439 }
440
441 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
442 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
443 }
444
445 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
446 let (source, indexes) = indexing(indexes, self.along);
447 self.sources.get_unchecked(source).get_reference_unchecked(indexes)
448 }}
449
450 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
451 DataLayout::NonLinear
454 }
455 }
456
457 unsafe impl<T, S, const N: usize> TensorMut<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
458 where
459 S: TensorMut<T, $d>
460 {
461 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
462 let (source, indexes) = indexing(indexes, self.along);
463 self.sources.get_mut(source)?.get_reference_mut(indexes)
464 }
465
466 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
467 let (source, indexes) = indexing(indexes, self.along);
468 self.sources.get_unchecked_mut(source).get_reference_unchecked_mut(indexes)
469 }}
470 }
471
472 unsafe impl<T, S1, S2> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
473 where
474 S1: TensorRef<T, $d>,
475 S2: TensorRef<T, $d>,
476 {
477 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
478 let (source, indexes) = indexing(indexes, self.along);
479 match source {
480 0 => self.sources.0.get_reference(indexes),
481 1 => self.sources.1.get_reference(indexes),
482 _ => None
483 }
484 }
485
486 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
487 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
488 }
489
490 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
491 let (source, indexes) = indexing(indexes, self.along);
492 match source {
493 0 => self.sources.0.get_reference_unchecked(indexes),
494 1 => self.sources.1.get_reference_unchecked(indexes),
495 _ => panic!(
497 "Invalid index should never be given to get_reference_unchecked"
498 )
499 }
500 }}
501
502 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
503 DataLayout::NonLinear
506 }
507 }
508
509 unsafe impl<T, S1, S2> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
510 where
511 S1: TensorMut<T, $d>,
512 S2: TensorMut<T, $d>,
513 {
514 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
515 let (source, indexes) = indexing(indexes, self.along);
516 match source {
517 0 => self.sources.0.get_reference_mut(indexes),
518 1 => self.sources.1.get_reference_mut(indexes),
519 _ => None
520 }
521 }
522
523 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
524 let (source, indexes) = indexing(indexes, self.along);
525 match source {
526 0 => self.sources.0.get_reference_unchecked_mut(indexes),
527 1 => self.sources.1.get_reference_unchecked_mut(indexes),
528 _ => panic!(
530 "Invalid index should never be given to get_reference_unchecked"
531 )
532 }
533 }}
534 }
535
536 unsafe impl<T, S1, S2, S3> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
537 where
538 S1: TensorRef<T, $d>,
539 S2: TensorRef<T, $d>,
540 S3: TensorRef<T, $d>,
541 {
542 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
543 let (source, indexes) = indexing(indexes, self.along);
544 match source {
545 0 => self.sources.0.get_reference(indexes),
546 1 => self.sources.1.get_reference(indexes),
547 2 => self.sources.2.get_reference(indexes),
548 _ => None
549 }
550 }
551
552 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
553 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
554 }
555
556 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
557 let (source, indexes) = indexing(indexes, self.along);
558 match source {
559 0 => self.sources.0.get_reference_unchecked(indexes),
560 1 => self.sources.1.get_reference_unchecked(indexes),
561 2 => self.sources.2.get_reference_unchecked(indexes),
562 _ => panic!(
564 "Invalid index should never be given to get_reference_unchecked"
565 )
566 }
567 }}
568
569 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
570 DataLayout::NonLinear
573 }
574 }
575
576 unsafe impl<T, S1, S2, S3> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
577 where
578 S1: TensorMut<T, $d>,
579 S2: TensorMut<T, $d>,
580 S3: TensorMut<T, $d>,
581 {
582 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
583 let (source, indexes) = indexing(indexes, self.along);
584 match source {
585 0 => self.sources.0.get_reference_mut(indexes),
586 1 => self.sources.1.get_reference_mut(indexes),
587 2 => self.sources.2.get_reference_mut(indexes),
588 _ => None
589 }
590 }
591
592 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
593 let (source, indexes) = indexing(indexes, self.along);
594 match source {
595 0 => self.sources.0.get_reference_unchecked_mut(indexes),
596 1 => self.sources.1.get_reference_unchecked_mut(indexes),
597 2 => self.sources.2.get_reference_unchecked_mut(indexes),
598 _ => panic!(
600 "Invalid index should never be given to get_reference_unchecked"
601 )
602 }
603 }}
604 }
605
606 unsafe impl<T, S1, S2, S3, S4> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
607 where
608 S1: TensorRef<T, $d>,
609 S2: TensorRef<T, $d>,
610 S3: TensorRef<T, $d>,
611 S4: TensorRef<T, $d>,
612 {
613 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
614 let (source, indexes) = indexing(indexes, self.along);
615 match source {
616 0 => self.sources.0.get_reference(indexes),
617 1 => self.sources.1.get_reference(indexes),
618 2 => self.sources.2.get_reference(indexes),
619 3 => self.sources.3.get_reference(indexes),
620 _ => None
621 }
622 }
623
624 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
625 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
626 }
627
628 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
629 let (source, indexes) = indexing(indexes, self.along);
630 match source {
631 0 => self.sources.0.get_reference_unchecked(indexes),
632 1 => self.sources.1.get_reference_unchecked(indexes),
633 2 => self.sources.2.get_reference_unchecked(indexes),
634 3 => self.sources.3.get_reference_unchecked(indexes),
635 _ => panic!(
637 "Invalid index should never be given to get_reference_unchecked"
638 )
639 }
640 }}
641
642 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
643 DataLayout::NonLinear
646 }
647 }
648
649 unsafe impl<T, S1, S2, S3, S4> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
650 where
651 S1: TensorMut<T, $d>,
652 S2: TensorMut<T, $d>,
653 S3: TensorMut<T, $d>,
654 S4: TensorMut<T, $d>,
655 {
656 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
657 let (source, indexes) = indexing(indexes, self.along);
658 match source {
659 0 => self.sources.0.get_reference_mut(indexes),
660 1 => self.sources.1.get_reference_mut(indexes),
661 2 => self.sources.2.get_reference_mut(indexes),
662 3 => self.sources.3.get_reference_mut(indexes),
663 _ => None
664 }
665 }
666
667 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
668 let (source, indexes) = indexing(indexes, self.along);
669 match source {
670 0 => self.sources.0.get_reference_unchecked_mut(indexes),
671 1 => self.sources.1.get_reference_unchecked_mut(indexes),
672 2 => self.sources.2.get_reference_unchecked_mut(indexes),
673 3 => self.sources.3.get_reference_unchecked_mut(indexes),
674 _ => panic!(
676 "Invalid index should never be given to get_reference_unchecked"
677 )
678 }
679 }}
680 }
681 }
682 }
683}
684
685tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 0 zero);
686tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 1 one);
687tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 2 two);
688tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 3 three);
689tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 4 four);
690tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 5 five);
691
692#[test]
693fn test_stacking() {
694 use crate::tensors::Tensor;
695 use crate::tensors::views::{TensorMut, TensorView};
696 let vector1 = Tensor::from([("a", 3)], vec![9, 5, 2]);
697 let vector2 = Tensor::from([("a", 3)], vec![3, 6, 0]);
698 let vector3 = Tensor::from([("a", 3)], vec![8, 7, 1]);
699 let matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
700 (&vector1, &vector2, &vector3),
701 (1, "b"),
702 ));
703 #[rustfmt::skip]
704 assert_eq!(
705 matrix,
706 Tensor::from([("a", 3), ("b", 3)], vec![
707 9, 3, 8,
708 5, 6, 7,
709 2, 0, 1,
710 ])
711 );
712 let different_matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
713 (&vector1, &vector2, &vector3),
714 (0, "b"),
715 ));
716 #[rustfmt::skip]
717 assert_eq!(
718 different_matrix,
719 Tensor::from([("b", 3), ("a", 3)], vec![
720 9, 5, 2,
721 3, 6, 0,
722 8, 7, 1,
723 ])
724 );
725 let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
726 let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
727 Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
728 let tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
729 [matrix_erased, different_matrix_erased],
730 (2, "c"),
731 ));
732 #[rustfmt::skip]
733 assert!(
734 tensor.eq(
735 &Tensor::from([("a", 3), ("b", 3), ("c", 2)], vec![
736 9, 9,
737 3, 5,
738 8, 2,
739
740 5, 3,
741 6, 6,
742 7, 0,
743
744 2, 8,
745 0, 7,
746 1, 1
747 ])
748 ),
749 );
750 let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
751 let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
752 Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
753 let different_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
754 [matrix_erased, different_matrix_erased],
755 (1, "c"),
756 ));
757 #[rustfmt::skip]
758 assert!(
759 different_tensor.eq(
760 &Tensor::from([("a", 3), ("c", 2), ("b", 3)], vec![
761 9, 3, 8,
762 9, 5, 2,
763
764 5, 6, 7,
765 3, 6, 0,
766
767 2, 0, 1,
768 8, 7, 1
769 ])
770 ),
771 );
772 let matrix_erased: Box<dyn TensorRef<_, 2>> = Box::new(matrix.map(|x| x));
773 let different_matrix_erased: Box<dyn TensorRef<_, 2>> =
774 Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
775 let another_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
776 [matrix_erased, different_matrix_erased],
777 (0, "c"),
778 ));
779 #[rustfmt::skip]
780 assert!(
781 another_tensor.eq(
782 &Tensor::from([("c", 2), ("a", 3), ("b", 3)], vec![
783 9, 3, 8,
784 5, 6, 7,
785 2, 0, 1,
786
787 9, 5, 2,
788 3, 6, 0,
789 8, 7, 1,
790 ])
791 ),
792 );
793}
794
795#[derive(Clone, Debug)]
847pub struct TensorChain<T, S, const D: usize> {
848 sources: S,
849 _type: PhantomData<T>,
850 along: usize,
851}
852
853fn validate_shapes_similar<const D: usize, I>(mut shapes: I, along: usize)
854where
855 I: Iterator<Item = [(Dimension, usize); D]>,
856{
857 let first_shape = shapes.next().unwrap();
860 for (i, shape) in shapes.enumerate() {
861 for d in 0..D {
862 let similar = if d == along {
863 shape[d].0 == first_shape[d].0
865 } else {
866 shape[d] == first_shape[d]
867 };
868 if !similar {
869 panic!(
870 "The shapes of each tensor in the sources to chain along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
871 i + 1,
872 shape,
873 first_shape
874 );
875 }
876 }
877 }
878}
879
880impl<T, S, const D: usize, const N: usize> TensorChain<T, [S; N], D>
881where
882 S: TensorRef<T, D>,
883{
884 #[track_caller]
898 pub fn from(sources: [S; N], along: Dimension) -> Self {
899 if N == 0 {
900 panic!("No sources provided");
901 }
902 if D == 0 {
903 panic!("Can't chain along 0 dimensional tensors");
904 }
905 let shape = sources[0].view_shape();
906 let along = match dimensions::position_of(&shape, along) {
907 Some(d) => d,
908 None => panic!(
909 "The dimension {:?} is not in the source's shapes: {:?}",
910 along, shape
911 ),
912 };
913 validate_shapes_similar(sources.iter().map(|tensor| tensor.view_shape()), along);
914 Self {
915 sources,
916 along,
917 _type: PhantomData,
918 }
919 }
920
921 pub fn sources(self) -> [S; N] {
925 self.sources
926 }
927
928 pub fn sources_ref(&self) -> &[S; N] {
938 &self.sources
939 }
940}
941
942impl<T, S1, S2, const D: usize> TensorChain<T, (S1, S2), D>
943where
944 S1: TensorRef<T, D>,
945 S2: TensorRef<T, D>,
946{
947 #[track_caller]
960 pub fn from(sources: (S1, S2), along: Dimension) -> Self {
961 if D == 0 {
962 panic!("Can't chain along 0 dimensional tensors");
963 }
964 let shape = sources.0.view_shape();
965 let along = match dimensions::position_of(&shape, along) {
966 Some(d) => d,
967 None => panic!(
968 "The dimension {:?} is not in the source's shapes: {:?}",
969 along, shape
970 ),
971 };
972 validate_shapes_similar(
973 [sources.0.view_shape(), sources.1.view_shape()].into_iter(),
974 along,
975 );
976 Self {
977 sources,
978 along,
979 _type: PhantomData,
980 }
981 }
982
983 pub fn sources(self) -> (S1, S2) {
987 self.sources
988 }
989
990 pub fn sources_ref(&self) -> &(S1, S2) {
1000 &self.sources
1001 }
1002}
1003
1004impl<T, S1, S2, S3, const D: usize> TensorChain<T, (S1, S2, S3), D>
1005where
1006 S1: TensorRef<T, D>,
1007 S2: TensorRef<T, D>,
1008 S3: TensorRef<T, D>,
1009{
1010 #[track_caller]
1023 pub fn from(sources: (S1, S2, S3), along: Dimension) -> Self {
1024 if D == 0 {
1025 panic!("Can't chain along 0 dimensional tensors");
1026 }
1027 let shape = sources.0.view_shape();
1028 let along = match dimensions::position_of(&shape, along) {
1029 Some(d) => d,
1030 None => panic!(
1031 "The dimension {:?} is not in the source's shapes: {:?}",
1032 along, shape
1033 ),
1034 };
1035 validate_shapes_similar(
1036 [
1037 sources.0.view_shape(),
1038 sources.1.view_shape(),
1039 sources.2.view_shape(),
1040 ]
1041 .into_iter(),
1042 along,
1043 );
1044 Self {
1045 sources,
1046 along,
1047 _type: PhantomData,
1048 }
1049 }
1050
1051 pub fn sources(self) -> (S1, S2, S3) {
1055 self.sources
1056 }
1057
1058 pub fn sources_ref(&self) -> &(S1, S2, S3) {
1068 &self.sources
1069 }
1070}
1071
1072impl<T, S1, S2, S3, S4, const D: usize> TensorChain<T, (S1, S2, S3, S4), D>
1073where
1074 S1: TensorRef<T, D>,
1075 S2: TensorRef<T, D>,
1076 S3: TensorRef<T, D>,
1077 S4: TensorRef<T, D>,
1078{
1079 #[track_caller]
1092 pub fn from(sources: (S1, S2, S3, S4), along: Dimension) -> Self {
1093 if D == 0 {
1094 panic!("Can't chain along 0 dimensional tensors");
1095 }
1096 let shape = sources.0.view_shape();
1097 let along = match dimensions::position_of(&shape, along) {
1098 Some(d) => d,
1099 None => panic!(
1100 "The dimension {:?} is not in the source's shapes: {:?}",
1101 along, shape
1102 ),
1103 };
1104 validate_shapes_similar(
1105 [
1106 sources.0.view_shape(),
1107 sources.1.view_shape(),
1108 sources.2.view_shape(),
1109 sources.3.view_shape(),
1110 ]
1111 .into_iter(),
1112 along,
1113 );
1114 Self {
1115 sources,
1116 along,
1117 _type: PhantomData,
1118 }
1119 }
1120
1121 pub fn sources(self) -> (S1, S2, S3, S4) {
1125 self.sources
1126 }
1127
1128 pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
1138 &self.sources
1139 }
1140}
1141
1142fn view_shape_impl<I, const D: usize>(
1143 first_shape: [(Dimension, usize); D],
1144 shapes: I,
1145 along: usize,
1146) -> [(Dimension, usize); D]
1147where
1148 I: Iterator<Item = [(Dimension, usize); D]>,
1149{
1150 let mut shape = first_shape;
1151 shape[along].1 = shapes.into_iter().map(|shape| shape[along].1).sum();
1152 shape
1153}
1154
1155fn indexing<I, const D: usize>(
1156 indexes: [usize; D],
1157 shapes: I,
1158 along: usize,
1159) -> Option<(usize, [usize; D])>
1160where
1161 I: Iterator<Item = [(Dimension, usize); D]>,
1162{
1163 let mut shapes = shapes.enumerate();
1164 let mut i = indexes[along];
1168 loop {
1169 let (source, next_shape) = shapes.next()?;
1170 let length_along_chained_dimension = next_shape[along].1;
1171 if i < length_along_chained_dimension {
1172 #[allow(clippy::clone_on_copy)]
1173 let mut indexes = indexes.clone();
1174 indexes[along] = i;
1175 return Some((source, indexes));
1176 }
1177 i -= length_along_chained_dimension;
1178 }
1179}
1180
1181unsafe impl<T, S, const D: usize, const N: usize> TensorRef<T, D> for TensorChain<T, [S; N], D>
1182where
1183 S: TensorRef<T, D>,
1184{
1185 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1186 let (source, indexes) = indexing(
1187 indexes,
1188 self.sources.iter().map(|s| s.view_shape()),
1189 self.along,
1190 )?;
1191 self.sources.get(source)?.get_reference(indexes)
1192 }
1193
1194 fn view_shape(&self) -> [(Dimension, usize); D] {
1195 view_shape_impl(
1196 self.sources[0].view_shape(),
1197 self.sources.iter().map(|s| s.view_shape()),
1198 self.along,
1199 )
1200 }
1201
1202 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1203 unsafe {
1204 let (source, indexes) = indexing(
1205 indexes,
1206 self.sources.iter().map(|s| s.view_shape()),
1207 self.along,
1208 )
1209 .unwrap_unchecked();
1212 self.sources
1213 .get(source)
1214 .unwrap()
1215 .get_reference_unchecked(indexes)
1216 }
1217 }
1218
1219 fn data_layout(&self) -> DataLayout<D> {
1220 DataLayout::NonLinear
1223 }
1224}
1225
1226unsafe impl<T, S, const D: usize, const N: usize> TensorMut<T, D> for TensorChain<T, [S; N], D>
1227where
1228 S: TensorMut<T, D>,
1229{
1230 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1231 let (source, indexes) = indexing(
1232 indexes,
1233 self.sources.iter().map(|s| s.view_shape()),
1234 self.along,
1235 )?;
1236 self.sources.get_mut(source)?.get_reference_mut(indexes)
1237 }
1238
1239 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1240 unsafe {
1241 let (source, indexes) = indexing(
1242 indexes,
1243 self.sources.iter().map(|s| s.view_shape()),
1244 self.along,
1245 )
1246 .unwrap_unchecked();
1249 self.sources
1250 .get_mut(source)
1251 .unwrap()
1252 .get_reference_unchecked_mut(indexes)
1253 }
1254 }
1255}
1256
1257unsafe impl<T, S1, S2, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2), D>
1258where
1259 S1: TensorRef<T, D>,
1260 S2: TensorRef<T, D>,
1261{
1262 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1263 let (source, indexes) = indexing(
1264 indexes,
1265 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1266 self.along,
1267 )?;
1268 match source {
1269 0 => self.sources.0.get_reference(indexes),
1270 1 => self.sources.1.get_reference(indexes),
1271 _ => None,
1272 }
1273 }
1274
1275 fn view_shape(&self) -> [(Dimension, usize); D] {
1276 view_shape_impl(
1277 self.sources.0.view_shape(),
1278 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1279 self.along,
1280 )
1281 }
1282
1283 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1284 unsafe {
1285 let (source, indexes) = indexing(
1286 indexes,
1287 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1288 self.along,
1289 )
1290 .unwrap_unchecked();
1293 match source {
1294 0 => self.sources.0.get_reference_unchecked(indexes),
1295 1 => self.sources.1.get_reference_unchecked(indexes),
1296 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1298 }
1299 }
1300 }
1301
1302 fn data_layout(&self) -> DataLayout<D> {
1303 DataLayout::NonLinear
1306 }
1307}
1308
1309unsafe impl<T, S1, S2, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2), D>
1310where
1311 S1: TensorMut<T, D>,
1312 S2: TensorMut<T, D>,
1313{
1314 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1315 let (source, indexes) = indexing(
1316 indexes,
1317 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1318 self.along,
1319 )?;
1320 match source {
1321 0 => self.sources.0.get_reference_mut(indexes),
1322 1 => self.sources.1.get_reference_mut(indexes),
1323 _ => None,
1324 }
1325 }
1326
1327 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1328 unsafe {
1329 let (source, indexes) = indexing(
1330 indexes,
1331 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1332 self.along,
1333 )
1334 .unwrap_unchecked();
1337 match source {
1338 0 => self.sources.0.get_reference_unchecked_mut(indexes),
1339 1 => self.sources.1.get_reference_unchecked_mut(indexes),
1340 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1342 }
1343 }
1344 }
1345}
1346
1347unsafe impl<T, S1, S2, S3, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2, S3), D>
1348where
1349 S1: TensorRef<T, D>,
1350 S2: TensorRef<T, D>,
1351 S3: TensorRef<T, D>,
1352{
1353 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1354 let (source, indexes) = indexing(
1355 indexes,
1356 [
1357 self.sources.0.view_shape(),
1358 self.sources.1.view_shape(),
1359 self.sources.2.view_shape(),
1360 ]
1361 .into_iter(),
1362 self.along,
1363 )?;
1364 match source {
1365 0 => self.sources.0.get_reference(indexes),
1366 1 => self.sources.1.get_reference(indexes),
1367 2 => self.sources.2.get_reference(indexes),
1368 _ => None,
1369 }
1370 }
1371
1372 fn view_shape(&self) -> [(Dimension, usize); D] {
1373 view_shape_impl(
1374 self.sources.0.view_shape(),
1375 [
1376 self.sources.0.view_shape(),
1377 self.sources.1.view_shape(),
1378 self.sources.2.view_shape(),
1379 ]
1380 .into_iter(),
1381 self.along,
1382 )
1383 }
1384
1385 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1386 unsafe {
1387 let (source, indexes) = indexing(
1388 indexes,
1389 [
1390 self.sources.0.view_shape(),
1391 self.sources.1.view_shape(),
1392 self.sources.2.view_shape(),
1393 ]
1394 .into_iter(),
1395 self.along,
1396 )
1397 .unwrap_unchecked();
1400 match source {
1401 0 => self.sources.0.get_reference_unchecked(indexes),
1402 1 => self.sources.1.get_reference_unchecked(indexes),
1403 2 => self.sources.2.get_reference_unchecked(indexes),
1404 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1406 }
1407 }
1408 }
1409
1410 fn data_layout(&self) -> DataLayout<D> {
1411 DataLayout::NonLinear
1414 }
1415}
1416
1417unsafe impl<T, S1, S2, S3, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2, S3), D>
1418where
1419 S1: TensorMut<T, D>,
1420 S2: TensorMut<T, D>,
1421 S3: TensorMut<T, D>,
1422{
1423 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1424 let (source, indexes) = indexing(
1425 indexes,
1426 [
1427 self.sources.0.view_shape(),
1428 self.sources.1.view_shape(),
1429 self.sources.2.view_shape(),
1430 ]
1431 .into_iter(),
1432 self.along,
1433 )?;
1434 match source {
1435 0 => self.sources.0.get_reference_mut(indexes),
1436 1 => self.sources.1.get_reference_mut(indexes),
1437 2 => self.sources.2.get_reference_mut(indexes),
1438 _ => None,
1439 }
1440 }
1441
1442 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1443 unsafe {
1444 let (source, indexes) = indexing(
1445 indexes,
1446 [
1447 self.sources.0.view_shape(),
1448 self.sources.1.view_shape(),
1449 self.sources.2.view_shape(),
1450 ]
1451 .into_iter(),
1452 self.along,
1453 )
1454 .unwrap_unchecked();
1457 match source {
1458 0 => self.sources.0.get_reference_unchecked_mut(indexes),
1459 1 => self.sources.1.get_reference_unchecked_mut(indexes),
1460 2 => self.sources.2.get_reference_unchecked_mut(indexes),
1461 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1463 }
1464 }
1465 }
1466}
1467
1468unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorRef<T, D>
1469 for TensorChain<T, (S1, S2, S3, S4), D>
1470where
1471 S1: TensorRef<T, D>,
1472 S2: TensorRef<T, D>,
1473 S3: TensorRef<T, D>,
1474 S4: TensorRef<T, D>,
1475{
1476 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1477 let (source, indexes) = indexing(
1478 indexes,
1479 [
1480 self.sources.0.view_shape(),
1481 self.sources.1.view_shape(),
1482 self.sources.2.view_shape(),
1483 self.sources.3.view_shape(),
1484 ]
1485 .into_iter(),
1486 self.along,
1487 )?;
1488 match source {
1489 0 => self.sources.0.get_reference(indexes),
1490 1 => self.sources.1.get_reference(indexes),
1491 2 => self.sources.2.get_reference(indexes),
1492 3 => self.sources.3.get_reference(indexes),
1493 _ => None,
1494 }
1495 }
1496
1497 fn view_shape(&self) -> [(Dimension, usize); D] {
1498 view_shape_impl(
1499 self.sources.0.view_shape(),
1500 [
1501 self.sources.0.view_shape(),
1502 self.sources.1.view_shape(),
1503 self.sources.2.view_shape(),
1504 self.sources.3.view_shape(),
1505 ]
1506 .into_iter(),
1507 self.along,
1508 )
1509 }
1510
1511 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1512 unsafe {
1513 let (source, indexes) = indexing(
1514 indexes,
1515 [
1516 self.sources.0.view_shape(),
1517 self.sources.1.view_shape(),
1518 self.sources.2.view_shape(),
1519 self.sources.3.view_shape(),
1520 ]
1521 .into_iter(),
1522 self.along,
1523 )
1524 .unwrap_unchecked();
1527 match source {
1528 0 => self.sources.0.get_reference_unchecked(indexes),
1529 1 => self.sources.1.get_reference_unchecked(indexes),
1530 2 => self.sources.2.get_reference_unchecked(indexes),
1531 3 => self.sources.3.get_reference_unchecked(indexes),
1532 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1534 }
1535 }
1536 }
1537
1538 fn data_layout(&self) -> DataLayout<D> {
1539 DataLayout::NonLinear
1542 }
1543}
1544
1545unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorMut<T, D>
1546 for TensorChain<T, (S1, S2, S3, S4), D>
1547where
1548 S1: TensorMut<T, D>,
1549 S2: TensorMut<T, D>,
1550 S3: TensorMut<T, D>,
1551 S4: TensorMut<T, D>,
1552{
1553 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1554 let (source, indexes) = indexing(
1555 indexes,
1556 [
1557 self.sources.0.view_shape(),
1558 self.sources.1.view_shape(),
1559 self.sources.2.view_shape(),
1560 self.sources.3.view_shape(),
1561 ]
1562 .into_iter(),
1563 self.along,
1564 )?;
1565 match source {
1566 0 => self.sources.0.get_reference_mut(indexes),
1567 1 => self.sources.1.get_reference_mut(indexes),
1568 2 => self.sources.2.get_reference_mut(indexes),
1569 3 => self.sources.3.get_reference_mut(indexes),
1570 _ => None,
1571 }
1572 }
1573
1574 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1575 unsafe {
1576 let (source, indexes) = indexing(
1577 indexes,
1578 [
1579 self.sources.0.view_shape(),
1580 self.sources.1.view_shape(),
1581 self.sources.2.view_shape(),
1582 self.sources.3.view_shape(),
1583 ]
1584 .into_iter(),
1585 self.along,
1586 )
1587 .unwrap_unchecked();
1590 match source {
1591 0 => self.sources.0.get_reference_unchecked_mut(indexes),
1592 1 => self.sources.1.get_reference_unchecked_mut(indexes),
1593 2 => self.sources.2.get_reference_unchecked_mut(indexes),
1594 3 => self.sources.3.get_reference_unchecked_mut(indexes),
1595 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1597 }
1598 }
1599 }
1600}
1601
1602#[test]
1603fn test_chaining() {
1604 use crate::tensors::Tensor;
1605 use crate::tensors::views::TensorView;
1606 #[rustfmt::skip]
1607 let matrix1 = Tensor::from(
1608 [("a", 3), ("b", 2)],
1609 vec![
1610 9, 5,
1611 2, 1,
1612 3, 5
1613 ]
1614 );
1615 #[rustfmt::skip]
1616 let matrix2 = Tensor::from(
1617 [("a", 4), ("b", 2)],
1618 vec![
1619 0, 1,
1620 8, 4,
1621 1, 7,
1622 6, 3
1623 ]
1624 );
1625 let matrix = TensorView::from(TensorChain::<_, (_, _), 2>::from((&matrix1, &matrix2), "a"));
1626 #[rustfmt::skip]
1627 assert_eq!(
1628 matrix,
1629 Tensor::from([("a", 7), ("b", 2)], vec![
1630 9, 5,
1631 2, 1,
1632 3, 5,
1633 0, 1,
1634 8, 4,
1635 1, 7,
1636 6, 3
1637 ])
1638 );
1639 let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
1640 let different_matrix = Tensor::from([("a", 7), ("b", 1)], (0..7).collect());
1641 let different_matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(different_matrix);
1642 let another_matrix = TensorView::from(TensorChain::<_, [_; 2], 2>::from(
1643 [matrix_erased, different_matrix_erased],
1644 "b",
1645 ));
1646 #[rustfmt::skip]
1647 assert!(
1648 another_matrix.eq(
1649 &Tensor::from([("a", 7), ("b", 3)], vec![
1650 9, 5, 0,
1651 2, 1, 1,
1652 3, 5, 2,
1653 0, 1, 3,
1654 8, 4, 4,
1655 1, 7, 5,
1656 6, 3, 6
1657 ])
1658 )
1659 );
1660}