easy_ml/tensors/views/
zip.rs

1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6/**
7 * Combines two or more tensors with the same shape along a new dimension to create a Tensor
8 * with one additional dimension which stacks the sources together along that dimension.
9 *
10 * Note: due to limitations in Rust's const generics support, TensorStack only implements
11 * TensorRef for D from `1` to `6` (from sources of `0` to `5` dimensions respectively), and
12 * only supports tuple combinations for `2` to `4`. If you need to stack more than four tensors
13 * together, you can stack any number with the `[S; N]` implementation, though note this requires
14 * that all the tensors are the same type so you may need to box and erase the types to
15 * `Box<dyn TensorRef<T, D>>`.
16 *
17 * ```
18 * use easy_ml::tensors::Tensor;
19 * use easy_ml::tensors::views::{TensorView, TensorStack, TensorRef};
20 * let vector1 = Tensor::from([("data", 5)], vec![0, 1, 2, 3, 4]);
21 * let vector2 = Tensor::from([("data", 5)], vec![2, 4, 8, 16, 32]);
22 * // Because there are 4 variants of `TensorStack::from` you may need to use the turbofish
23 * // to tell the Rust compiler which variant you're using, but the actual type of `S` can be
24 * // left unspecified by using an underscore.
25 * let matrix = TensorStack::<i32, [_; 2], 1>::from([&vector1, &vector2], (0, "sample"));
26 * let equal_matrix = Tensor::from([("sample", 2), ("data", 5)], vec![
27 *   0, 1, 2, 3, 4,
28 *   2, 4, 8, 16, 32
29 * ]);
30 * assert_eq!(equal_matrix, TensorView::from(matrix));
31 *
32 * let also_matrix = TensorStack::<i32, (_, _), 1>::from((vector1, vector2), (0, "sample"));
33 * assert_eq!(equal_matrix, TensorView::from(&also_matrix));
34 *
35 * // To stack `equal_matrix` and `also_matrix` using the `[S; N]` implementation we have to first
36 * // make them the same type, which we can do by boxing and erasing.
37 * let matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(also_matrix);
38 * let equal_matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(equal_matrix);
39 * let tensor = TensorStack::<i32, [_; 2], 2>::from(
40 *     [matrix_erased, equal_matrix_erased], (0, "experiment")
41 * );
42 * assert!(
43 *     TensorView::from(tensor).eq(
44 *         &Tensor::from([("experiment", 2), ("sample", 2), ("data", 5)], vec![
45 *             0, 1, 2, 3, 4,
46 *             2, 4, 8, 16, 32,
47 *
48 *             0, 1, 2, 3, 4,
49 *             2, 4, 8, 16, 32
50 *         ])
51 *     ),
52 * );
53 * ```
54 */
55#[derive(Clone, Debug)]
56pub struct TensorStack<T, S, const D: usize> {
57    sources: S,
58    _type: PhantomData<T>,
59    along: (usize, Dimension),
60}
61
62fn validate_shapes_equal<const D: usize, I>(mut shapes: I)
63where
64    I: Iterator<Item = [(Dimension, usize); D]>,
65{
66    // We'll reject fewer than one tensors in the constructors before getting here, so first unwrap
67    // is always going to succeed.
68    let first_shape = shapes.next().unwrap();
69    for (i, shape) in shapes.enumerate() {
70        if shape != first_shape {
71            panic!(
72                "The shapes of each tensor in the sources to stack along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
73                i + 1,
74                shape,
75                first_shape
76            );
77        }
78    }
79}
80
81impl<T, S, const D: usize, const N: usize> TensorStack<T, [S; N], D>
82where
83    S: TensorRef<T, D>,
84{
85    /**
86     * Creates a TensorStack from an array of sources of the same type and a tuple of which
87     * dimension and name to stack the sources along in the range of 0 <= `d` <= D. The sources
88     * must all have an identical shape, and the dimension name to add must not be in the sources'
89     * shape already.
90     *
91     * # Panics
92     *
93     * If N == 0, the shapes of the sources are not identical, the dimension for stacking is out
94     * of bounds, or the name is already in the sources' shape.
95     *
96     * While N == 1 arguments may be valid [TensorExpansion](crate::tensors::views::TensorExpansion)
97     * is a more general way to add dimensions with no additional data.
98     */
99    #[track_caller]
100    pub fn from(sources: [S; N], along: (usize, Dimension)) -> Self {
101        if N == 0 {
102            panic!("No sources provided");
103        }
104        if along.0 > D {
105            panic!(
106                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
107                along
108            );
109        }
110        let shape = sources[0].view_shape();
111        if dimensions::contains(&shape, along.1) {
112            panic!(
113                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
114                along, shape
115            );
116        }
117        validate_shapes_equal(sources.iter().map(|tensor| tensor.view_shape()));
118        Self {
119            sources,
120            along,
121            _type: PhantomData,
122        }
123    }
124
125    /**
126     * Consumes the TensorStack, yielding the sources it was created from in the same order.
127     */
128    pub fn sources(self) -> [S; N] {
129        self.sources
130    }
131
132    // # Safety
133    //
134    // Giving out a mutable reference to our sources could allow then to be changed out from under
135    // us and make our shape invalid. However, since the sources implement TensorRef interior
136    // mutability is not allowed, so we can give out shared references without breaking our own
137    // integrity.
138    /**
139     * Gives a reference to all the TensorStack's sources it was created from in the same order
140     */
141    pub fn sources_ref(&self) -> &[S; N] {
142        &self.sources
143    }
144
145    /**
146     * Returns the shape of each of the matching sources the TensorStack was created from.
147     */
148    fn source_view_shape(&self) -> [(Dimension, usize); D] {
149        self.sources[0].view_shape()
150    }
151
152    fn number_of_sources() -> usize {
153        N
154    }
155}
156
157impl<T, S1, S2, const D: usize> TensorStack<T, (S1, S2), D>
158where
159    S1: TensorRef<T, D>,
160    S2: TensorRef<T, D>,
161{
162    /**
163     * Creates a TensorStack from two sources and a tuple of which dimension and name to stack
164     * the sources along in the range of 0 <= `d` <= D. The sources must all have an identical
165     * shape, and the dimension name to add must not be in the sources' shape already.
166     *
167     * # Panics
168     *
169     * If the shapes of the sources are not identical, the dimension for stacking is out
170     * of bounds, or the name is already in the sources' shape.
171     */
172    #[track_caller]
173    pub fn from(sources: (S1, S2), along: (usize, Dimension)) -> Self {
174        if along.0 > D {
175            panic!(
176                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
177                along
178            );
179        }
180        let shape = sources.0.view_shape();
181        if dimensions::contains(&shape, along.1) {
182            panic!(
183                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
184                along, shape
185            );
186        }
187        validate_shapes_equal([sources.0.view_shape(), sources.1.view_shape()].into_iter());
188        Self {
189            sources,
190            along,
191            _type: PhantomData,
192        }
193    }
194
195    /**
196     * Consumes the TensorStack, yielding the sources it was created from in the same order.
197     */
198    pub fn sources(self) -> (S1, S2) {
199        self.sources
200    }
201
202    // # Safety
203    //
204    // Giving out a mutable reference to our sources could allow then to be changed out from under
205    // us and make our shape invalid. However, since the sources implement TensorRef interior
206    // mutability is not allowed, so we can give out shared references without breaking our own
207    // integrity.
208    /**
209     * Gives a reference to all the TensorStack's sources it was created from in the same order
210     */
211    pub fn sources_ref(&self) -> &(S1, S2) {
212        &self.sources
213    }
214
215    /**
216     * Returns the shape of each of the matching sources the TensorStack was created from.
217     */
218    fn source_view_shape(&self) -> [(Dimension, usize); D] {
219        self.sources.0.view_shape()
220    }
221
222    fn number_of_sources() -> usize {
223        2
224    }
225}
226
227impl<T, S1, S2, S3, const D: usize> TensorStack<T, (S1, S2, S3), D>
228where
229    S1: TensorRef<T, D>,
230    S2: TensorRef<T, D>,
231    S3: TensorRef<T, D>,
232{
233    /**
234     * Creates a TensorStack from three sources and a tuple of which dimension and name to stack
235     * the sources along in the range of 0 <= `d` <= D. The sources must all have an identical
236     * shape, and the dimension name to add must not be in the sources' shape already.
237     *
238     * # Panics
239     *
240     * If the shapes of the sources are not identical, the dimension for stacking is out
241     * of bounds, or the name is already in the sources' shape.
242     */
243    #[track_caller]
244    pub fn from(sources: (S1, S2, S3), along: (usize, Dimension)) -> Self {
245        if along.0 > D {
246            panic!(
247                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
248                along
249            );
250        }
251        let shape = sources.0.view_shape();
252        if dimensions::contains(&shape, along.1) {
253            panic!(
254                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
255                along, shape
256            );
257        }
258        validate_shapes_equal(
259            [
260                sources.0.view_shape(),
261                sources.1.view_shape(),
262                sources.2.view_shape(),
263            ]
264            .into_iter(),
265        );
266        Self {
267            sources,
268            along,
269            _type: PhantomData,
270        }
271    }
272
273    /**
274     * Consumes the TensorStack, yielding the sources it was created from in the same order.
275     */
276    pub fn sources(self) -> (S1, S2, S3) {
277        self.sources
278    }
279
280    // # Safety
281    //
282    // Giving out a mutable reference to our sources could allow then to be changed out from under
283    // us and make our shape invalid. However, since the sources implement TensorRef interior
284    // mutability is not allowed, so we can give out shared references without breaking our own
285    // integrity.
286    /**
287     * Gives a reference to all the TensorStack's sources it was created from in the same order
288     */
289    pub fn sources_ref(&self) -> &(S1, S2, S3) {
290        &self.sources
291    }
292
293    /**
294     * Returns the shape of each of the matching sources the TensorStack was created from.
295     */
296    fn source_view_shape(&self) -> [(Dimension, usize); D] {
297        self.sources.0.view_shape()
298    }
299
300    fn number_of_sources() -> usize {
301        3
302    }
303}
304
305impl<T, S1, S2, S3, S4, const D: usize> TensorStack<T, (S1, S2, S3, S4), D>
306where
307    S1: TensorRef<T, D>,
308    S2: TensorRef<T, D>,
309    S3: TensorRef<T, D>,
310    S4: TensorRef<T, D>,
311{
312    /**
313     * Creates a TensorStack from four sources and a tuple of which dimension and name to stack
314     * the sources along in the range of 0 <= `d` <= D. The sources must all have an identical
315     * shape, and the dimension name to add must not be in the sources' shape already.
316     *
317     * # Panics
318     *
319     * If the shapes of the sources are not identical, the dimension for stacking is out
320     * of bounds, or the name is already in the sources' shape.
321     */
322    #[track_caller]
323    pub fn from(sources: (S1, S2, S3, S4), along: (usize, Dimension)) -> Self {
324        if along.0 > D {
325            panic!(
326                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
327                along
328            );
329        }
330        let shape = sources.0.view_shape();
331        if dimensions::contains(&shape, along.1) {
332            panic!(
333                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
334                along, shape
335            );
336        }
337        validate_shapes_equal(
338            [
339                sources.0.view_shape(),
340                sources.1.view_shape(),
341                sources.2.view_shape(),
342                sources.3.view_shape(),
343            ]
344            .into_iter(),
345        );
346        Self {
347            sources,
348            along,
349            _type: PhantomData,
350        }
351    }
352
353    /**
354     * Consumes the TensorStack, yielding the sources it was created from in the same order.
355     */
356    pub fn sources(self) -> (S1, S2, S3, S4) {
357        self.sources
358    }
359
360    // # Safety
361    //
362    // Giving out a mutable reference to our sources could allow then to be changed out from under
363    // us and make our shape invalid. However, since the sources implement TensorRef interior
364    // mutability is not allowed, so we can give out shared references without breaking our own
365    // integrity.
366    /**
367     * Gives a reference to all the TensorStack's sources it was created from in the same order
368     */
369    pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
370        &self.sources
371    }
372
373    /**
374     * Returns the shape of each of the matching sources the TensorStack was created from.
375     */
376    fn source_view_shape(&self) -> [(Dimension, usize); D] {
377        self.sources.0.view_shape()
378    }
379
380    fn number_of_sources() -> usize {
381        4
382    }
383}
384
385macro_rules! tensor_stack_ref_impl {
386    (unsafe impl TensorRef for TensorStack $d:literal $mod:ident) => {
387        // To avoid helper name clashes we use a different module per macro invocation
388        mod $mod {
389            use crate::tensors::views::{TensorRef, TensorMut, DataLayout, TensorStack};
390            use crate::tensors::Dimension;
391
392            fn view_shape_impl(
393                shape: [(Dimension, usize); $d],
394                along: (usize, Dimension),
395                sources: usize,
396            ) -> [(Dimension, usize); $d + 1] {
397                let mut extra_shape = [("", 0); $d + 1];
398                let mut i = 0;
399                for (d, dimension) in extra_shape.iter_mut().enumerate() {
400                    match d == along.0 {
401                        true => {
402                            *dimension = (along.1, sources);
403                            // Do not increment i, this is the added dimension
404                        },
405                        false => {
406                            *dimension = shape[i];
407                            i += 1;
408                        }
409                    }
410                }
411                extra_shape
412            }
413
414            fn indexing(
415                indexes: [usize; $d + 1],
416                along: (usize, Dimension)
417            ) -> (usize, [usize; $d]) {
418                let mut indexes_into_source = [0; $d];
419                let mut i = 0;
420                for (d, &index) in indexes.iter().enumerate() {
421                    if d != along.0 {
422                        indexes_into_source[i] = index;
423                        i += 1;
424                    }
425                }
426                (indexes[along.0], indexes_into_source)
427            }
428
429            unsafe impl<T, S, const N: usize> TensorRef<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
430            where
431                S: TensorRef<T, $d>
432            {
433                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
434                    let (source, indexes) = indexing(indexes, self.along);
435                    self.sources.get(source)?.get_reference(indexes)
436                }
437
438                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
439                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
440                }
441
442                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
443                    let (source, indexes) = indexing(indexes, self.along);
444                    // TODO: Can we use get_unchecked here?
445                    self.sources.get(source).unwrap().get_reference_unchecked(indexes)
446                }}
447
448                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
449                    // Our stacked shapes means the view shape no longer matches up to a single
450                    // line of data in memory.
451                    DataLayout::NonLinear
452                }
453            }
454
455            unsafe impl<T, S, const N: usize> TensorMut<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
456            where
457                S: TensorMut<T, $d>
458            {
459                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
460                    let (source, indexes) = indexing(indexes, self.along);
461                    self.sources.get_mut(source)?.get_reference_mut(indexes)
462                }
463
464                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
465                    let (source, indexes) = indexing(indexes, self.along);
466                    // TODO: Can we use get_unchecked here?
467                    self.sources.get_mut(source).unwrap().get_reference_unchecked_mut(indexes)
468                }}
469            }
470
471            unsafe impl<T, S1, S2> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
472            where
473                S1: TensorRef<T, $d>,
474                S2: TensorRef<T, $d>,
475            {
476                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
477                    let (source, indexes) = indexing(indexes, self.along);
478                    match source {
479                        0 => self.sources.0.get_reference(indexes),
480                        1 => self.sources.1.get_reference(indexes),
481                        _ => None
482                    }
483                }
484
485                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
486                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
487                }
488
489                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
490                    let (source, indexes) = indexing(indexes, self.along);
491                    match source {
492                        0 => self.sources.0.get_reference_unchecked(indexes),
493                        1 => self.sources.1.get_reference_unchecked(indexes),
494                        // TODO: Can we use unreachable_unchecked here?
495                        _ => panic!(
496                            "Invalid index should never be given to get_reference_unchecked"
497                        )
498                    }
499                }}
500
501                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
502                    // Our stacked shapes means the view shape no longer matches up to a single
503                    // line of data in memory.
504                    DataLayout::NonLinear
505                }
506            }
507
508            unsafe impl<T, S1, S2> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
509            where
510                S1: TensorMut<T, $d>,
511                S2: TensorMut<T, $d>,
512            {
513                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
514                    let (source, indexes) = indexing(indexes, self.along);
515                    match source {
516                        0 => self.sources.0.get_reference_mut(indexes),
517                        1 => self.sources.1.get_reference_mut(indexes),
518                        _ => None
519                    }
520                }
521
522                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
523                    let (source, indexes) = indexing(indexes, self.along);
524                    match source {
525                        0 => self.sources.0.get_reference_unchecked_mut(indexes),
526                        1 => self.sources.1.get_reference_unchecked_mut(indexes),
527                        // TODO: Can we use unreachable_unchecked here?
528                        _ => panic!(
529                            "Invalid index should never be given to get_reference_unchecked"
530                        )
531                    }
532                }}
533            }
534
535            unsafe impl<T, S1, S2, S3> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
536            where
537                S1: TensorRef<T, $d>,
538                S2: TensorRef<T, $d>,
539                S3: TensorRef<T, $d>,
540            {
541                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
542                    let (source, indexes) = indexing(indexes, self.along);
543                    match source {
544                        0 => self.sources.0.get_reference(indexes),
545                        1 => self.sources.1.get_reference(indexes),
546                        2 => self.sources.2.get_reference(indexes),
547                        _ => None
548                    }
549                }
550
551                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
552                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
553                }
554
555                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
556                    let (source, indexes) = indexing(indexes, self.along);
557                    match source {
558                        0 => self.sources.0.get_reference_unchecked(indexes),
559                        1 => self.sources.1.get_reference_unchecked(indexes),
560                        2 => self.sources.2.get_reference_unchecked(indexes),
561                        // TODO: Can we use unreachable_unchecked here?
562                        _ => panic!(
563                            "Invalid index should never be given to get_reference_unchecked"
564                        )
565                    }
566                }}
567
568                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
569                    // Our stacked shapes means the view shape no longer matches up to a single
570                    // line of data in memory.
571                    DataLayout::NonLinear
572                }
573            }
574
575            unsafe impl<T, S1, S2, S3> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
576            where
577                S1: TensorMut<T, $d>,
578                S2: TensorMut<T, $d>,
579                S3: TensorMut<T, $d>,
580            {
581                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
582                    let (source, indexes) = indexing(indexes, self.along);
583                    match source {
584                        0 => self.sources.0.get_reference_mut(indexes),
585                        1 => self.sources.1.get_reference_mut(indexes),
586                        2 => self.sources.2.get_reference_mut(indexes),
587                        _ => None
588                    }
589                }
590
591                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
592                    let (source, indexes) = indexing(indexes, self.along);
593                    match source {
594                        0 => self.sources.0.get_reference_unchecked_mut(indexes),
595                        1 => self.sources.1.get_reference_unchecked_mut(indexes),
596                        2 => self.sources.2.get_reference_unchecked_mut(indexes),
597                        // TODO: Can we use unreachable_unchecked here?
598                        _ => panic!(
599                            "Invalid index should never be given to get_reference_unchecked"
600                        )
601                    }
602                }}
603            }
604
605            unsafe impl<T, S1, S2, S3, S4> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
606            where
607                S1: TensorRef<T, $d>,
608                S2: TensorRef<T, $d>,
609                S3: TensorRef<T, $d>,
610                S4: TensorRef<T, $d>,
611            {
612                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
613                    let (source, indexes) = indexing(indexes, self.along);
614                    match source {
615                        0 => self.sources.0.get_reference(indexes),
616                        1 => self.sources.1.get_reference(indexes),
617                        2 => self.sources.2.get_reference(indexes),
618                        3 => self.sources.3.get_reference(indexes),
619                        _ => None
620                    }
621                }
622
623                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
624                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
625                }
626
627                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
628                    let (source, indexes) = indexing(indexes, self.along);
629                    match source {
630                        0 => self.sources.0.get_reference_unchecked(indexes),
631                        1 => self.sources.1.get_reference_unchecked(indexes),
632                        2 => self.sources.2.get_reference_unchecked(indexes),
633                        3 => self.sources.3.get_reference_unchecked(indexes),
634                        // TODO: Can we use unreachable_unchecked here?
635                        _ => panic!(
636                            "Invalid index should never be given to get_reference_unchecked"
637                        )
638                    }
639                }}
640
641                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
642                    // Our stacked shapes means the view shape no longer matches up to a single
643                    // line of data in memory.
644                    DataLayout::NonLinear
645                }
646            }
647
648            unsafe impl<T, S1, S2, S3, S4> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
649            where
650                S1: TensorMut<T, $d>,
651                S2: TensorMut<T, $d>,
652                S3: TensorMut<T, $d>,
653                S4: TensorMut<T, $d>,
654            {
655                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
656                    let (source, indexes) = indexing(indexes, self.along);
657                    match source {
658                        0 => self.sources.0.get_reference_mut(indexes),
659                        1 => self.sources.1.get_reference_mut(indexes),
660                        2 => self.sources.2.get_reference_mut(indexes),
661                        3 => self.sources.3.get_reference_mut(indexes),
662                        _ => None
663                    }
664                }
665
666                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
667                    let (source, indexes) = indexing(indexes, self.along);
668                    match source {
669                        0 => self.sources.0.get_reference_unchecked_mut(indexes),
670                        1 => self.sources.1.get_reference_unchecked_mut(indexes),
671                        2 => self.sources.2.get_reference_unchecked_mut(indexes),
672                        3 => self.sources.3.get_reference_unchecked_mut(indexes),
673                        // TODO: Can we use unreachable_unchecked here?
674                        _ => panic!(
675                            "Invalid index should never be given to get_reference_unchecked"
676                        )
677                    }
678                }}
679            }
680        }
681    }
682}
683
684tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 0 zero);
685tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 1 one);
686tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 2 two);
687tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 3 three);
688tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 4 four);
689tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 5 five);
690
691#[test]
692fn test_stacking() {
693    use crate::tensors::Tensor;
694    use crate::tensors::views::{TensorMut, TensorView};
695    let vector1 = Tensor::from([("a", 3)], vec![9, 5, 2]);
696    let vector2 = Tensor::from([("a", 3)], vec![3, 6, 0]);
697    let vector3 = Tensor::from([("a", 3)], vec![8, 7, 1]);
698    let matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
699        (&vector1, &vector2, &vector3),
700        (1, "b"),
701    ));
702    #[rustfmt::skip]
703    assert_eq!(
704        matrix,
705        Tensor::from([("a", 3), ("b", 3)], vec![
706            9, 3, 8,
707            5, 6, 7,
708            2, 0, 1,
709        ])
710    );
711    let different_matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
712        (&vector1, &vector2, &vector3),
713        (0, "b"),
714    ));
715    #[rustfmt::skip]
716    assert_eq!(
717        different_matrix,
718        Tensor::from([("b", 3), ("a", 3)], vec![
719            9, 5, 2,
720            3, 6, 0,
721            8, 7, 1,
722        ])
723    );
724    let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
725    let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
726        Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
727    let tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
728        [matrix_erased, different_matrix_erased],
729        (2, "c"),
730    ));
731    #[rustfmt::skip]
732    assert!(
733        tensor.eq(
734            &Tensor::from([("a", 3), ("b", 3), ("c", 2)], vec![
735                9, 9,
736                3, 5,
737                8, 2,
738
739                5, 3,
740                6, 6,
741                7, 0,
742
743                2, 8,
744                0, 7,
745                1, 1
746            ])
747        ),
748    );
749    let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
750    let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
751        Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
752    let different_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
753        [matrix_erased, different_matrix_erased],
754        (1, "c"),
755    ));
756    #[rustfmt::skip]
757    assert!(
758        different_tensor.eq(
759            &Tensor::from([("a", 3), ("c", 2), ("b", 3)], vec![
760                9, 3, 8,
761                9, 5, 2,
762
763                5, 6, 7,
764                3, 6, 0,
765
766                2, 0, 1,
767                8, 7, 1
768            ])
769        ),
770    );
771    let matrix_erased: Box<dyn TensorRef<_, 2>> = Box::new(matrix.map(|x| x));
772    let different_matrix_erased: Box<dyn TensorRef<_, 2>> =
773        Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
774    let another_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
775        [matrix_erased, different_matrix_erased],
776        (0, "c"),
777    ));
778    #[rustfmt::skip]
779    assert!(
780        another_tensor.eq(
781            &Tensor::from([("c", 2), ("a", 3), ("b", 3)], vec![
782                9, 3, 8,
783                5, 6, 7,
784                2, 0, 1,
785
786                9, 5, 2,
787                3, 6, 0,
788                8, 7, 1,
789            ])
790        ),
791    );
792}
793
794/**
795 * Combines two or more tensors along an existing dimension in their shapes to create a Tensor
796 * with a length in that dimension equal to the sum of the sources together along that dimension.
797 * All other dimensions in the tensors' shapes must be the same.
798 *
799 * This can be framed as an D dimensional version of
800 * [std::iter::Iterator::chain](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.chain)
801 *
802 * Note: TensorChain only supports tuple combinations for `2` to `4`. If you need to stack more
803 * than four tensors together, you can stack any number with the `[S; N]` implementation, though
804 * note this requires that all the tensors are the same type so you may need to box and erase
805 * the types to `Box<dyn TensorRef<T, D>>`.
806 *
807 * ```
808 * use easy_ml::tensors::Tensor;
809 * use easy_ml::tensors::views::{TensorView, TensorChain, TensorRef};
810 * let sample1 = Tensor::from([("sample", 1), ("data", 5)], vec![0, 1, 2, 3, 4]);
811 * let sample2 = Tensor::from([("sample", 1), ("data", 5)], vec![2, 4, 8, 16, 32]);
812 * // Because there are 4 variants of `TensorChain::from` you may need to use the turbofish
813 * // to tell the Rust compiler which variant you're using, but the actual type of `S` can be
814 * // left unspecified by using an underscore.
815 * let matrix = TensorChain::<i32, [_; 2], 2>::from([&sample1, &sample2], "sample");
816 * let equal_matrix = Tensor::from([("sample", 2), ("data", 5)], vec![
817 *     0, 1, 2, 3, 4,
818 *     2, 4, 8, 16, 32
819 *  ]);
820 * assert_eq!(equal_matrix, TensorView::from(matrix));
821 *
822 * let also_matrix = TensorChain::<i32, (_, _), 2>::from((sample1, sample2), "sample");
823 * assert_eq!(equal_matrix, TensorView::from(&also_matrix));
824 *
825 * // To stack `equal_matrix` and `also_matrix` using the `[S; N]` implementation we have to first
826 * // make them the same type, which we can do by boxing and erasing.
827 * let matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(also_matrix);
828 * let equal_matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(equal_matrix);
829 * let repeated_data = TensorChain::<i32, [_; 2], 2>::from(
830 *     [matrix_erased, equal_matrix_erased], "data"
831 * );
832 * assert!(
833 *     TensorView::from(repeated_data).eq(
834 *         &Tensor::from([("sample", 2), ("data", 10)], vec![
835 *             0, 1, 2,  3,  4, 0, 1, 2,  3,  4,
836 *             2, 4, 8, 16, 32, 2, 4, 8, 16, 32
837 *         ])
838 *     ),
839 * );
840 * ```
841 */
842#[derive(Clone, Debug)]
843pub struct TensorChain<T, S, const D: usize> {
844    sources: S,
845    _type: PhantomData<T>,
846    along: usize,
847}
848
849fn validate_shapes_similar<const D: usize, I>(mut shapes: I, along: usize)
850where
851    I: Iterator<Item = [(Dimension, usize); D]>,
852{
853    // We'll reject fewer than one tensors in the constructors before getting here, so first unwrap
854    // is always going to succeed.
855    let first_shape = shapes.next().unwrap();
856    for (i, shape) in shapes.enumerate() {
857        for d in 0..D {
858            let similar = if d == along {
859                // don't need match for dimension lengths in the `along` dimension
860                shape[d].0 == first_shape[d].0
861            } else {
862                shape[d] == first_shape[d]
863            };
864            if !similar {
865                panic!(
866                    "The shapes of each tensor in the sources to chain along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
867                    i + 1,
868                    shape,
869                    first_shape
870                );
871            }
872        }
873    }
874}
875
876impl<T, S, const D: usize, const N: usize> TensorChain<T, [S; N], D>
877where
878    S: TensorRef<T, D>,
879{
880    /**
881     * Creates a TensorChain from an array of sources of the same type and the dimension name to
882     * chain the sources along. The sources must all have an identical shape, including the
883     * provided dimension, except for the dimension lengths of the provided dimension name which
884     * may be different.
885     *
886     * # Panics
887     *
888     * If N == 0, D == 0, the shapes of the sources are not identical*, or the dimension for
889     * chaining is not in sources' shape.
890     *
891     * *except for the lengths along the provided dimension.
892     */
893    #[track_caller]
894    pub fn from(sources: [S; N], along: Dimension) -> Self {
895        if N == 0 {
896            panic!("No sources provided");
897        }
898        if D == 0 {
899            panic!("Can't chain along 0 dimensional tensors");
900        }
901        let shape = sources[0].view_shape();
902        let along = match dimensions::position_of(&shape, along) {
903            Some(d) => d,
904            None => panic!(
905                "The dimension {:?} is not in the source's shapes: {:?}",
906                along, shape
907            ),
908        };
909        validate_shapes_similar(sources.iter().map(|tensor| tensor.view_shape()), along);
910        Self {
911            sources,
912            along,
913            _type: PhantomData,
914        }
915    }
916
917    /**
918     * Consumes the TensorChain, yielding the sources it was created from in the same order.
919     */
920    pub fn sources(self) -> [S; N] {
921        self.sources
922    }
923
924    // # Safety
925    //
926    // Giving out a mutable reference to our sources could allow then to be changed out from under
927    // us and make our shape invalid. However, since the sources implement TensorRef interior
928    // mutability is not allowed, so we can give out shared references without breaking our own
929    // integrity.
930    /**
931     * Gives a reference to all the TensorChain's sources it was created from in the same order
932     */
933    pub fn sources_ref(&self) -> &[S; N] {
934        &self.sources
935    }
936}
937
938impl<T, S1, S2, const D: usize> TensorChain<T, (S1, S2), D>
939where
940    S1: TensorRef<T, D>,
941    S2: TensorRef<T, D>,
942{
943    /**
944     * Creates a TensorChain from two sources and the dimension name to chain the sources along.
945     * The sources must all have an identical shape, including the provided dimension, except for
946     * the dimension lengths of the provided dimension name which may be different.
947     *
948     * # Panics
949     *
950     * If D == 0, the shapes of the sources are not identical*, or the dimension for
951     * chaining is not in sources' shape.
952     *
953     * *except for the lengths along the provided dimension.
954     */
955    #[track_caller]
956    pub fn from(sources: (S1, S2), along: Dimension) -> Self {
957        if D == 0 {
958            panic!("Can't chain along 0 dimensional tensors");
959        }
960        let shape = sources.0.view_shape();
961        let along = match dimensions::position_of(&shape, along) {
962            Some(d) => d,
963            None => panic!(
964                "The dimension {:?} is not in the source's shapes: {:?}",
965                along, shape
966            ),
967        };
968        validate_shapes_similar(
969            [sources.0.view_shape(), sources.1.view_shape()].into_iter(),
970            along,
971        );
972        Self {
973            sources,
974            along,
975            _type: PhantomData,
976        }
977    }
978
979    /**
980     * Consumes the TensorChain, yielding the sources it was created from in the same order.
981     */
982    pub fn sources(self) -> (S1, S2) {
983        self.sources
984    }
985
986    // # Safety
987    //
988    // Giving out a mutable reference to our sources could allow then to be changed out from under
989    // us and make our shape invalid. However, since the sources implement TensorRef interior
990    // mutability is not allowed, so we can give out shared references without breaking our own
991    // integrity.
992    /**
993     * Gives a reference to all the TensorChain's sources it was created from in the same order
994     */
995    pub fn sources_ref(&self) -> &(S1, S2) {
996        &self.sources
997    }
998}
999
1000impl<T, S1, S2, S3, const D: usize> TensorChain<T, (S1, S2, S3), D>
1001where
1002    S1: TensorRef<T, D>,
1003    S2: TensorRef<T, D>,
1004    S3: TensorRef<T, D>,
1005{
1006    /**
1007     * Creates a TensorChain from three sources and the dimension name to chain the sources along.
1008     * The sources must all have an identical shape, including the provided dimension, except for
1009     * the dimension lengths of the provided dimension name which may be different.
1010     *
1011     * # Panics
1012     *
1013     * If D == 0, the shapes of the sources are not identical*, or the dimension for
1014     * chaining is not in sources' shape.
1015     *
1016     * *except for the lengths along the provided dimension.
1017     */
1018    #[track_caller]
1019    pub fn from(sources: (S1, S2, S3), along: Dimension) -> Self {
1020        if D == 0 {
1021            panic!("Can't chain along 0 dimensional tensors");
1022        }
1023        let shape = sources.0.view_shape();
1024        let along = match dimensions::position_of(&shape, along) {
1025            Some(d) => d,
1026            None => panic!(
1027                "The dimension {:?} is not in the source's shapes: {:?}",
1028                along, shape
1029            ),
1030        };
1031        validate_shapes_similar(
1032            [
1033                sources.0.view_shape(),
1034                sources.1.view_shape(),
1035                sources.2.view_shape(),
1036            ]
1037            .into_iter(),
1038            along,
1039        );
1040        Self {
1041            sources,
1042            along,
1043            _type: PhantomData,
1044        }
1045    }
1046
1047    /**
1048     * Consumes the TensorChain, yielding the sources it was created from in the same order.
1049     */
1050    pub fn sources(self) -> (S1, S2, S3) {
1051        self.sources
1052    }
1053
1054    // # Safety
1055    //
1056    // Giving out a mutable reference to our sources could allow then to be changed out from under
1057    // us and make our shape invalid. However, since the sources implement TensorRef interior
1058    // mutability is not allowed, so we can give out shared references without breaking our own
1059    // integrity.
1060    /**
1061     * Gives a reference to all the TensorChain's sources it was created from in the same order
1062     */
1063    pub fn sources_ref(&self) -> &(S1, S2, S3) {
1064        &self.sources
1065    }
1066}
1067
1068impl<T, S1, S2, S3, S4, const D: usize> TensorChain<T, (S1, S2, S3, S4), D>
1069where
1070    S1: TensorRef<T, D>,
1071    S2: TensorRef<T, D>,
1072    S3: TensorRef<T, D>,
1073    S4: TensorRef<T, D>,
1074{
1075    /**
1076     * Creates a TensorChain from four sources and the dimension name to chain the sources along.
1077     * The sources must all have an identical shape, including the provided dimension, except for
1078     * the dimension lengths of the provided dimension name which may be different.
1079     *
1080     * # Panics
1081     *
1082     * If D == 0, the shapes of the sources are not identical*, or the dimension for
1083     * chaining is not in sources' shape.
1084     *
1085     * *except for the lengths along the provided dimension.
1086     */
1087    #[track_caller]
1088    pub fn from(sources: (S1, S2, S3, S4), along: Dimension) -> Self {
1089        if D == 0 {
1090            panic!("Can't chain along 0 dimensional tensors");
1091        }
1092        let shape = sources.0.view_shape();
1093        let along = match dimensions::position_of(&shape, along) {
1094            Some(d) => d,
1095            None => panic!(
1096                "The dimension {:?} is not in the source's shapes: {:?}",
1097                along, shape
1098            ),
1099        };
1100        validate_shapes_similar(
1101            [
1102                sources.0.view_shape(),
1103                sources.1.view_shape(),
1104                sources.2.view_shape(),
1105                sources.3.view_shape(),
1106            ]
1107            .into_iter(),
1108            along,
1109        );
1110        Self {
1111            sources,
1112            along,
1113            _type: PhantomData,
1114        }
1115    }
1116
1117    /**
1118     * Consumes the TensorChain, yielding the sources it was created from in the same order.
1119     */
1120    pub fn sources(self) -> (S1, S2, S3, S4) {
1121        self.sources
1122    }
1123
1124    // # Safety
1125    //
1126    // Giving out a mutable reference to our sources could allow then to be changed out from under
1127    // us and make our shape invalid. However, since the sources implement TensorRef interior
1128    // mutability is not allowed, so we can give out shared references without breaking our own
1129    // integrity.
1130    /**
1131     * Gives a reference to all the TensorChain's sources it was created from in the same order
1132     */
1133    pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
1134        &self.sources
1135    }
1136}
1137
1138fn view_shape_impl<I, const D: usize>(
1139    first_shape: [(Dimension, usize); D],
1140    shapes: I,
1141    along: usize,
1142) -> [(Dimension, usize); D]
1143where
1144    I: Iterator<Item = [(Dimension, usize); D]>,
1145{
1146    let mut shape = first_shape;
1147    shape[along].1 = shapes.into_iter().map(|shape| shape[along].1).sum();
1148    shape
1149}
1150
1151fn indexing<I, const D: usize>(
1152    indexes: [usize; D],
1153    shapes: I,
1154    along: usize,
1155) -> Option<(usize, [usize; D])>
1156where
1157    I: Iterator<Item = [(Dimension, usize); D]>,
1158{
1159    let mut shapes = shapes.enumerate();
1160    // Keep trying to index the next shape in the chain, if i is still greater
1161    // than the available length we know it's for a later shape, and can subtract
1162    // that available length till we find one.
1163    let mut i = indexes[along];
1164    loop {
1165        let (source, next_shape) = shapes.next()?;
1166        let length_along_chained_dimension = next_shape[along].1;
1167        if i < length_along_chained_dimension {
1168            #[allow(clippy::clone_on_copy)]
1169            let mut indexes = indexes.clone();
1170            indexes[along] = i;
1171            return Some((source, indexes));
1172        }
1173        i -= length_along_chained_dimension;
1174    }
1175}
1176
1177unsafe impl<T, S, const D: usize, const N: usize> TensorRef<T, D> for TensorChain<T, [S; N], D>
1178where
1179    S: TensorRef<T, D>,
1180{
1181    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1182        let (source, indexes) = indexing(
1183            indexes,
1184            self.sources.iter().map(|s| s.view_shape()),
1185            self.along,
1186        )?;
1187        self.sources.get(source)?.get_reference(indexes)
1188    }
1189
1190    fn view_shape(&self) -> [(Dimension, usize); D] {
1191        view_shape_impl(
1192            self.sources[0].view_shape(),
1193            self.sources.iter().map(|s| s.view_shape()),
1194            self.along,
1195        )
1196    }
1197
1198    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1199        unsafe {
1200            // TODO: Can we use get_unchecked here?
1201            let (source, indexes) = indexing(
1202                indexes,
1203                self.sources.iter().map(|s| s.view_shape()),
1204                self.along,
1205            )
1206            .unwrap();
1207            self.sources
1208                .get(source)
1209                .unwrap()
1210                .get_reference_unchecked(indexes)
1211        }
1212    }
1213
1214    fn data_layout(&self) -> DataLayout<D> {
1215        // Our chained shapes means the view shape no longer matches up to a single
1216        // line of data in memory in the general case.
1217        DataLayout::NonLinear
1218    }
1219}
1220
1221unsafe impl<T, S, const D: usize, const N: usize> TensorMut<T, D> for TensorChain<T, [S; N], D>
1222where
1223    S: TensorMut<T, D>,
1224{
1225    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1226        let (source, indexes) = indexing(
1227            indexes,
1228            self.sources.iter().map(|s| s.view_shape()),
1229            self.along,
1230        )?;
1231        self.sources.get_mut(source)?.get_reference_mut(indexes)
1232    }
1233
1234    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1235        unsafe {
1236            // TODO: Can we use get_unchecked here?
1237            let (source, indexes) = indexing(
1238                indexes,
1239                self.sources.iter().map(|s| s.view_shape()),
1240                self.along,
1241            )
1242            .unwrap();
1243            self.sources
1244                .get_mut(source)
1245                .unwrap()
1246                .get_reference_unchecked_mut(indexes)
1247        }
1248    }
1249}
1250
1251unsafe impl<T, S1, S2, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2), D>
1252where
1253    S1: TensorRef<T, D>,
1254    S2: TensorRef<T, D>,
1255{
1256    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1257        let (source, indexes) = indexing(
1258            indexes,
1259            [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1260            self.along,
1261        )?;
1262        match source {
1263            0 => self.sources.0.get_reference(indexes),
1264            1 => self.sources.1.get_reference(indexes),
1265            _ => None,
1266        }
1267    }
1268
1269    fn view_shape(&self) -> [(Dimension, usize); D] {
1270        view_shape_impl(
1271            self.sources.0.view_shape(),
1272            [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1273            self.along,
1274        )
1275    }
1276
1277    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1278        unsafe {
1279            // TODO: Can we use get_unchecked here?
1280            let (source, indexes) = indexing(
1281                indexes,
1282                [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1283                self.along,
1284            )
1285            .unwrap();
1286            match source {
1287                0 => self.sources.0.get_reference_unchecked(indexes),
1288                1 => self.sources.1.get_reference_unchecked(indexes),
1289                // TODO: Can we use unreachable_unchecked here?
1290                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1291            }
1292        }
1293    }
1294
1295    fn data_layout(&self) -> DataLayout<D> {
1296        // Our chained shapes means the view shape no longer matches up to a single
1297        // line of data in memory in the general case.
1298        DataLayout::NonLinear
1299    }
1300}
1301
1302unsafe impl<T, S1, S2, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2), D>
1303where
1304    S1: TensorMut<T, D>,
1305    S2: TensorMut<T, D>,
1306{
1307    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1308        let (source, indexes) = indexing(
1309            indexes,
1310            [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1311            self.along,
1312        )?;
1313        match source {
1314            0 => self.sources.0.get_reference_mut(indexes),
1315            1 => self.sources.1.get_reference_mut(indexes),
1316            _ => None,
1317        }
1318    }
1319
1320    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1321        unsafe {
1322            // TODO: Can we use get_unchecked here?
1323            let (source, indexes) = indexing(
1324                indexes,
1325                [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1326                self.along,
1327            )
1328            .unwrap();
1329            match source {
1330                0 => self.sources.0.get_reference_unchecked_mut(indexes),
1331                1 => self.sources.1.get_reference_unchecked_mut(indexes),
1332                // TODO: Can we use unreachable_unchecked here?
1333                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1334            }
1335        }
1336    }
1337}
1338
1339unsafe impl<T, S1, S2, S3, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2, S3), D>
1340where
1341    S1: TensorRef<T, D>,
1342    S2: TensorRef<T, D>,
1343    S3: TensorRef<T, D>,
1344{
1345    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1346        let (source, indexes) = indexing(
1347            indexes,
1348            [
1349                self.sources.0.view_shape(),
1350                self.sources.1.view_shape(),
1351                self.sources.2.view_shape(),
1352            ]
1353            .into_iter(),
1354            self.along,
1355        )?;
1356        match source {
1357            0 => self.sources.0.get_reference(indexes),
1358            1 => self.sources.1.get_reference(indexes),
1359            2 => self.sources.2.get_reference(indexes),
1360            _ => None,
1361        }
1362    }
1363
1364    fn view_shape(&self) -> [(Dimension, usize); D] {
1365        view_shape_impl(
1366            self.sources.0.view_shape(),
1367            [
1368                self.sources.0.view_shape(),
1369                self.sources.1.view_shape(),
1370                self.sources.2.view_shape(),
1371            ]
1372            .into_iter(),
1373            self.along,
1374        )
1375    }
1376
1377    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1378        unsafe {
1379            // TODO: Can we use get_unchecked here?
1380            let (source, indexes) = indexing(
1381                indexes,
1382                [
1383                    self.sources.0.view_shape(),
1384                    self.sources.1.view_shape(),
1385                    self.sources.2.view_shape(),
1386                ]
1387                .into_iter(),
1388                self.along,
1389            )
1390            .unwrap();
1391            match source {
1392                0 => self.sources.0.get_reference_unchecked(indexes),
1393                1 => self.sources.1.get_reference_unchecked(indexes),
1394                2 => self.sources.2.get_reference_unchecked(indexes),
1395                // TODO: Can we use unreachable_unchecked here?
1396                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1397            }
1398        }
1399    }
1400
1401    fn data_layout(&self) -> DataLayout<D> {
1402        // Our chained shapes means the view shape no longer matches up to a single
1403        // line of data in memory in the general case.
1404        DataLayout::NonLinear
1405    }
1406}
1407
1408unsafe impl<T, S1, S2, S3, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2, S3), D>
1409where
1410    S1: TensorMut<T, D>,
1411    S2: TensorMut<T, D>,
1412    S3: TensorMut<T, D>,
1413{
1414    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1415        let (source, indexes) = indexing(
1416            indexes,
1417            [
1418                self.sources.0.view_shape(),
1419                self.sources.1.view_shape(),
1420                self.sources.2.view_shape(),
1421            ]
1422            .into_iter(),
1423            self.along,
1424        )?;
1425        match source {
1426            0 => self.sources.0.get_reference_mut(indexes),
1427            1 => self.sources.1.get_reference_mut(indexes),
1428            2 => self.sources.2.get_reference_mut(indexes),
1429            _ => None,
1430        }
1431    }
1432
1433    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1434        unsafe {
1435            // TODO: Can we use get_unchecked here?
1436            let (source, indexes) = indexing(
1437                indexes,
1438                [
1439                    self.sources.0.view_shape(),
1440                    self.sources.1.view_shape(),
1441                    self.sources.2.view_shape(),
1442                ]
1443                .into_iter(),
1444                self.along,
1445            )
1446            .unwrap();
1447            match source {
1448                0 => self.sources.0.get_reference_unchecked_mut(indexes),
1449                1 => self.sources.1.get_reference_unchecked_mut(indexes),
1450                2 => self.sources.2.get_reference_unchecked_mut(indexes),
1451                // TODO: Can we use unreachable_unchecked here?
1452                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1453            }
1454        }
1455    }
1456}
1457
1458unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorRef<T, D>
1459    for TensorChain<T, (S1, S2, S3, S4), D>
1460where
1461    S1: TensorRef<T, D>,
1462    S2: TensorRef<T, D>,
1463    S3: TensorRef<T, D>,
1464    S4: TensorRef<T, D>,
1465{
1466    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1467        let (source, indexes) = indexing(
1468            indexes,
1469            [
1470                self.sources.0.view_shape(),
1471                self.sources.1.view_shape(),
1472                self.sources.2.view_shape(),
1473                self.sources.3.view_shape(),
1474            ]
1475            .into_iter(),
1476            self.along,
1477        )?;
1478        match source {
1479            0 => self.sources.0.get_reference(indexes),
1480            1 => self.sources.1.get_reference(indexes),
1481            2 => self.sources.2.get_reference(indexes),
1482            3 => self.sources.3.get_reference(indexes),
1483            _ => None,
1484        }
1485    }
1486
1487    fn view_shape(&self) -> [(Dimension, usize); D] {
1488        view_shape_impl(
1489            self.sources.0.view_shape(),
1490            [
1491                self.sources.0.view_shape(),
1492                self.sources.1.view_shape(),
1493                self.sources.2.view_shape(),
1494                self.sources.3.view_shape(),
1495            ]
1496            .into_iter(),
1497            self.along,
1498        )
1499    }
1500
1501    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1502        unsafe {
1503            // TODO: Can we use get_unchecked here?
1504            let (source, indexes) = indexing(
1505                indexes,
1506                [
1507                    self.sources.0.view_shape(),
1508                    self.sources.1.view_shape(),
1509                    self.sources.2.view_shape(),
1510                    self.sources.3.view_shape(),
1511                ]
1512                .into_iter(),
1513                self.along,
1514            )
1515            .unwrap();
1516            match source {
1517                0 => self.sources.0.get_reference_unchecked(indexes),
1518                1 => self.sources.1.get_reference_unchecked(indexes),
1519                2 => self.sources.2.get_reference_unchecked(indexes),
1520                3 => self.sources.3.get_reference_unchecked(indexes),
1521                // TODO: Can we use unreachable_unchecked here?
1522                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1523            }
1524        }
1525    }
1526
1527    fn data_layout(&self) -> DataLayout<D> {
1528        // Our chained shapes means the view shape no longer matches up to a single
1529        // line of data in memory in the general case.
1530        DataLayout::NonLinear
1531    }
1532}
1533
1534unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorMut<T, D>
1535    for TensorChain<T, (S1, S2, S3, S4), D>
1536where
1537    S1: TensorMut<T, D>,
1538    S2: TensorMut<T, D>,
1539    S3: TensorMut<T, D>,
1540    S4: TensorMut<T, D>,
1541{
1542    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1543        let (source, indexes) = indexing(
1544            indexes,
1545            [
1546                self.sources.0.view_shape(),
1547                self.sources.1.view_shape(),
1548                self.sources.2.view_shape(),
1549                self.sources.3.view_shape(),
1550            ]
1551            .into_iter(),
1552            self.along,
1553        )?;
1554        match source {
1555            0 => self.sources.0.get_reference_mut(indexes),
1556            1 => self.sources.1.get_reference_mut(indexes),
1557            2 => self.sources.2.get_reference_mut(indexes),
1558            3 => self.sources.3.get_reference_mut(indexes),
1559            _ => None,
1560        }
1561    }
1562
1563    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1564        unsafe {
1565            // TODO: Can we use get_unchecked here?
1566            let (source, indexes) = indexing(
1567                indexes,
1568                [
1569                    self.sources.0.view_shape(),
1570                    self.sources.1.view_shape(),
1571                    self.sources.2.view_shape(),
1572                    self.sources.3.view_shape(),
1573                ]
1574                .into_iter(),
1575                self.along,
1576            )
1577            .unwrap();
1578            match source {
1579                0 => self.sources.0.get_reference_unchecked_mut(indexes),
1580                1 => self.sources.1.get_reference_unchecked_mut(indexes),
1581                2 => self.sources.2.get_reference_unchecked_mut(indexes),
1582                3 => self.sources.3.get_reference_unchecked_mut(indexes),
1583                // TODO: Can we use unreachable_unchecked here?
1584                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1585            }
1586        }
1587    }
1588}
1589
1590#[test]
1591fn test_chaining() {
1592    use crate::tensors::Tensor;
1593    use crate::tensors::views::TensorView;
1594    #[rustfmt::skip]
1595    let matrix1 = Tensor::from(
1596        [("a", 3), ("b", 2)],
1597        vec![
1598            9, 5,
1599            2, 1,
1600            3, 5
1601        ]
1602    );
1603    #[rustfmt::skip]
1604    let matrix2 = Tensor::from(
1605        [("a", 4), ("b", 2)],
1606        vec![
1607            0, 1,
1608            8, 4,
1609            1, 7,
1610            6, 3
1611        ]
1612    );
1613    let matrix = TensorView::from(TensorChain::<_, (_, _), 2>::from((&matrix1, &matrix2), "a"));
1614    #[rustfmt::skip]
1615    assert_eq!(
1616        matrix,
1617        Tensor::from([("a", 7), ("b", 2)], vec![
1618            9, 5,
1619            2, 1,
1620            3, 5,
1621            0, 1,
1622            8, 4,
1623            1, 7,
1624            6, 3
1625        ])
1626    );
1627    let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
1628    let different_matrix = Tensor::from([("a", 7), ("b", 1)], (0..7).collect());
1629    let different_matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(different_matrix);
1630    let another_matrix = TensorView::from(TensorChain::<_, [_; 2], 2>::from(
1631        [matrix_erased, different_matrix_erased],
1632        "b",
1633    ));
1634    #[rustfmt::skip]
1635    assert!(
1636        another_matrix.eq(
1637            &Tensor::from([("a", 7), ("b", 3)], vec![
1638                9, 5, 0,
1639                2, 1, 1,
1640                3, 5, 2,
1641                0, 1, 3,
1642                8, 4, 4,
1643                1, 7, 5,
1644                6, 3, 6
1645            ])
1646        )
1647    );
1648}