1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6#[derive(Clone, Debug)]
56pub struct TensorStack<T, S, const D: usize> {
57 sources: S,
58 _type: PhantomData<T>,
59 along: (usize, Dimension),
60}
61
62fn validate_shapes_equal<const D: usize, I>(mut shapes: I)
63where
64 I: Iterator<Item = [(Dimension, usize); D]>,
65{
66 let first_shape = shapes.next().unwrap();
69 for (i, shape) in shapes.enumerate() {
70 if shape != first_shape {
71 panic!(
72 "The shapes of each tensor in the sources to stack along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
73 i + 1,
74 shape,
75 first_shape
76 );
77 }
78 }
79}
80
81impl<T, S, const D: usize, const N: usize> TensorStack<T, [S; N], D>
82where
83 S: TensorRef<T, D>,
84{
85 #[track_caller]
100 pub fn from(sources: [S; N], along: (usize, Dimension)) -> Self {
101 if N == 0 {
102 panic!("No sources provided");
103 }
104 if along.0 > D {
105 panic!(
106 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
107 along
108 );
109 }
110 let shape = sources[0].view_shape();
111 if dimensions::contains(&shape, along.1) {
112 panic!(
113 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
114 along, shape
115 );
116 }
117 validate_shapes_equal(sources.iter().map(|tensor| tensor.view_shape()));
118 Self {
119 sources,
120 along,
121 _type: PhantomData,
122 }
123 }
124
125 pub fn sources(self) -> [S; N] {
129 self.sources
130 }
131
132 pub fn sources_ref(&self) -> &[S; N] {
142 &self.sources
143 }
144
145 fn source_view_shape(&self) -> [(Dimension, usize); D] {
149 self.sources[0].view_shape()
150 }
151
152 fn number_of_sources() -> usize {
153 N
154 }
155}
156
157impl<T, S1, S2, const D: usize> TensorStack<T, (S1, S2), D>
158where
159 S1: TensorRef<T, D>,
160 S2: TensorRef<T, D>,
161{
162 #[track_caller]
173 pub fn from(sources: (S1, S2), along: (usize, Dimension)) -> Self {
174 if along.0 > D {
175 panic!(
176 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
177 along
178 );
179 }
180 let shape = sources.0.view_shape();
181 if dimensions::contains(&shape, along.1) {
182 panic!(
183 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
184 along, shape
185 );
186 }
187 validate_shapes_equal([sources.0.view_shape(), sources.1.view_shape()].into_iter());
188 Self {
189 sources,
190 along,
191 _type: PhantomData,
192 }
193 }
194
195 pub fn sources(self) -> (S1, S2) {
199 self.sources
200 }
201
202 pub fn sources_ref(&self) -> &(S1, S2) {
212 &self.sources
213 }
214
215 fn source_view_shape(&self) -> [(Dimension, usize); D] {
219 self.sources.0.view_shape()
220 }
221
222 fn number_of_sources() -> usize {
223 2
224 }
225}
226
227impl<T, S1, S2, S3, const D: usize> TensorStack<T, (S1, S2, S3), D>
228where
229 S1: TensorRef<T, D>,
230 S2: TensorRef<T, D>,
231 S3: TensorRef<T, D>,
232{
233 #[track_caller]
244 pub fn from(sources: (S1, S2, S3), along: (usize, Dimension)) -> Self {
245 if along.0 > D {
246 panic!(
247 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
248 along
249 );
250 }
251 let shape = sources.0.view_shape();
252 if dimensions::contains(&shape, along.1) {
253 panic!(
254 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
255 along, shape
256 );
257 }
258 validate_shapes_equal(
259 [
260 sources.0.view_shape(),
261 sources.1.view_shape(),
262 sources.2.view_shape(),
263 ]
264 .into_iter(),
265 );
266 Self {
267 sources,
268 along,
269 _type: PhantomData,
270 }
271 }
272
273 pub fn sources(self) -> (S1, S2, S3) {
277 self.sources
278 }
279
280 pub fn sources_ref(&self) -> &(S1, S2, S3) {
290 &self.sources
291 }
292
293 fn source_view_shape(&self) -> [(Dimension, usize); D] {
297 self.sources.0.view_shape()
298 }
299
300 fn number_of_sources() -> usize {
301 3
302 }
303}
304
305impl<T, S1, S2, S3, S4, const D: usize> TensorStack<T, (S1, S2, S3, S4), D>
306where
307 S1: TensorRef<T, D>,
308 S2: TensorRef<T, D>,
309 S3: TensorRef<T, D>,
310 S4: TensorRef<T, D>,
311{
312 #[track_caller]
323 pub fn from(sources: (S1, S2, S3, S4), along: (usize, Dimension)) -> Self {
324 if along.0 > D {
325 panic!(
326 "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
327 along
328 );
329 }
330 let shape = sources.0.view_shape();
331 if dimensions::contains(&shape, along.1) {
332 panic!(
333 "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
334 along, shape
335 );
336 }
337 validate_shapes_equal(
338 [
339 sources.0.view_shape(),
340 sources.1.view_shape(),
341 sources.2.view_shape(),
342 sources.3.view_shape(),
343 ]
344 .into_iter(),
345 );
346 Self {
347 sources,
348 along,
349 _type: PhantomData,
350 }
351 }
352
353 pub fn sources(self) -> (S1, S2, S3, S4) {
357 self.sources
358 }
359
360 pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
370 &self.sources
371 }
372
373 fn source_view_shape(&self) -> [(Dimension, usize); D] {
377 self.sources.0.view_shape()
378 }
379
380 fn number_of_sources() -> usize {
381 4
382 }
383}
384
385macro_rules! tensor_stack_ref_impl {
386 (unsafe impl TensorRef for TensorStack $d:literal $mod:ident) => {
387 mod $mod {
389 use crate::tensors::views::{TensorRef, TensorMut, DataLayout, TensorStack};
390 use crate::tensors::Dimension;
391
392 fn view_shape_impl(
393 shape: [(Dimension, usize); $d],
394 along: (usize, Dimension),
395 sources: usize,
396 ) -> [(Dimension, usize); $d + 1] {
397 let mut extra_shape = [("", 0); $d + 1];
398 let mut i = 0;
399 for (d, dimension) in extra_shape.iter_mut().enumerate() {
400 match d == along.0 {
401 true => {
402 *dimension = (along.1, sources);
403 },
405 false => {
406 *dimension = shape[i];
407 i += 1;
408 }
409 }
410 }
411 extra_shape
412 }
413
414 fn indexing(
415 indexes: [usize; $d + 1],
416 along: (usize, Dimension)
417 ) -> (usize, [usize; $d]) {
418 let mut indexes_into_source = [0; $d];
419 let mut i = 0;
420 for (d, &index) in indexes.iter().enumerate() {
421 if d != along.0 {
422 indexes_into_source[i] = index;
423 i += 1;
424 }
425 }
426 (indexes[along.0], indexes_into_source)
427 }
428
429 unsafe impl<T, S, const N: usize> TensorRef<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
430 where
431 S: TensorRef<T, $d>
432 {
433 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
434 let (source, indexes) = indexing(indexes, self.along);
435 self.sources.get(source)?.get_reference(indexes)
436 }
437
438 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
439 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
440 }
441
442 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
443 let (source, indexes) = indexing(indexes, self.along);
444 self.sources.get_unchecked(source).get_reference_unchecked(indexes)
445 }}
446
447 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
448 DataLayout::NonLinear
451 }
452 }
453
454 unsafe impl<T, S, const N: usize> TensorMut<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
455 where
456 S: TensorMut<T, $d>
457 {
458 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
459 let (source, indexes) = indexing(indexes, self.along);
460 self.sources.get_mut(source)?.get_reference_mut(indexes)
461 }
462
463 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
464 let (source, indexes) = indexing(indexes, self.along);
465 self.sources.get_unchecked_mut(source).get_reference_unchecked_mut(indexes)
466 }}
467 }
468
469 unsafe impl<T, S1, S2> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
470 where
471 S1: TensorRef<T, $d>,
472 S2: TensorRef<T, $d>,
473 {
474 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
475 let (source, indexes) = indexing(indexes, self.along);
476 match source {
477 0 => self.sources.0.get_reference(indexes),
478 1 => self.sources.1.get_reference(indexes),
479 _ => None
480 }
481 }
482
483 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
484 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
485 }
486
487 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
488 let (source, indexes) = indexing(indexes, self.along);
489 match source {
490 0 => self.sources.0.get_reference_unchecked(indexes),
491 1 => self.sources.1.get_reference_unchecked(indexes),
492 _ => panic!(
494 "Invalid index should never be given to get_reference_unchecked"
495 )
496 }
497 }}
498
499 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
500 DataLayout::NonLinear
503 }
504 }
505
506 unsafe impl<T, S1, S2> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
507 where
508 S1: TensorMut<T, $d>,
509 S2: TensorMut<T, $d>,
510 {
511 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
512 let (source, indexes) = indexing(indexes, self.along);
513 match source {
514 0 => self.sources.0.get_reference_mut(indexes),
515 1 => self.sources.1.get_reference_mut(indexes),
516 _ => None
517 }
518 }
519
520 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
521 let (source, indexes) = indexing(indexes, self.along);
522 match source {
523 0 => self.sources.0.get_reference_unchecked_mut(indexes),
524 1 => self.sources.1.get_reference_unchecked_mut(indexes),
525 _ => panic!(
527 "Invalid index should never be given to get_reference_unchecked"
528 )
529 }
530 }}
531 }
532
533 unsafe impl<T, S1, S2, S3> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
534 where
535 S1: TensorRef<T, $d>,
536 S2: TensorRef<T, $d>,
537 S3: TensorRef<T, $d>,
538 {
539 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
540 let (source, indexes) = indexing(indexes, self.along);
541 match source {
542 0 => self.sources.0.get_reference(indexes),
543 1 => self.sources.1.get_reference(indexes),
544 2 => self.sources.2.get_reference(indexes),
545 _ => None
546 }
547 }
548
549 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
550 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
551 }
552
553 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
554 let (source, indexes) = indexing(indexes, self.along);
555 match source {
556 0 => self.sources.0.get_reference_unchecked(indexes),
557 1 => self.sources.1.get_reference_unchecked(indexes),
558 2 => self.sources.2.get_reference_unchecked(indexes),
559 _ => panic!(
561 "Invalid index should never be given to get_reference_unchecked"
562 )
563 }
564 }}
565
566 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
567 DataLayout::NonLinear
570 }
571 }
572
573 unsafe impl<T, S1, S2, S3> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
574 where
575 S1: TensorMut<T, $d>,
576 S2: TensorMut<T, $d>,
577 S3: TensorMut<T, $d>,
578 {
579 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
580 let (source, indexes) = indexing(indexes, self.along);
581 match source {
582 0 => self.sources.0.get_reference_mut(indexes),
583 1 => self.sources.1.get_reference_mut(indexes),
584 2 => self.sources.2.get_reference_mut(indexes),
585 _ => None
586 }
587 }
588
589 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
590 let (source, indexes) = indexing(indexes, self.along);
591 match source {
592 0 => self.sources.0.get_reference_unchecked_mut(indexes),
593 1 => self.sources.1.get_reference_unchecked_mut(indexes),
594 2 => self.sources.2.get_reference_unchecked_mut(indexes),
595 _ => panic!(
597 "Invalid index should never be given to get_reference_unchecked"
598 )
599 }
600 }}
601 }
602
603 unsafe impl<T, S1, S2, S3, S4> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
604 where
605 S1: TensorRef<T, $d>,
606 S2: TensorRef<T, $d>,
607 S3: TensorRef<T, $d>,
608 S4: TensorRef<T, $d>,
609 {
610 fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
611 let (source, indexes) = indexing(indexes, self.along);
612 match source {
613 0 => self.sources.0.get_reference(indexes),
614 1 => self.sources.1.get_reference(indexes),
615 2 => self.sources.2.get_reference(indexes),
616 3 => self.sources.3.get_reference(indexes),
617 _ => None
618 }
619 }
620
621 fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
622 view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
623 }
624
625 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
626 let (source, indexes) = indexing(indexes, self.along);
627 match source {
628 0 => self.sources.0.get_reference_unchecked(indexes),
629 1 => self.sources.1.get_reference_unchecked(indexes),
630 2 => self.sources.2.get_reference_unchecked(indexes),
631 3 => self.sources.3.get_reference_unchecked(indexes),
632 _ => panic!(
634 "Invalid index should never be given to get_reference_unchecked"
635 )
636 }
637 }}
638
639 fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
640 DataLayout::NonLinear
643 }
644 }
645
646 unsafe impl<T, S1, S2, S3, S4> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
647 where
648 S1: TensorMut<T, $d>,
649 S2: TensorMut<T, $d>,
650 S3: TensorMut<T, $d>,
651 S4: TensorMut<T, $d>,
652 {
653 fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
654 let (source, indexes) = indexing(indexes, self.along);
655 match source {
656 0 => self.sources.0.get_reference_mut(indexes),
657 1 => self.sources.1.get_reference_mut(indexes),
658 2 => self.sources.2.get_reference_mut(indexes),
659 3 => self.sources.3.get_reference_mut(indexes),
660 _ => None
661 }
662 }
663
664 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
665 let (source, indexes) = indexing(indexes, self.along);
666 match source {
667 0 => self.sources.0.get_reference_unchecked_mut(indexes),
668 1 => self.sources.1.get_reference_unchecked_mut(indexes),
669 2 => self.sources.2.get_reference_unchecked_mut(indexes),
670 3 => self.sources.3.get_reference_unchecked_mut(indexes),
671 _ => panic!(
673 "Invalid index should never be given to get_reference_unchecked"
674 )
675 }
676 }}
677 }
678 }
679 }
680}
681
682tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 0 zero);
683tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 1 one);
684tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 2 two);
685tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 3 three);
686tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 4 four);
687tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 5 five);
688
689#[test]
690fn test_stacking() {
691 use crate::tensors::Tensor;
692 use crate::tensors::views::{TensorMut, TensorView};
693 let vector1 = Tensor::from([("a", 3)], vec![9, 5, 2]);
694 let vector2 = Tensor::from([("a", 3)], vec![3, 6, 0]);
695 let vector3 = Tensor::from([("a", 3)], vec![8, 7, 1]);
696 let matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
697 (&vector1, &vector2, &vector3),
698 (1, "b"),
699 ));
700 #[rustfmt::skip]
701 assert_eq!(
702 matrix,
703 Tensor::from([("a", 3), ("b", 3)], vec![
704 9, 3, 8,
705 5, 6, 7,
706 2, 0, 1,
707 ])
708 );
709 let different_matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
710 (&vector1, &vector2, &vector3),
711 (0, "b"),
712 ));
713 #[rustfmt::skip]
714 assert_eq!(
715 different_matrix,
716 Tensor::from([("b", 3), ("a", 3)], vec![
717 9, 5, 2,
718 3, 6, 0,
719 8, 7, 1,
720 ])
721 );
722 let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
723 let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
724 Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
725 let tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
726 [matrix_erased, different_matrix_erased],
727 (2, "c"),
728 ));
729 #[rustfmt::skip]
730 assert!(
731 tensor.eq(
732 &Tensor::from([("a", 3), ("b", 3), ("c", 2)], vec![
733 9, 9,
734 3, 5,
735 8, 2,
736
737 5, 3,
738 6, 6,
739 7, 0,
740
741 2, 8,
742 0, 7,
743 1, 1
744 ])
745 ),
746 );
747 let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
748 let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
749 Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
750 let different_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
751 [matrix_erased, different_matrix_erased],
752 (1, "c"),
753 ));
754 #[rustfmt::skip]
755 assert!(
756 different_tensor.eq(
757 &Tensor::from([("a", 3), ("c", 2), ("b", 3)], vec![
758 9, 3, 8,
759 9, 5, 2,
760
761 5, 6, 7,
762 3, 6, 0,
763
764 2, 0, 1,
765 8, 7, 1
766 ])
767 ),
768 );
769 let matrix_erased: Box<dyn TensorRef<_, 2>> = Box::new(matrix.map(|x| x));
770 let different_matrix_erased: Box<dyn TensorRef<_, 2>> =
771 Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
772 let another_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
773 [matrix_erased, different_matrix_erased],
774 (0, "c"),
775 ));
776 #[rustfmt::skip]
777 assert!(
778 another_tensor.eq(
779 &Tensor::from([("c", 2), ("a", 3), ("b", 3)], vec![
780 9, 3, 8,
781 5, 6, 7,
782 2, 0, 1,
783
784 9, 5, 2,
785 3, 6, 0,
786 8, 7, 1,
787 ])
788 ),
789 );
790}
791
792#[derive(Clone, Debug)]
841pub struct TensorChain<T, S, const D: usize> {
842 sources: S,
843 _type: PhantomData<T>,
844 along: usize,
845}
846
847fn validate_shapes_similar<const D: usize, I>(mut shapes: I, along: usize)
848where
849 I: Iterator<Item = [(Dimension, usize); D]>,
850{
851 let first_shape = shapes.next().unwrap();
854 for (i, shape) in shapes.enumerate() {
855 for d in 0..D {
856 let similar = if d == along {
857 shape[d].0 == first_shape[d].0
859 } else {
860 shape[d] == first_shape[d]
861 };
862 if !similar {
863 panic!(
864 "The shapes of each tensor in the sources to chain along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
865 i + 1,
866 shape,
867 first_shape
868 );
869 }
870 }
871 }
872}
873
874impl<T, S, const D: usize, const N: usize> TensorChain<T, [S; N], D>
875where
876 S: TensorRef<T, D>,
877{
878 #[track_caller]
892 pub fn from(sources: [S; N], along: Dimension) -> Self {
893 if N == 0 {
894 panic!("No sources provided");
895 }
896 if D == 0 {
897 panic!("Can't chain along 0 dimensional tensors");
898 }
899 let shape = sources[0].view_shape();
900 let along = match dimensions::position_of(&shape, along) {
901 Some(d) => d,
902 None => panic!(
903 "The dimension {:?} is not in the source's shapes: {:?}",
904 along, shape
905 ),
906 };
907 validate_shapes_similar(sources.iter().map(|tensor| tensor.view_shape()), along);
908 Self {
909 sources,
910 along,
911 _type: PhantomData,
912 }
913 }
914
915 pub fn sources(self) -> [S; N] {
919 self.sources
920 }
921
922 pub fn sources_ref(&self) -> &[S; N] {
932 &self.sources
933 }
934}
935
936impl<T, S1, S2, const D: usize> TensorChain<T, (S1, S2), D>
937where
938 S1: TensorRef<T, D>,
939 S2: TensorRef<T, D>,
940{
941 #[track_caller]
954 pub fn from(sources: (S1, S2), along: Dimension) -> Self {
955 if D == 0 {
956 panic!("Can't chain along 0 dimensional tensors");
957 }
958 let shape = sources.0.view_shape();
959 let along = match dimensions::position_of(&shape, along) {
960 Some(d) => d,
961 None => panic!(
962 "The dimension {:?} is not in the source's shapes: {:?}",
963 along, shape
964 ),
965 };
966 validate_shapes_similar(
967 [sources.0.view_shape(), sources.1.view_shape()].into_iter(),
968 along,
969 );
970 Self {
971 sources,
972 along,
973 _type: PhantomData,
974 }
975 }
976
977 pub fn sources(self) -> (S1, S2) {
981 self.sources
982 }
983
984 pub fn sources_ref(&self) -> &(S1, S2) {
994 &self.sources
995 }
996}
997
998impl<T, S1, S2, S3, const D: usize> TensorChain<T, (S1, S2, S3), D>
999where
1000 S1: TensorRef<T, D>,
1001 S2: TensorRef<T, D>,
1002 S3: TensorRef<T, D>,
1003{
1004 #[track_caller]
1017 pub fn from(sources: (S1, S2, S3), along: Dimension) -> Self {
1018 if D == 0 {
1019 panic!("Can't chain along 0 dimensional tensors");
1020 }
1021 let shape = sources.0.view_shape();
1022 let along = match dimensions::position_of(&shape, along) {
1023 Some(d) => d,
1024 None => panic!(
1025 "The dimension {:?} is not in the source's shapes: {:?}",
1026 along, shape
1027 ),
1028 };
1029 validate_shapes_similar(
1030 [
1031 sources.0.view_shape(),
1032 sources.1.view_shape(),
1033 sources.2.view_shape(),
1034 ]
1035 .into_iter(),
1036 along,
1037 );
1038 Self {
1039 sources,
1040 along,
1041 _type: PhantomData,
1042 }
1043 }
1044
1045 pub fn sources(self) -> (S1, S2, S3) {
1049 self.sources
1050 }
1051
1052 pub fn sources_ref(&self) -> &(S1, S2, S3) {
1062 &self.sources
1063 }
1064}
1065
1066impl<T, S1, S2, S3, S4, const D: usize> TensorChain<T, (S1, S2, S3, S4), D>
1067where
1068 S1: TensorRef<T, D>,
1069 S2: TensorRef<T, D>,
1070 S3: TensorRef<T, D>,
1071 S4: TensorRef<T, D>,
1072{
1073 #[track_caller]
1086 pub fn from(sources: (S1, S2, S3, S4), along: Dimension) -> Self {
1087 if D == 0 {
1088 panic!("Can't chain along 0 dimensional tensors");
1089 }
1090 let shape = sources.0.view_shape();
1091 let along = match dimensions::position_of(&shape, along) {
1092 Some(d) => d,
1093 None => panic!(
1094 "The dimension {:?} is not in the source's shapes: {:?}",
1095 along, shape
1096 ),
1097 };
1098 validate_shapes_similar(
1099 [
1100 sources.0.view_shape(),
1101 sources.1.view_shape(),
1102 sources.2.view_shape(),
1103 sources.3.view_shape(),
1104 ]
1105 .into_iter(),
1106 along,
1107 );
1108 Self {
1109 sources,
1110 along,
1111 _type: PhantomData,
1112 }
1113 }
1114
1115 pub fn sources(self) -> (S1, S2, S3, S4) {
1119 self.sources
1120 }
1121
1122 pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
1132 &self.sources
1133 }
1134}
1135
1136fn view_shape_impl<I, const D: usize>(
1137 first_shape: [(Dimension, usize); D],
1138 shapes: I,
1139 along: usize,
1140) -> [(Dimension, usize); D]
1141where
1142 I: Iterator<Item = [(Dimension, usize); D]>,
1143{
1144 let mut shape = first_shape;
1145 shape[along].1 = shapes.into_iter().map(|shape| shape[along].1).sum();
1146 shape
1147}
1148
1149fn indexing<I, const D: usize>(
1150 indexes: [usize; D],
1151 shapes: I,
1152 along: usize,
1153) -> Option<(usize, [usize; D])>
1154where
1155 I: Iterator<Item = [(Dimension, usize); D]>,
1156{
1157 let mut shapes = shapes.enumerate();
1158 let mut i = indexes[along];
1162 loop {
1163 let (source, next_shape) = shapes.next()?;
1164 let length_along_chained_dimension = next_shape[along].1;
1165 if i < length_along_chained_dimension {
1166 #[allow(clippy::clone_on_copy)]
1167 let mut indexes = indexes.clone();
1168 indexes[along] = i;
1169 return Some((source, indexes));
1170 }
1171 i -= length_along_chained_dimension;
1172 }
1173}
1174
1175unsafe impl<T, S, const D: usize, const N: usize> TensorRef<T, D> for TensorChain<T, [S; N], D>
1176where
1177 S: TensorRef<T, D>,
1178{
1179 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1180 let (source, indexes) = indexing(
1181 indexes,
1182 self.sources.iter().map(|s| s.view_shape()),
1183 self.along,
1184 )?;
1185 self.sources.get(source)?.get_reference(indexes)
1186 }
1187
1188 fn view_shape(&self) -> [(Dimension, usize); D] {
1189 view_shape_impl(
1190 self.sources[0].view_shape(),
1191 self.sources.iter().map(|s| s.view_shape()),
1192 self.along,
1193 )
1194 }
1195
1196 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1197 unsafe {
1198 let (source, indexes) = indexing(
1199 indexes,
1200 self.sources.iter().map(|s| s.view_shape()),
1201 self.along,
1202 )
1203 .unwrap_unchecked();
1206 self.sources
1207 .get(source)
1208 .unwrap()
1209 .get_reference_unchecked(indexes)
1210 }
1211 }
1212
1213 fn data_layout(&self) -> DataLayout<D> {
1214 DataLayout::NonLinear
1217 }
1218}
1219
1220unsafe impl<T, S, const D: usize, const N: usize> TensorMut<T, D> for TensorChain<T, [S; N], D>
1221where
1222 S: TensorMut<T, D>,
1223{
1224 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1225 let (source, indexes) = indexing(
1226 indexes,
1227 self.sources.iter().map(|s| s.view_shape()),
1228 self.along,
1229 )?;
1230 self.sources.get_mut(source)?.get_reference_mut(indexes)
1231 }
1232
1233 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1234 unsafe {
1235 let (source, indexes) = indexing(
1236 indexes,
1237 self.sources.iter().map(|s| s.view_shape()),
1238 self.along,
1239 )
1240 .unwrap_unchecked();
1243 self.sources
1244 .get_mut(source)
1245 .unwrap()
1246 .get_reference_unchecked_mut(indexes)
1247 }
1248 }
1249}
1250
1251unsafe impl<T, S1, S2, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2), D>
1252where
1253 S1: TensorRef<T, D>,
1254 S2: TensorRef<T, D>,
1255{
1256 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1257 let (source, indexes) = indexing(
1258 indexes,
1259 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1260 self.along,
1261 )?;
1262 match source {
1263 0 => self.sources.0.get_reference(indexes),
1264 1 => self.sources.1.get_reference(indexes),
1265 _ => None,
1266 }
1267 }
1268
1269 fn view_shape(&self) -> [(Dimension, usize); D] {
1270 view_shape_impl(
1271 self.sources.0.view_shape(),
1272 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1273 self.along,
1274 )
1275 }
1276
1277 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1278 unsafe {
1279 let (source, indexes) = indexing(
1280 indexes,
1281 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1282 self.along,
1283 )
1284 .unwrap_unchecked();
1287 match source {
1288 0 => self.sources.0.get_reference_unchecked(indexes),
1289 1 => self.sources.1.get_reference_unchecked(indexes),
1290 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1292 }
1293 }
1294 }
1295
1296 fn data_layout(&self) -> DataLayout<D> {
1297 DataLayout::NonLinear
1300 }
1301}
1302
1303unsafe impl<T, S1, S2, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2), D>
1304where
1305 S1: TensorMut<T, D>,
1306 S2: TensorMut<T, D>,
1307{
1308 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1309 let (source, indexes) = indexing(
1310 indexes,
1311 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1312 self.along,
1313 )?;
1314 match source {
1315 0 => self.sources.0.get_reference_mut(indexes),
1316 1 => self.sources.1.get_reference_mut(indexes),
1317 _ => None,
1318 }
1319 }
1320
1321 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1322 unsafe {
1323 let (source, indexes) = indexing(
1324 indexes,
1325 [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1326 self.along,
1327 )
1328 .unwrap_unchecked();
1331 match source {
1332 0 => self.sources.0.get_reference_unchecked_mut(indexes),
1333 1 => self.sources.1.get_reference_unchecked_mut(indexes),
1334 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1336 }
1337 }
1338 }
1339}
1340
1341unsafe impl<T, S1, S2, S3, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2, S3), D>
1342where
1343 S1: TensorRef<T, D>,
1344 S2: TensorRef<T, D>,
1345 S3: TensorRef<T, D>,
1346{
1347 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1348 let (source, indexes) = indexing(
1349 indexes,
1350 [
1351 self.sources.0.view_shape(),
1352 self.sources.1.view_shape(),
1353 self.sources.2.view_shape(),
1354 ]
1355 .into_iter(),
1356 self.along,
1357 )?;
1358 match source {
1359 0 => self.sources.0.get_reference(indexes),
1360 1 => self.sources.1.get_reference(indexes),
1361 2 => self.sources.2.get_reference(indexes),
1362 _ => None,
1363 }
1364 }
1365
1366 fn view_shape(&self) -> [(Dimension, usize); D] {
1367 view_shape_impl(
1368 self.sources.0.view_shape(),
1369 [
1370 self.sources.0.view_shape(),
1371 self.sources.1.view_shape(),
1372 self.sources.2.view_shape(),
1373 ]
1374 .into_iter(),
1375 self.along,
1376 )
1377 }
1378
1379 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1380 unsafe {
1381 let (source, indexes) = indexing(
1382 indexes,
1383 [
1384 self.sources.0.view_shape(),
1385 self.sources.1.view_shape(),
1386 self.sources.2.view_shape(),
1387 ]
1388 .into_iter(),
1389 self.along,
1390 )
1391 .unwrap_unchecked();
1394 match source {
1395 0 => self.sources.0.get_reference_unchecked(indexes),
1396 1 => self.sources.1.get_reference_unchecked(indexes),
1397 2 => self.sources.2.get_reference_unchecked(indexes),
1398 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1400 }
1401 }
1402 }
1403
1404 fn data_layout(&self) -> DataLayout<D> {
1405 DataLayout::NonLinear
1408 }
1409}
1410
1411unsafe impl<T, S1, S2, S3, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2, S3), D>
1412where
1413 S1: TensorMut<T, D>,
1414 S2: TensorMut<T, D>,
1415 S3: TensorMut<T, D>,
1416{
1417 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1418 let (source, indexes) = indexing(
1419 indexes,
1420 [
1421 self.sources.0.view_shape(),
1422 self.sources.1.view_shape(),
1423 self.sources.2.view_shape(),
1424 ]
1425 .into_iter(),
1426 self.along,
1427 )?;
1428 match source {
1429 0 => self.sources.0.get_reference_mut(indexes),
1430 1 => self.sources.1.get_reference_mut(indexes),
1431 2 => self.sources.2.get_reference_mut(indexes),
1432 _ => None,
1433 }
1434 }
1435
1436 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1437 unsafe {
1438 let (source, indexes) = indexing(
1439 indexes,
1440 [
1441 self.sources.0.view_shape(),
1442 self.sources.1.view_shape(),
1443 self.sources.2.view_shape(),
1444 ]
1445 .into_iter(),
1446 self.along,
1447 )
1448 .unwrap_unchecked();
1451 match source {
1452 0 => self.sources.0.get_reference_unchecked_mut(indexes),
1453 1 => self.sources.1.get_reference_unchecked_mut(indexes),
1454 2 => self.sources.2.get_reference_unchecked_mut(indexes),
1455 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1457 }
1458 }
1459 }
1460}
1461
1462unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorRef<T, D>
1463 for TensorChain<T, (S1, S2, S3, S4), D>
1464where
1465 S1: TensorRef<T, D>,
1466 S2: TensorRef<T, D>,
1467 S3: TensorRef<T, D>,
1468 S4: TensorRef<T, D>,
1469{
1470 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1471 let (source, indexes) = indexing(
1472 indexes,
1473 [
1474 self.sources.0.view_shape(),
1475 self.sources.1.view_shape(),
1476 self.sources.2.view_shape(),
1477 self.sources.3.view_shape(),
1478 ]
1479 .into_iter(),
1480 self.along,
1481 )?;
1482 match source {
1483 0 => self.sources.0.get_reference(indexes),
1484 1 => self.sources.1.get_reference(indexes),
1485 2 => self.sources.2.get_reference(indexes),
1486 3 => self.sources.3.get_reference(indexes),
1487 _ => None,
1488 }
1489 }
1490
1491 fn view_shape(&self) -> [(Dimension, usize); D] {
1492 view_shape_impl(
1493 self.sources.0.view_shape(),
1494 [
1495 self.sources.0.view_shape(),
1496 self.sources.1.view_shape(),
1497 self.sources.2.view_shape(),
1498 self.sources.3.view_shape(),
1499 ]
1500 .into_iter(),
1501 self.along,
1502 )
1503 }
1504
1505 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1506 unsafe {
1507 let (source, indexes) = indexing(
1508 indexes,
1509 [
1510 self.sources.0.view_shape(),
1511 self.sources.1.view_shape(),
1512 self.sources.2.view_shape(),
1513 self.sources.3.view_shape(),
1514 ]
1515 .into_iter(),
1516 self.along,
1517 )
1518 .unwrap_unchecked();
1521 match source {
1522 0 => self.sources.0.get_reference_unchecked(indexes),
1523 1 => self.sources.1.get_reference_unchecked(indexes),
1524 2 => self.sources.2.get_reference_unchecked(indexes),
1525 3 => self.sources.3.get_reference_unchecked(indexes),
1526 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1528 }
1529 }
1530 }
1531
1532 fn data_layout(&self) -> DataLayout<D> {
1533 DataLayout::NonLinear
1536 }
1537}
1538
1539unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorMut<T, D>
1540 for TensorChain<T, (S1, S2, S3, S4), D>
1541where
1542 S1: TensorMut<T, D>,
1543 S2: TensorMut<T, D>,
1544 S3: TensorMut<T, D>,
1545 S4: TensorMut<T, D>,
1546{
1547 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1548 let (source, indexes) = indexing(
1549 indexes,
1550 [
1551 self.sources.0.view_shape(),
1552 self.sources.1.view_shape(),
1553 self.sources.2.view_shape(),
1554 self.sources.3.view_shape(),
1555 ]
1556 .into_iter(),
1557 self.along,
1558 )?;
1559 match source {
1560 0 => self.sources.0.get_reference_mut(indexes),
1561 1 => self.sources.1.get_reference_mut(indexes),
1562 2 => self.sources.2.get_reference_mut(indexes),
1563 3 => self.sources.3.get_reference_mut(indexes),
1564 _ => None,
1565 }
1566 }
1567
1568 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1569 unsafe {
1570 let (source, indexes) = indexing(
1571 indexes,
1572 [
1573 self.sources.0.view_shape(),
1574 self.sources.1.view_shape(),
1575 self.sources.2.view_shape(),
1576 self.sources.3.view_shape(),
1577 ]
1578 .into_iter(),
1579 self.along,
1580 )
1581 .unwrap_unchecked();
1584 match source {
1585 0 => self.sources.0.get_reference_unchecked_mut(indexes),
1586 1 => self.sources.1.get_reference_unchecked_mut(indexes),
1587 2 => self.sources.2.get_reference_unchecked_mut(indexes),
1588 3 => self.sources.3.get_reference_unchecked_mut(indexes),
1589 _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1591 }
1592 }
1593 }
1594}
1595
1596#[test]
1597fn test_chaining() {
1598 use crate::tensors::Tensor;
1599 use crate::tensors::views::TensorView;
1600 #[rustfmt::skip]
1601 let matrix1 = Tensor::from(
1602 [("a", 3), ("b", 2)],
1603 vec![
1604 9, 5,
1605 2, 1,
1606 3, 5
1607 ]
1608 );
1609 #[rustfmt::skip]
1610 let matrix2 = Tensor::from(
1611 [("a", 4), ("b", 2)],
1612 vec![
1613 0, 1,
1614 8, 4,
1615 1, 7,
1616 6, 3
1617 ]
1618 );
1619 let matrix = TensorView::from(TensorChain::<_, (_, _), 2>::from((&matrix1, &matrix2), "a"));
1620 #[rustfmt::skip]
1621 assert_eq!(
1622 matrix,
1623 Tensor::from([("a", 7), ("b", 2)], vec![
1624 9, 5,
1625 2, 1,
1626 3, 5,
1627 0, 1,
1628 8, 4,
1629 1, 7,
1630 6, 3
1631 ])
1632 );
1633 let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
1634 let different_matrix = Tensor::from([("a", 7), ("b", 1)], (0..7).collect());
1635 let different_matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(different_matrix);
1636 let another_matrix = TensorView::from(TensorChain::<_, [_; 2], 2>::from(
1637 [matrix_erased, different_matrix_erased],
1638 "b",
1639 ));
1640 #[rustfmt::skip]
1641 assert!(
1642 another_matrix.eq(
1643 &Tensor::from([("a", 7), ("b", 3)], vec![
1644 9, 5, 0,
1645 2, 1, 1,
1646 3, 5, 2,
1647 0, 1, 3,
1648 8, 4, 4,
1649 1, 7, 5,
1650 6, 3, 6
1651 ])
1652 )
1653 );
1654}