Skip to main content

easy_ml/tensors/views/
zip.rs

1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6/**
7 * Combines two or more tensors with the same shape along a new dimension to create a Tensor
8 * with one additional dimension which stacks the sources together along that dimension.
9 *
10 * Note: due to limitations in Rust's const generics support, TensorStack only implements
11 * TensorRef for D from `1` to `6` (from sources of `0` to `5` dimensions respectively), and
12 * only supports tuple combinations for `2` to `4`. If you need to stack more than four tensors
13 * together, you can stack any number with the `[S; N]` implementation, though note this requires
14 * that all the tensors are the same type so you may need to box and erase the types to
15 * `Box<dyn TensorRef<T, D>>`.
16 *
17 * If you want to combine two or more tensors along an *existing* dimension see
18 * [TensorChain](TensorChain).
19 *
20 * ```
21 * use easy_ml::tensors::Tensor;
22 * use easy_ml::tensors::views::{TensorView, TensorStack, TensorRef};
23 * let vector1 = Tensor::from([("data", 5)], vec![0, 1, 2, 3, 4]);
24 * let vector2 = Tensor::from([("data", 5)], vec![2, 4, 8, 16, 32]);
25 * // Because there are 4 variants of `TensorStack::from` you may need to use the turbofish
26 * // to tell the Rust compiler which variant you're using, but the actual type of `S` can be
27 * // left unspecified by using an underscore.
28 * let matrix = TensorStack::<i32, [_; 2], 1>::from([&vector1, &vector2], (0, "sample"));
29 * let equal_matrix = Tensor::from([("sample", 2), ("data", 5)], vec![
30 *   0, 1, 2, 3, 4,
31 *   2, 4, 8, 16, 32
32 * ]);
33 * assert_eq!(equal_matrix, TensorView::from(matrix));
34 *
35 * let also_matrix = TensorStack::<i32, (_, _), 1>::from((vector1, vector2), (0, "sample"));
36 * assert_eq!(equal_matrix, TensorView::from(&also_matrix));
37 *
38 * // To stack `equal_matrix` and `also_matrix` using the `[S; N]` implementation we have to first
39 * // make them the same type, which we can do by boxing and erasing.
40 * let matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(also_matrix);
41 * let equal_matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(equal_matrix);
42 * let tensor = TensorStack::<i32, [_; 2], 2>::from(
43 *     [matrix_erased, equal_matrix_erased], (0, "experiment")
44 * );
45 * assert!(
46 *     TensorView::from(tensor).eq(
47 *         &Tensor::from([("experiment", 2), ("sample", 2), ("data", 5)], vec![
48 *             0, 1, 2, 3, 4,
49 *             2, 4, 8, 16, 32,
50 *
51 *             0, 1, 2, 3, 4,
52 *             2, 4, 8, 16, 32
53 *         ])
54 *     ),
55 * );
56 * ```
57 */
58#[derive(Clone, Debug)]
59pub struct TensorStack<T, S, const D: usize> {
60    sources: S,
61    _type: PhantomData<T>,
62    along: (usize, Dimension),
63}
64
65fn validate_shapes_equal<const D: usize, I>(mut shapes: I)
66where
67    I: Iterator<Item = [(Dimension, usize); D]>,
68{
69    // We'll reject fewer than one tensors in the constructors before getting here, so first unwrap
70    // is always going to succeed.
71    let first_shape = shapes.next().unwrap();
72    for (i, shape) in shapes.enumerate() {
73        if shape != first_shape {
74            panic!(
75                "The shapes of each tensor in the sources to stack along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
76                i + 1,
77                shape,
78                first_shape
79            );
80        }
81    }
82}
83
84impl<T, S, const D: usize, const N: usize> TensorStack<T, [S; N], D>
85where
86    S: TensorRef<T, D>,
87{
88    /**
89     * Creates a TensorStack from an array of sources of the same type and a tuple of which
90     * dimension and name to stack the sources along in the range of 0 <= `d` <= D. The sources
91     * must all have an identical shape, and the dimension name to add must not be in the sources'
92     * shape already.
93     *
94     * # Panics
95     *
96     * If N == 0, the shapes of the sources are not identical, the dimension for stacking is out
97     * of bounds, or the name is already in the sources' shape.
98     *
99     * While N == 1 arguments may be valid [TensorExpansion](crate::tensors::views::TensorExpansion)
100     * is a more general way to add dimensions with no additional data.
101     */
102    #[track_caller]
103    pub fn from(sources: [S; N], along: (usize, Dimension)) -> Self {
104        if N == 0 {
105            panic!("No sources provided");
106        }
107        if along.0 > D {
108            panic!(
109                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
110                along
111            );
112        }
113        let shape = sources[0].view_shape();
114        if dimensions::contains(&shape, along.1) {
115            panic!(
116                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
117                along, shape
118            );
119        }
120        validate_shapes_equal(sources.iter().map(|tensor| tensor.view_shape()));
121        Self {
122            sources,
123            along,
124            _type: PhantomData,
125        }
126    }
127
128    /**
129     * Consumes the TensorStack, yielding the sources it was created from in the same order.
130     */
131    pub fn sources(self) -> [S; N] {
132        self.sources
133    }
134
135    // # Safety
136    //
137    // Giving out a mutable reference to our sources could allow then to be changed out from under
138    // us and make our shape invalid. However, since the sources implement TensorRef interior
139    // mutability is not allowed, so we can give out shared references without breaking our own
140    // integrity.
141    /**
142     * Gives a reference to all the TensorStack's sources it was created from in the same order
143     */
144    pub fn sources_ref(&self) -> &[S; N] {
145        &self.sources
146    }
147
148    /**
149     * Returns the shape of each of the matching sources the TensorStack was created from.
150     */
151    fn source_view_shape(&self) -> [(Dimension, usize); D] {
152        self.sources[0].view_shape()
153    }
154
155    fn number_of_sources() -> usize {
156        N
157    }
158}
159
160impl<T, S1, S2, const D: usize> TensorStack<T, (S1, S2), D>
161where
162    S1: TensorRef<T, D>,
163    S2: TensorRef<T, D>,
164{
165    /**
166     * Creates a TensorStack from two sources and a tuple of which dimension and name to stack
167     * the sources along in the range of 0 <= `d` <= D. The sources must all have an identical
168     * shape, and the dimension name to add must not be in the sources' shape already.
169     *
170     * # Panics
171     *
172     * If the shapes of the sources are not identical, the dimension for stacking is out
173     * of bounds, or the name is already in the sources' shape.
174     */
175    #[track_caller]
176    pub fn from(sources: (S1, S2), along: (usize, Dimension)) -> Self {
177        if along.0 > D {
178            panic!(
179                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
180                along
181            );
182        }
183        let shape = sources.0.view_shape();
184        if dimensions::contains(&shape, along.1) {
185            panic!(
186                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
187                along, shape
188            );
189        }
190        validate_shapes_equal([sources.0.view_shape(), sources.1.view_shape()].into_iter());
191        Self {
192            sources,
193            along,
194            _type: PhantomData,
195        }
196    }
197
198    /**
199     * Consumes the TensorStack, yielding the sources it was created from in the same order.
200     */
201    pub fn sources(self) -> (S1, S2) {
202        self.sources
203    }
204
205    // # Safety
206    //
207    // Giving out a mutable reference to our sources could allow then to be changed out from under
208    // us and make our shape invalid. However, since the sources implement TensorRef interior
209    // mutability is not allowed, so we can give out shared references without breaking our own
210    // integrity.
211    /**
212     * Gives a reference to all the TensorStack's sources it was created from in the same order
213     */
214    pub fn sources_ref(&self) -> &(S1, S2) {
215        &self.sources
216    }
217
218    /**
219     * Returns the shape of each of the matching sources the TensorStack was created from.
220     */
221    fn source_view_shape(&self) -> [(Dimension, usize); D] {
222        self.sources.0.view_shape()
223    }
224
225    fn number_of_sources() -> usize {
226        2
227    }
228}
229
230impl<T, S1, S2, S3, const D: usize> TensorStack<T, (S1, S2, S3), D>
231where
232    S1: TensorRef<T, D>,
233    S2: TensorRef<T, D>,
234    S3: TensorRef<T, D>,
235{
236    /**
237     * Creates a TensorStack from three sources and a tuple of which dimension and name to stack
238     * the sources along in the range of 0 <= `d` <= D. The sources must all have an identical
239     * shape, and the dimension name to add must not be in the sources' shape already.
240     *
241     * # Panics
242     *
243     * If the shapes of the sources are not identical, the dimension for stacking is out
244     * of bounds, or the name is already in the sources' shape.
245     */
246    #[track_caller]
247    pub fn from(sources: (S1, S2, S3), along: (usize, Dimension)) -> Self {
248        if along.0 > D {
249            panic!(
250                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
251                along
252            );
253        }
254        let shape = sources.0.view_shape();
255        if dimensions::contains(&shape, along.1) {
256            panic!(
257                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
258                along, shape
259            );
260        }
261        validate_shapes_equal(
262            [
263                sources.0.view_shape(),
264                sources.1.view_shape(),
265                sources.2.view_shape(),
266            ]
267            .into_iter(),
268        );
269        Self {
270            sources,
271            along,
272            _type: PhantomData,
273        }
274    }
275
276    /**
277     * Consumes the TensorStack, yielding the sources it was created from in the same order.
278     */
279    pub fn sources(self) -> (S1, S2, S3) {
280        self.sources
281    }
282
283    // # Safety
284    //
285    // Giving out a mutable reference to our sources could allow then to be changed out from under
286    // us and make our shape invalid. However, since the sources implement TensorRef interior
287    // mutability is not allowed, so we can give out shared references without breaking our own
288    // integrity.
289    /**
290     * Gives a reference to all the TensorStack's sources it was created from in the same order
291     */
292    pub fn sources_ref(&self) -> &(S1, S2, S3) {
293        &self.sources
294    }
295
296    /**
297     * Returns the shape of each of the matching sources the TensorStack was created from.
298     */
299    fn source_view_shape(&self) -> [(Dimension, usize); D] {
300        self.sources.0.view_shape()
301    }
302
303    fn number_of_sources() -> usize {
304        3
305    }
306}
307
308impl<T, S1, S2, S3, S4, const D: usize> TensorStack<T, (S1, S2, S3, S4), D>
309where
310    S1: TensorRef<T, D>,
311    S2: TensorRef<T, D>,
312    S3: TensorRef<T, D>,
313    S4: TensorRef<T, D>,
314{
315    /**
316     * Creates a TensorStack from four sources and a tuple of which dimension and name to stack
317     * the sources along in the range of 0 <= `d` <= D. The sources must all have an identical
318     * shape, and the dimension name to add must not be in the sources' shape already.
319     *
320     * # Panics
321     *
322     * If the shapes of the sources are not identical, the dimension for stacking is out
323     * of bounds, or the name is already in the sources' shape.
324     */
325    #[track_caller]
326    pub fn from(sources: (S1, S2, S3, S4), along: (usize, Dimension)) -> Self {
327        if along.0 > D {
328            panic!(
329                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
330                along
331            );
332        }
333        let shape = sources.0.view_shape();
334        if dimensions::contains(&shape, along.1) {
335            panic!(
336                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
337                along, shape
338            );
339        }
340        validate_shapes_equal(
341            [
342                sources.0.view_shape(),
343                sources.1.view_shape(),
344                sources.2.view_shape(),
345                sources.3.view_shape(),
346            ]
347            .into_iter(),
348        );
349        Self {
350            sources,
351            along,
352            _type: PhantomData,
353        }
354    }
355
356    /**
357     * Consumes the TensorStack, yielding the sources it was created from in the same order.
358     */
359    pub fn sources(self) -> (S1, S2, S3, S4) {
360        self.sources
361    }
362
363    // # Safety
364    //
365    // Giving out a mutable reference to our sources could allow then to be changed out from under
366    // us and make our shape invalid. However, since the sources implement TensorRef interior
367    // mutability is not allowed, so we can give out shared references without breaking our own
368    // integrity.
369    /**
370     * Gives a reference to all the TensorStack's sources it was created from in the same order
371     */
372    pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
373        &self.sources
374    }
375
376    /**
377     * Returns the shape of each of the matching sources the TensorStack was created from.
378     */
379    fn source_view_shape(&self) -> [(Dimension, usize); D] {
380        self.sources.0.view_shape()
381    }
382
383    fn number_of_sources() -> usize {
384        4
385    }
386}
387
388macro_rules! tensor_stack_ref_impl {
389    (unsafe impl TensorRef for TensorStack $d:literal $mod:ident) => {
390        // To avoid helper name clashes we use a different module per macro invocation
391        mod $mod {
392            use crate::tensors::views::{TensorRef, TensorMut, DataLayout, TensorStack};
393            use crate::tensors::Dimension;
394
395            fn view_shape_impl(
396                shape: [(Dimension, usize); $d],
397                along: (usize, Dimension),
398                sources: usize,
399            ) -> [(Dimension, usize); $d + 1] {
400                let mut extra_shape = [("", 0); $d + 1];
401                let mut i = 0;
402                for (d, dimension) in extra_shape.iter_mut().enumerate() {
403                    match d == along.0 {
404                        true => {
405                            *dimension = (along.1, sources);
406                            // Do not increment i, this is the added dimension
407                        },
408                        false => {
409                            *dimension = shape[i];
410                            i += 1;
411                        }
412                    }
413                }
414                extra_shape
415            }
416
417            fn indexing(
418                indexes: [usize; $d + 1],
419                along: (usize, Dimension)
420            ) -> (usize, [usize; $d]) {
421                let mut indexes_into_source = [0; $d];
422                let mut i = 0;
423                for (d, &index) in indexes.iter().enumerate() {
424                    if d != along.0 {
425                        indexes_into_source[i] = index;
426                        i += 1;
427                    }
428                }
429                (indexes[along.0], indexes_into_source)
430            }
431
432            unsafe impl<T, S, const N: usize> TensorRef<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
433            where
434                S: TensorRef<T, $d>
435            {
436                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
437                    let (source, indexes) = indexing(indexes, self.along);
438                    self.sources.get(source)?.get_reference(indexes)
439                }
440
441                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
442                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
443                }
444
445                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
446                    let (source, indexes) = indexing(indexes, self.along);
447                    self.sources.get_unchecked(source).get_reference_unchecked(indexes)
448                }}
449
450                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
451                    // Our stacked shapes means the view shape no longer matches up to a single
452                    // line of data in memory.
453                    DataLayout::NonLinear
454                }
455            }
456
457            unsafe impl<T, S, const N: usize> TensorMut<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
458            where
459                S: TensorMut<T, $d>
460            {
461                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
462                    let (source, indexes) = indexing(indexes, self.along);
463                    self.sources.get_mut(source)?.get_reference_mut(indexes)
464                }
465
466                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
467                    let (source, indexes) = indexing(indexes, self.along);
468                    self.sources.get_unchecked_mut(source).get_reference_unchecked_mut(indexes)
469                }}
470            }
471
472            unsafe impl<T, S1, S2> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
473            where
474                S1: TensorRef<T, $d>,
475                S2: TensorRef<T, $d>,
476            {
477                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
478                    let (source, indexes) = indexing(indexes, self.along);
479                    match source {
480                        0 => self.sources.0.get_reference(indexes),
481                        1 => self.sources.1.get_reference(indexes),
482                        _ => None
483                    }
484                }
485
486                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
487                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
488                }
489
490                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
491                    let (source, indexes) = indexing(indexes, self.along);
492                    match source {
493                        0 => self.sources.0.get_reference_unchecked(indexes),
494                        1 => self.sources.1.get_reference_unchecked(indexes),
495                        // TODO: Can we use unreachable_unchecked here?
496                        _ => panic!(
497                            "Invalid index should never be given to get_reference_unchecked"
498                        )
499                    }
500                }}
501
502                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
503                    // Our stacked shapes means the view shape no longer matches up to a single
504                    // line of data in memory.
505                    DataLayout::NonLinear
506                }
507            }
508
509            unsafe impl<T, S1, S2> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
510            where
511                S1: TensorMut<T, $d>,
512                S2: TensorMut<T, $d>,
513            {
514                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
515                    let (source, indexes) = indexing(indexes, self.along);
516                    match source {
517                        0 => self.sources.0.get_reference_mut(indexes),
518                        1 => self.sources.1.get_reference_mut(indexes),
519                        _ => None
520                    }
521                }
522
523                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
524                    let (source, indexes) = indexing(indexes, self.along);
525                    match source {
526                        0 => self.sources.0.get_reference_unchecked_mut(indexes),
527                        1 => self.sources.1.get_reference_unchecked_mut(indexes),
528                        // TODO: Can we use unreachable_unchecked here?
529                        _ => panic!(
530                            "Invalid index should never be given to get_reference_unchecked"
531                        )
532                    }
533                }}
534            }
535
536            unsafe impl<T, S1, S2, S3> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
537            where
538                S1: TensorRef<T, $d>,
539                S2: TensorRef<T, $d>,
540                S3: TensorRef<T, $d>,
541            {
542                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
543                    let (source, indexes) = indexing(indexes, self.along);
544                    match source {
545                        0 => self.sources.0.get_reference(indexes),
546                        1 => self.sources.1.get_reference(indexes),
547                        2 => self.sources.2.get_reference(indexes),
548                        _ => None
549                    }
550                }
551
552                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
553                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
554                }
555
556                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
557                    let (source, indexes) = indexing(indexes, self.along);
558                    match source {
559                        0 => self.sources.0.get_reference_unchecked(indexes),
560                        1 => self.sources.1.get_reference_unchecked(indexes),
561                        2 => self.sources.2.get_reference_unchecked(indexes),
562                        // TODO: Can we use unreachable_unchecked here?
563                        _ => panic!(
564                            "Invalid index should never be given to get_reference_unchecked"
565                        )
566                    }
567                }}
568
569                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
570                    // Our stacked shapes means the view shape no longer matches up to a single
571                    // line of data in memory.
572                    DataLayout::NonLinear
573                }
574            }
575
576            unsafe impl<T, S1, S2, S3> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
577            where
578                S1: TensorMut<T, $d>,
579                S2: TensorMut<T, $d>,
580                S3: TensorMut<T, $d>,
581            {
582                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
583                    let (source, indexes) = indexing(indexes, self.along);
584                    match source {
585                        0 => self.sources.0.get_reference_mut(indexes),
586                        1 => self.sources.1.get_reference_mut(indexes),
587                        2 => self.sources.2.get_reference_mut(indexes),
588                        _ => None
589                    }
590                }
591
592                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
593                    let (source, indexes) = indexing(indexes, self.along);
594                    match source {
595                        0 => self.sources.0.get_reference_unchecked_mut(indexes),
596                        1 => self.sources.1.get_reference_unchecked_mut(indexes),
597                        2 => self.sources.2.get_reference_unchecked_mut(indexes),
598                        // TODO: Can we use unreachable_unchecked here?
599                        _ => panic!(
600                            "Invalid index should never be given to get_reference_unchecked"
601                        )
602                    }
603                }}
604            }
605
606            unsafe impl<T, S1, S2, S3, S4> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
607            where
608                S1: TensorRef<T, $d>,
609                S2: TensorRef<T, $d>,
610                S3: TensorRef<T, $d>,
611                S4: TensorRef<T, $d>,
612            {
613                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
614                    let (source, indexes) = indexing(indexes, self.along);
615                    match source {
616                        0 => self.sources.0.get_reference(indexes),
617                        1 => self.sources.1.get_reference(indexes),
618                        2 => self.sources.2.get_reference(indexes),
619                        3 => self.sources.3.get_reference(indexes),
620                        _ => None
621                    }
622                }
623
624                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
625                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
626                }
627
628                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
629                    let (source, indexes) = indexing(indexes, self.along);
630                    match source {
631                        0 => self.sources.0.get_reference_unchecked(indexes),
632                        1 => self.sources.1.get_reference_unchecked(indexes),
633                        2 => self.sources.2.get_reference_unchecked(indexes),
634                        3 => self.sources.3.get_reference_unchecked(indexes),
635                        // TODO: Can we use unreachable_unchecked here?
636                        _ => panic!(
637                            "Invalid index should never be given to get_reference_unchecked"
638                        )
639                    }
640                }}
641
642                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
643                    // Our stacked shapes means the view shape no longer matches up to a single
644                    // line of data in memory.
645                    DataLayout::NonLinear
646                }
647            }
648
649            unsafe impl<T, S1, S2, S3, S4> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
650            where
651                S1: TensorMut<T, $d>,
652                S2: TensorMut<T, $d>,
653                S3: TensorMut<T, $d>,
654                S4: TensorMut<T, $d>,
655            {
656                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
657                    let (source, indexes) = indexing(indexes, self.along);
658                    match source {
659                        0 => self.sources.0.get_reference_mut(indexes),
660                        1 => self.sources.1.get_reference_mut(indexes),
661                        2 => self.sources.2.get_reference_mut(indexes),
662                        3 => self.sources.3.get_reference_mut(indexes),
663                        _ => None
664                    }
665                }
666
667                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
668                    let (source, indexes) = indexing(indexes, self.along);
669                    match source {
670                        0 => self.sources.0.get_reference_unchecked_mut(indexes),
671                        1 => self.sources.1.get_reference_unchecked_mut(indexes),
672                        2 => self.sources.2.get_reference_unchecked_mut(indexes),
673                        3 => self.sources.3.get_reference_unchecked_mut(indexes),
674                        // TODO: Can we use unreachable_unchecked here?
675                        _ => panic!(
676                            "Invalid index should never be given to get_reference_unchecked"
677                        )
678                    }
679                }}
680            }
681        }
682    }
683}
684
685tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 0 zero);
686tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 1 one);
687tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 2 two);
688tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 3 three);
689tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 4 four);
690tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 5 five);
691
692#[test]
693fn test_stacking() {
694    use crate::tensors::Tensor;
695    use crate::tensors::views::{TensorMut, TensorView};
696    let vector1 = Tensor::from([("a", 3)], vec![9, 5, 2]);
697    let vector2 = Tensor::from([("a", 3)], vec![3, 6, 0]);
698    let vector3 = Tensor::from([("a", 3)], vec![8, 7, 1]);
699    let matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
700        (&vector1, &vector2, &vector3),
701        (1, "b"),
702    ));
703    #[rustfmt::skip]
704    assert_eq!(
705        matrix,
706        Tensor::from([("a", 3), ("b", 3)], vec![
707            9, 3, 8,
708            5, 6, 7,
709            2, 0, 1,
710        ])
711    );
712    let different_matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
713        (&vector1, &vector2, &vector3),
714        (0, "b"),
715    ));
716    #[rustfmt::skip]
717    assert_eq!(
718        different_matrix,
719        Tensor::from([("b", 3), ("a", 3)], vec![
720            9, 5, 2,
721            3, 6, 0,
722            8, 7, 1,
723        ])
724    );
725    let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
726    let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
727        Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
728    let tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
729        [matrix_erased, different_matrix_erased],
730        (2, "c"),
731    ));
732    #[rustfmt::skip]
733    assert!(
734        tensor.eq(
735            &Tensor::from([("a", 3), ("b", 3), ("c", 2)], vec![
736                9, 9,
737                3, 5,
738                8, 2,
739
740                5, 3,
741                6, 6,
742                7, 0,
743
744                2, 8,
745                0, 7,
746                1, 1
747            ])
748        ),
749    );
750    let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
751    let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
752        Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
753    let different_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
754        [matrix_erased, different_matrix_erased],
755        (1, "c"),
756    ));
757    #[rustfmt::skip]
758    assert!(
759        different_tensor.eq(
760            &Tensor::from([("a", 3), ("c", 2), ("b", 3)], vec![
761                9, 3, 8,
762                9, 5, 2,
763
764                5, 6, 7,
765                3, 6, 0,
766
767                2, 0, 1,
768                8, 7, 1
769            ])
770        ),
771    );
772    let matrix_erased: Box<dyn TensorRef<_, 2>> = Box::new(matrix.map(|x| x));
773    let different_matrix_erased: Box<dyn TensorRef<_, 2>> =
774        Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
775    let another_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
776        [matrix_erased, different_matrix_erased],
777        (0, "c"),
778    ));
779    #[rustfmt::skip]
780    assert!(
781        another_tensor.eq(
782            &Tensor::from([("c", 2), ("a", 3), ("b", 3)], vec![
783                9, 3, 8,
784                5, 6, 7,
785                2, 0, 1,
786
787                9, 5, 2,
788                3, 6, 0,
789                8, 7, 1,
790            ])
791        ),
792    );
793}
794
795/**
796 * Combines two or more tensors along an existing dimension in their shapes to create a Tensor
797 * with a length in that dimension equal to the sum of the sources together along that dimension.
798 * All other dimensions in the tensors' shapes must be the same.
799 *
800 * This can be framed as an D dimensional version of
801 * [std::iter::Iterator::chain](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.chain)
802 *
803 * Note: TensorChain only supports tuple combinations for `2` to `4`. If you need to stack more
804 * than four tensors together, you can stack any number with the `[S; N]` implementation, though
805 * note this requires that all the tensors are the same type so you may need to box and erase
806 * the types to `Box<dyn TensorRef<T, D>>`.
807 *
808 * If you want to combine two or more tensors along an *new* dimension see
809 * [TensorStack](TensorStack).
810 *
811 * ```
812 * use easy_ml::tensors::Tensor;
813 * use easy_ml::tensors::views::{TensorView, TensorChain, TensorRef};
814 * let sample1 = Tensor::from([("sample", 1), ("data", 5)], vec![0, 1, 2, 3, 4]);
815 * let sample2 = Tensor::from([("sample", 1), ("data", 5)], vec![2, 4, 8, 16, 32]);
816 * // Because there are 4 variants of `TensorChain::from` you may need to use the turbofish
817 * // to tell the Rust compiler which variant you're using, but the actual type of `S` can be
818 * // left unspecified by using an underscore.
819 * let matrix = TensorChain::<i32, [_; 2], 2>::from([&sample1, &sample2], "sample");
820 * let equal_matrix = Tensor::from([("sample", 2), ("data", 5)], vec![
821 *     0, 1, 2, 3, 4,
822 *     2, 4, 8, 16, 32
823 *  ]);
824 * assert_eq!(equal_matrix, TensorView::from(matrix));
825 *
826 * let also_matrix = TensorChain::<i32, (_, _), 2>::from((sample1, sample2), "sample");
827 * assert_eq!(equal_matrix, TensorView::from(&also_matrix));
828 *
829 * // To stack `equal_matrix` and `also_matrix` using the `[S; N]` implementation we have to first
830 * // make them the same type, which we can do by boxing and erasing.
831 * let matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(also_matrix);
832 * let equal_matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(equal_matrix);
833 * let repeated_data = TensorChain::<i32, [_; 2], 2>::from(
834 *     [matrix_erased, equal_matrix_erased], "data"
835 * );
836 * assert!(
837 *     TensorView::from(repeated_data).eq(
838 *         &Tensor::from([("sample", 2), ("data", 10)], vec![
839 *             0, 1, 2,  3,  4, 0, 1, 2,  3,  4,
840 *             2, 4, 8, 16, 32, 2, 4, 8, 16, 32
841 *         ])
842 *     ),
843 * );
844 * ```
845 */
846#[derive(Clone, Debug)]
847pub struct TensorChain<T, S, const D: usize> {
848    sources: S,
849    _type: PhantomData<T>,
850    along: usize,
851}
852
853fn validate_shapes_similar<const D: usize, I>(mut shapes: I, along: usize)
854where
855    I: Iterator<Item = [(Dimension, usize); D]>,
856{
857    // We'll reject fewer than one tensors in the constructors before getting here, so first unwrap
858    // is always going to succeed.
859    let first_shape = shapes.next().unwrap();
860    for (i, shape) in shapes.enumerate() {
861        for d in 0..D {
862            let similar = if d == along {
863                // don't need match for dimension lengths in the `along` dimension
864                shape[d].0 == first_shape[d].0
865            } else {
866                shape[d] == first_shape[d]
867            };
868            if !similar {
869                panic!(
870                    "The shapes of each tensor in the sources to chain along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
871                    i + 1,
872                    shape,
873                    first_shape
874                );
875            }
876        }
877    }
878}
879
880impl<T, S, const D: usize, const N: usize> TensorChain<T, [S; N], D>
881where
882    S: TensorRef<T, D>,
883{
884    /**
885     * Creates a TensorChain from an array of sources of the same type and the dimension name to
886     * chain the sources along. The sources must all have an identical shape, including the
887     * provided dimension, except for the dimension lengths of the provided dimension name which
888     * may be different.
889     *
890     * # Panics
891     *
892     * If N == 0, D == 0, the shapes of the sources are not identical*, or the dimension for
893     * chaining is not in sources' shape.
894     *
895     * *except for the lengths along the provided dimension.
896     */
897    #[track_caller]
898    pub fn from(sources: [S; N], along: Dimension) -> Self {
899        if N == 0 {
900            panic!("No sources provided");
901        }
902        if D == 0 {
903            panic!("Can't chain along 0 dimensional tensors");
904        }
905        let shape = sources[0].view_shape();
906        let along = match dimensions::position_of(&shape, along) {
907            Some(d) => d,
908            None => panic!(
909                "The dimension {:?} is not in the source's shapes: {:?}",
910                along, shape
911            ),
912        };
913        validate_shapes_similar(sources.iter().map(|tensor| tensor.view_shape()), along);
914        Self {
915            sources,
916            along,
917            _type: PhantomData,
918        }
919    }
920
921    /**
922     * Consumes the TensorChain, yielding the sources it was created from in the same order.
923     */
924    pub fn sources(self) -> [S; N] {
925        self.sources
926    }
927
928    // # Safety
929    //
930    // Giving out a mutable reference to our sources could allow then to be changed out from under
931    // us and make our shape invalid. However, since the sources implement TensorRef interior
932    // mutability is not allowed, so we can give out shared references without breaking our own
933    // integrity.
934    /**
935     * Gives a reference to all the TensorChain's sources it was created from in the same order
936     */
937    pub fn sources_ref(&self) -> &[S; N] {
938        &self.sources
939    }
940}
941
942impl<T, S1, S2, const D: usize> TensorChain<T, (S1, S2), D>
943where
944    S1: TensorRef<T, D>,
945    S2: TensorRef<T, D>,
946{
947    /**
948     * Creates a TensorChain from two sources and the dimension name to chain the sources along.
949     * The sources must all have an identical shape, including the provided dimension, except for
950     * the dimension lengths of the provided dimension name which may be different.
951     *
952     * # Panics
953     *
954     * If D == 0, the shapes of the sources are not identical*, or the dimension for
955     * chaining is not in sources' shape.
956     *
957     * *except for the lengths along the provided dimension.
958     */
959    #[track_caller]
960    pub fn from(sources: (S1, S2), along: Dimension) -> Self {
961        if D == 0 {
962            panic!("Can't chain along 0 dimensional tensors");
963        }
964        let shape = sources.0.view_shape();
965        let along = match dimensions::position_of(&shape, along) {
966            Some(d) => d,
967            None => panic!(
968                "The dimension {:?} is not in the source's shapes: {:?}",
969                along, shape
970            ),
971        };
972        validate_shapes_similar(
973            [sources.0.view_shape(), sources.1.view_shape()].into_iter(),
974            along,
975        );
976        Self {
977            sources,
978            along,
979            _type: PhantomData,
980        }
981    }
982
983    /**
984     * Consumes the TensorChain, yielding the sources it was created from in the same order.
985     */
986    pub fn sources(self) -> (S1, S2) {
987        self.sources
988    }
989
990    // # Safety
991    //
992    // Giving out a mutable reference to our sources could allow then to be changed out from under
993    // us and make our shape invalid. However, since the sources implement TensorRef interior
994    // mutability is not allowed, so we can give out shared references without breaking our own
995    // integrity.
996    /**
997     * Gives a reference to all the TensorChain's sources it was created from in the same order
998     */
999    pub fn sources_ref(&self) -> &(S1, S2) {
1000        &self.sources
1001    }
1002}
1003
1004impl<T, S1, S2, S3, const D: usize> TensorChain<T, (S1, S2, S3), D>
1005where
1006    S1: TensorRef<T, D>,
1007    S2: TensorRef<T, D>,
1008    S3: TensorRef<T, D>,
1009{
1010    /**
1011     * Creates a TensorChain from three sources and the dimension name to chain the sources along.
1012     * The sources must all have an identical shape, including the provided dimension, except for
1013     * the dimension lengths of the provided dimension name which may be different.
1014     *
1015     * # Panics
1016     *
1017     * If D == 0, the shapes of the sources are not identical*, or the dimension for
1018     * chaining is not in sources' shape.
1019     *
1020     * *except for the lengths along the provided dimension.
1021     */
1022    #[track_caller]
1023    pub fn from(sources: (S1, S2, S3), along: Dimension) -> Self {
1024        if D == 0 {
1025            panic!("Can't chain along 0 dimensional tensors");
1026        }
1027        let shape = sources.0.view_shape();
1028        let along = match dimensions::position_of(&shape, along) {
1029            Some(d) => d,
1030            None => panic!(
1031                "The dimension {:?} is not in the source's shapes: {:?}",
1032                along, shape
1033            ),
1034        };
1035        validate_shapes_similar(
1036            [
1037                sources.0.view_shape(),
1038                sources.1.view_shape(),
1039                sources.2.view_shape(),
1040            ]
1041            .into_iter(),
1042            along,
1043        );
1044        Self {
1045            sources,
1046            along,
1047            _type: PhantomData,
1048        }
1049    }
1050
1051    /**
1052     * Consumes the TensorChain, yielding the sources it was created from in the same order.
1053     */
1054    pub fn sources(self) -> (S1, S2, S3) {
1055        self.sources
1056    }
1057
1058    // # Safety
1059    //
1060    // Giving out a mutable reference to our sources could allow then to be changed out from under
1061    // us and make our shape invalid. However, since the sources implement TensorRef interior
1062    // mutability is not allowed, so we can give out shared references without breaking our own
1063    // integrity.
1064    /**
1065     * Gives a reference to all the TensorChain's sources it was created from in the same order
1066     */
1067    pub fn sources_ref(&self) -> &(S1, S2, S3) {
1068        &self.sources
1069    }
1070}
1071
1072impl<T, S1, S2, S3, S4, const D: usize> TensorChain<T, (S1, S2, S3, S4), D>
1073where
1074    S1: TensorRef<T, D>,
1075    S2: TensorRef<T, D>,
1076    S3: TensorRef<T, D>,
1077    S4: TensorRef<T, D>,
1078{
1079    /**
1080     * Creates a TensorChain from four sources and the dimension name to chain the sources along.
1081     * The sources must all have an identical shape, including the provided dimension, except for
1082     * the dimension lengths of the provided dimension name which may be different.
1083     *
1084     * # Panics
1085     *
1086     * If D == 0, the shapes of the sources are not identical*, or the dimension for
1087     * chaining is not in sources' shape.
1088     *
1089     * *except for the lengths along the provided dimension.
1090     */
1091    #[track_caller]
1092    pub fn from(sources: (S1, S2, S3, S4), along: Dimension) -> Self {
1093        if D == 0 {
1094            panic!("Can't chain along 0 dimensional tensors");
1095        }
1096        let shape = sources.0.view_shape();
1097        let along = match dimensions::position_of(&shape, along) {
1098            Some(d) => d,
1099            None => panic!(
1100                "The dimension {:?} is not in the source's shapes: {:?}",
1101                along, shape
1102            ),
1103        };
1104        validate_shapes_similar(
1105            [
1106                sources.0.view_shape(),
1107                sources.1.view_shape(),
1108                sources.2.view_shape(),
1109                sources.3.view_shape(),
1110            ]
1111            .into_iter(),
1112            along,
1113        );
1114        Self {
1115            sources,
1116            along,
1117            _type: PhantomData,
1118        }
1119    }
1120
1121    /**
1122     * Consumes the TensorChain, yielding the sources it was created from in the same order.
1123     */
1124    pub fn sources(self) -> (S1, S2, S3, S4) {
1125        self.sources
1126    }
1127
1128    // # Safety
1129    //
1130    // Giving out a mutable reference to our sources could allow then to be changed out from under
1131    // us and make our shape invalid. However, since the sources implement TensorRef interior
1132    // mutability is not allowed, so we can give out shared references without breaking our own
1133    // integrity.
1134    /**
1135     * Gives a reference to all the TensorChain's sources it was created from in the same order
1136     */
1137    pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
1138        &self.sources
1139    }
1140}
1141
1142fn view_shape_impl<I, const D: usize>(
1143    first_shape: [(Dimension, usize); D],
1144    shapes: I,
1145    along: usize,
1146) -> [(Dimension, usize); D]
1147where
1148    I: Iterator<Item = [(Dimension, usize); D]>,
1149{
1150    let mut shape = first_shape;
1151    shape[along].1 = shapes.into_iter().map(|shape| shape[along].1).sum();
1152    shape
1153}
1154
1155fn indexing<I, const D: usize>(
1156    indexes: [usize; D],
1157    shapes: I,
1158    along: usize,
1159) -> Option<(usize, [usize; D])>
1160where
1161    I: Iterator<Item = [(Dimension, usize); D]>,
1162{
1163    let mut shapes = shapes.enumerate();
1164    // Keep trying to index the next shape in the chain, if i is still greater
1165    // than the available length we know it's for a later shape, and can subtract
1166    // that available length till we find one.
1167    let mut i = indexes[along];
1168    loop {
1169        let (source, next_shape) = shapes.next()?;
1170        let length_along_chained_dimension = next_shape[along].1;
1171        if i < length_along_chained_dimension {
1172            #[allow(clippy::clone_on_copy)]
1173            let mut indexes = indexes.clone();
1174            indexes[along] = i;
1175            return Some((source, indexes));
1176        }
1177        i -= length_along_chained_dimension;
1178    }
1179}
1180
1181unsafe impl<T, S, const D: usize, const N: usize> TensorRef<T, D> for TensorChain<T, [S; N], D>
1182where
1183    S: TensorRef<T, D>,
1184{
1185    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1186        let (source, indexes) = indexing(
1187            indexes,
1188            self.sources.iter().map(|s| s.view_shape()),
1189            self.along,
1190        )?;
1191        self.sources.get(source)?.get_reference(indexes)
1192    }
1193
1194    fn view_shape(&self) -> [(Dimension, usize); D] {
1195        view_shape_impl(
1196            self.sources[0].view_shape(),
1197            self.sources.iter().map(|s| s.view_shape()),
1198            self.along,
1199        )
1200    }
1201
1202    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1203        unsafe {
1204            let (source, indexes) = indexing(
1205                indexes,
1206                self.sources.iter().map(|s| s.view_shape()),
1207                self.along,
1208            )
1209            // The caller is already responsible for providing valid indexes to us
1210            // so `indexing` will always return Some
1211            .unwrap_unchecked();
1212            self.sources
1213                .get(source)
1214                .unwrap()
1215                .get_reference_unchecked(indexes)
1216        }
1217    }
1218
1219    fn data_layout(&self) -> DataLayout<D> {
1220        // Our chained shapes means the view shape no longer matches up to a single
1221        // line of data in memory in the general case.
1222        DataLayout::NonLinear
1223    }
1224}
1225
1226unsafe impl<T, S, const D: usize, const N: usize> TensorMut<T, D> for TensorChain<T, [S; N], D>
1227where
1228    S: TensorMut<T, D>,
1229{
1230    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1231        let (source, indexes) = indexing(
1232            indexes,
1233            self.sources.iter().map(|s| s.view_shape()),
1234            self.along,
1235        )?;
1236        self.sources.get_mut(source)?.get_reference_mut(indexes)
1237    }
1238
1239    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1240        unsafe {
1241            let (source, indexes) = indexing(
1242                indexes,
1243                self.sources.iter().map(|s| s.view_shape()),
1244                self.along,
1245            )
1246            // The caller is already responsible for providing valid indexes to us
1247            // so `indexing` will always return Some
1248            .unwrap_unchecked();
1249            self.sources
1250                .get_mut(source)
1251                .unwrap()
1252                .get_reference_unchecked_mut(indexes)
1253        }
1254    }
1255}
1256
1257unsafe impl<T, S1, S2, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2), D>
1258where
1259    S1: TensorRef<T, D>,
1260    S2: TensorRef<T, D>,
1261{
1262    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1263        let (source, indexes) = indexing(
1264            indexes,
1265            [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1266            self.along,
1267        )?;
1268        match source {
1269            0 => self.sources.0.get_reference(indexes),
1270            1 => self.sources.1.get_reference(indexes),
1271            _ => None,
1272        }
1273    }
1274
1275    fn view_shape(&self) -> [(Dimension, usize); D] {
1276        view_shape_impl(
1277            self.sources.0.view_shape(),
1278            [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1279            self.along,
1280        )
1281    }
1282
1283    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1284        unsafe {
1285            let (source, indexes) = indexing(
1286                indexes,
1287                [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1288                self.along,
1289            )
1290            // The caller is already responsible for providing valid indexes to us
1291            // so `indexing` will always return Some
1292            .unwrap_unchecked();
1293            match source {
1294                0 => self.sources.0.get_reference_unchecked(indexes),
1295                1 => self.sources.1.get_reference_unchecked(indexes),
1296                // TODO: Can we use unreachable_unchecked here?
1297                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1298            }
1299        }
1300    }
1301
1302    fn data_layout(&self) -> DataLayout<D> {
1303        // Our chained shapes means the view shape no longer matches up to a single
1304        // line of data in memory in the general case.
1305        DataLayout::NonLinear
1306    }
1307}
1308
1309unsafe impl<T, S1, S2, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2), D>
1310where
1311    S1: TensorMut<T, D>,
1312    S2: TensorMut<T, D>,
1313{
1314    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1315        let (source, indexes) = indexing(
1316            indexes,
1317            [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1318            self.along,
1319        )?;
1320        match source {
1321            0 => self.sources.0.get_reference_mut(indexes),
1322            1 => self.sources.1.get_reference_mut(indexes),
1323            _ => None,
1324        }
1325    }
1326
1327    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1328        unsafe {
1329            let (source, indexes) = indexing(
1330                indexes,
1331                [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1332                self.along,
1333            )
1334            // The caller is already responsible for providing valid indexes to us
1335            // so `indexing` will always return Some
1336            .unwrap_unchecked();
1337            match source {
1338                0 => self.sources.0.get_reference_unchecked_mut(indexes),
1339                1 => self.sources.1.get_reference_unchecked_mut(indexes),
1340                // TODO: Can we use unreachable_unchecked here?
1341                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1342            }
1343        }
1344    }
1345}
1346
1347unsafe impl<T, S1, S2, S3, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2, S3), D>
1348where
1349    S1: TensorRef<T, D>,
1350    S2: TensorRef<T, D>,
1351    S3: TensorRef<T, D>,
1352{
1353    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1354        let (source, indexes) = indexing(
1355            indexes,
1356            [
1357                self.sources.0.view_shape(),
1358                self.sources.1.view_shape(),
1359                self.sources.2.view_shape(),
1360            ]
1361            .into_iter(),
1362            self.along,
1363        )?;
1364        match source {
1365            0 => self.sources.0.get_reference(indexes),
1366            1 => self.sources.1.get_reference(indexes),
1367            2 => self.sources.2.get_reference(indexes),
1368            _ => None,
1369        }
1370    }
1371
1372    fn view_shape(&self) -> [(Dimension, usize); D] {
1373        view_shape_impl(
1374            self.sources.0.view_shape(),
1375            [
1376                self.sources.0.view_shape(),
1377                self.sources.1.view_shape(),
1378                self.sources.2.view_shape(),
1379            ]
1380            .into_iter(),
1381            self.along,
1382        )
1383    }
1384
1385    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1386        unsafe {
1387            let (source, indexes) = indexing(
1388                indexes,
1389                [
1390                    self.sources.0.view_shape(),
1391                    self.sources.1.view_shape(),
1392                    self.sources.2.view_shape(),
1393                ]
1394                .into_iter(),
1395                self.along,
1396            )
1397            // The caller is already responsible for providing valid indexes to us
1398            // so `indexing` will always return Some
1399            .unwrap_unchecked();
1400            match source {
1401                0 => self.sources.0.get_reference_unchecked(indexes),
1402                1 => self.sources.1.get_reference_unchecked(indexes),
1403                2 => self.sources.2.get_reference_unchecked(indexes),
1404                // TODO: Can we use unreachable_unchecked here?
1405                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1406            }
1407        }
1408    }
1409
1410    fn data_layout(&self) -> DataLayout<D> {
1411        // Our chained shapes means the view shape no longer matches up to a single
1412        // line of data in memory in the general case.
1413        DataLayout::NonLinear
1414    }
1415}
1416
1417unsafe impl<T, S1, S2, S3, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2, S3), D>
1418where
1419    S1: TensorMut<T, D>,
1420    S2: TensorMut<T, D>,
1421    S3: TensorMut<T, D>,
1422{
1423    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1424        let (source, indexes) = indexing(
1425            indexes,
1426            [
1427                self.sources.0.view_shape(),
1428                self.sources.1.view_shape(),
1429                self.sources.2.view_shape(),
1430            ]
1431            .into_iter(),
1432            self.along,
1433        )?;
1434        match source {
1435            0 => self.sources.0.get_reference_mut(indexes),
1436            1 => self.sources.1.get_reference_mut(indexes),
1437            2 => self.sources.2.get_reference_mut(indexes),
1438            _ => None,
1439        }
1440    }
1441
1442    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1443        unsafe {
1444            let (source, indexes) = indexing(
1445                indexes,
1446                [
1447                    self.sources.0.view_shape(),
1448                    self.sources.1.view_shape(),
1449                    self.sources.2.view_shape(),
1450                ]
1451                .into_iter(),
1452                self.along,
1453            )
1454            // The caller is already responsible for providing valid indexes to us
1455            // so `indexing` will always return Some
1456            .unwrap_unchecked();
1457            match source {
1458                0 => self.sources.0.get_reference_unchecked_mut(indexes),
1459                1 => self.sources.1.get_reference_unchecked_mut(indexes),
1460                2 => self.sources.2.get_reference_unchecked_mut(indexes),
1461                // TODO: Can we use unreachable_unchecked here?
1462                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1463            }
1464        }
1465    }
1466}
1467
1468unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorRef<T, D>
1469    for TensorChain<T, (S1, S2, S3, S4), D>
1470where
1471    S1: TensorRef<T, D>,
1472    S2: TensorRef<T, D>,
1473    S3: TensorRef<T, D>,
1474    S4: TensorRef<T, D>,
1475{
1476    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1477        let (source, indexes) = indexing(
1478            indexes,
1479            [
1480                self.sources.0.view_shape(),
1481                self.sources.1.view_shape(),
1482                self.sources.2.view_shape(),
1483                self.sources.3.view_shape(),
1484            ]
1485            .into_iter(),
1486            self.along,
1487        )?;
1488        match source {
1489            0 => self.sources.0.get_reference(indexes),
1490            1 => self.sources.1.get_reference(indexes),
1491            2 => self.sources.2.get_reference(indexes),
1492            3 => self.sources.3.get_reference(indexes),
1493            _ => None,
1494        }
1495    }
1496
1497    fn view_shape(&self) -> [(Dimension, usize); D] {
1498        view_shape_impl(
1499            self.sources.0.view_shape(),
1500            [
1501                self.sources.0.view_shape(),
1502                self.sources.1.view_shape(),
1503                self.sources.2.view_shape(),
1504                self.sources.3.view_shape(),
1505            ]
1506            .into_iter(),
1507            self.along,
1508        )
1509    }
1510
1511    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1512        unsafe {
1513            let (source, indexes) = indexing(
1514                indexes,
1515                [
1516                    self.sources.0.view_shape(),
1517                    self.sources.1.view_shape(),
1518                    self.sources.2.view_shape(),
1519                    self.sources.3.view_shape(),
1520                ]
1521                .into_iter(),
1522                self.along,
1523            )
1524            // The caller is already responsible for providing valid indexes to us
1525            // so `indexing` will always return Some
1526            .unwrap_unchecked();
1527            match source {
1528                0 => self.sources.0.get_reference_unchecked(indexes),
1529                1 => self.sources.1.get_reference_unchecked(indexes),
1530                2 => self.sources.2.get_reference_unchecked(indexes),
1531                3 => self.sources.3.get_reference_unchecked(indexes),
1532                // TODO: Can we use unreachable_unchecked here?
1533                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1534            }
1535        }
1536    }
1537
1538    fn data_layout(&self) -> DataLayout<D> {
1539        // Our chained shapes means the view shape no longer matches up to a single
1540        // line of data in memory in the general case.
1541        DataLayout::NonLinear
1542    }
1543}
1544
1545unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorMut<T, D>
1546    for TensorChain<T, (S1, S2, S3, S4), D>
1547where
1548    S1: TensorMut<T, D>,
1549    S2: TensorMut<T, D>,
1550    S3: TensorMut<T, D>,
1551    S4: TensorMut<T, D>,
1552{
1553    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1554        let (source, indexes) = indexing(
1555            indexes,
1556            [
1557                self.sources.0.view_shape(),
1558                self.sources.1.view_shape(),
1559                self.sources.2.view_shape(),
1560                self.sources.3.view_shape(),
1561            ]
1562            .into_iter(),
1563            self.along,
1564        )?;
1565        match source {
1566            0 => self.sources.0.get_reference_mut(indexes),
1567            1 => self.sources.1.get_reference_mut(indexes),
1568            2 => self.sources.2.get_reference_mut(indexes),
1569            3 => self.sources.3.get_reference_mut(indexes),
1570            _ => None,
1571        }
1572    }
1573
1574    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1575        unsafe {
1576            let (source, indexes) = indexing(
1577                indexes,
1578                [
1579                    self.sources.0.view_shape(),
1580                    self.sources.1.view_shape(),
1581                    self.sources.2.view_shape(),
1582                    self.sources.3.view_shape(),
1583                ]
1584                .into_iter(),
1585                self.along,
1586            )
1587            // The caller is already responsible for providing valid indexes to us
1588            // so `indexing` will always return Some
1589            .unwrap_unchecked();
1590            match source {
1591                0 => self.sources.0.get_reference_unchecked_mut(indexes),
1592                1 => self.sources.1.get_reference_unchecked_mut(indexes),
1593                2 => self.sources.2.get_reference_unchecked_mut(indexes),
1594                3 => self.sources.3.get_reference_unchecked_mut(indexes),
1595                // TODO: Can we use unreachable_unchecked here?
1596                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1597            }
1598        }
1599    }
1600}
1601
1602#[test]
1603fn test_chaining() {
1604    use crate::tensors::Tensor;
1605    use crate::tensors::views::TensorView;
1606    #[rustfmt::skip]
1607    let matrix1 = Tensor::from(
1608        [("a", 3), ("b", 2)],
1609        vec![
1610            9, 5,
1611            2, 1,
1612            3, 5
1613        ]
1614    );
1615    #[rustfmt::skip]
1616    let matrix2 = Tensor::from(
1617        [("a", 4), ("b", 2)],
1618        vec![
1619            0, 1,
1620            8, 4,
1621            1, 7,
1622            6, 3
1623        ]
1624    );
1625    let matrix = TensorView::from(TensorChain::<_, (_, _), 2>::from((&matrix1, &matrix2), "a"));
1626    #[rustfmt::skip]
1627    assert_eq!(
1628        matrix,
1629        Tensor::from([("a", 7), ("b", 2)], vec![
1630            9, 5,
1631            2, 1,
1632            3, 5,
1633            0, 1,
1634            8, 4,
1635            1, 7,
1636            6, 3
1637        ])
1638    );
1639    let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
1640    let different_matrix = Tensor::from([("a", 7), ("b", 1)], (0..7).collect());
1641    let different_matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(different_matrix);
1642    let another_matrix = TensorView::from(TensorChain::<_, [_; 2], 2>::from(
1643        [matrix_erased, different_matrix_erased],
1644        "b",
1645    ));
1646    #[rustfmt::skip]
1647    assert!(
1648        another_matrix.eq(
1649            &Tensor::from([("a", 7), ("b", 3)], vec![
1650                9, 5, 0,
1651                2, 1, 1,
1652                3, 5, 2,
1653                0, 1, 3,
1654                8, 4, 4,
1655                1, 7, 5,
1656                6, 3, 6
1657            ])
1658        )
1659    );
1660}