easy_ml/tensors/views/
zip.rs

1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6/**
7 * Combines two or more tensors with the same shape along a new dimension to create a Tensor
8 * with one additional dimension which stacks the sources together along that dimension.
9 *
10 * Note: due to limitations in Rust's const generics support, TensorStack only implements
11 * TensorRef for D from `1` to `6` (from sources of `0` to `5` dimensions respectively), and
12 * only supports tuple combinations for `2` to `4`. If you need to stack more than four tensors
13 * together, you can stack any number with the `[S; N]` implementation, though note this requires
14 * that all the tensors are the same type so you may need to box and erase the types to
15 * `Box<dyn TensorRef<T, D>>`.
16 *
17 * ```
18 * use easy_ml::tensors::Tensor;
19 * use easy_ml::tensors::views::{TensorView, TensorStack, TensorRef};
20 * let vector1 = Tensor::from([("data", 5)], vec![0, 1, 2, 3, 4]);
21 * let vector2 = Tensor::from([("data", 5)], vec![2, 4, 8, 16, 32]);
22 * // Because there are 4 variants of `TensorStack::from` you may need to use the turbofish
23 * // to tell the Rust compiler which variant you're using, but the actual type of `S` can be
24 * // left unspecified by using an underscore.
25 * let matrix = TensorStack::<i32, [_; 2], 1>::from([&vector1, &vector2], (0, "sample"));
26 * let equal_matrix = Tensor::from([("sample", 2), ("data", 5)], vec![
27 *   0, 1, 2, 3, 4,
28 *   2, 4, 8, 16, 32
29 * ]);
30 * assert_eq!(equal_matrix, TensorView::from(matrix));
31 *
32 * let also_matrix = TensorStack::<i32, (_, _), 1>::from((vector1, vector2), (0, "sample"));
33 * assert_eq!(equal_matrix, TensorView::from(&also_matrix));
34 *
35 * // To stack `equal_matrix` and `also_matrix` using the `[S; N]` implementation we have to first
36 * // make them the same type, which we can do by boxing and erasing.
37 * let matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(also_matrix);
38 * let equal_matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(equal_matrix);
39 * let tensor = TensorStack::<i32, [_; 2], 2>::from(
40 *     [matrix_erased, equal_matrix_erased], (0, "experiment")
41 * );
42 * assert!(
43 *     TensorView::from(tensor).eq(
44 *         &Tensor::from([("experiment", 2), ("sample", 2), ("data", 5)], vec![
45 *             0, 1, 2, 3, 4,
46 *             2, 4, 8, 16, 32,
47 *
48 *             0, 1, 2, 3, 4,
49 *             2, 4, 8, 16, 32
50 *         ])
51 *     ),
52 * );
53 * ```
54 */
55#[derive(Clone, Debug)]
56pub struct TensorStack<T, S, const D: usize> {
57    sources: S,
58    _type: PhantomData<T>,
59    along: (usize, Dimension),
60}
61
62fn validate_shapes_equal<const D: usize, I>(mut shapes: I)
63where
64    I: Iterator<Item = [(Dimension, usize); D]>,
65{
66    // We'll reject fewer than one tensors in the constructors before getting here, so first unwrap
67    // is always going to succeed.
68    let first_shape = shapes.next().unwrap();
69    for (i, shape) in shapes.enumerate() {
70        if shape != first_shape {
71            panic!(
72                "The shapes of each tensor in the sources to stack along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
73                i + 1,
74                shape,
75                first_shape
76            );
77        }
78    }
79}
80
81impl<T, S, const D: usize, const N: usize> TensorStack<T, [S; N], D>
82where
83    S: TensorRef<T, D>,
84{
85    /**
86     * Creates a TensorStack from an array of sources of the same type and a tuple of which
87     * dimension and name to stack the sources along in the range of 0 <= `d` <= D. The sources
88     * must all have an identical shape, and the dimension name to add must not be in the sources'
89     * shape already.
90     *
91     * # Panics
92     *
93     * If N == 0, the shapes of the sources are not identical, the dimension for stacking is out
94     * of bounds, or the name is already in the sources' shape.
95     *
96     * While N == 1 arguments may be valid [TensorExpansion](crate::tensors::views::TensorExpansion)
97     * is a more general way to add dimensions with no additional data.
98     */
99    #[track_caller]
100    pub fn from(sources: [S; N], along: (usize, Dimension)) -> Self {
101        if N == 0 {
102            panic!("No sources provided");
103        }
104        if along.0 > D {
105            panic!(
106                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
107                along
108            );
109        }
110        let shape = sources[0].view_shape();
111        if dimensions::contains(&shape, along.1) {
112            panic!(
113                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
114                along, shape
115            );
116        }
117        validate_shapes_equal(sources.iter().map(|tensor| tensor.view_shape()));
118        Self {
119            sources,
120            along,
121            _type: PhantomData,
122        }
123    }
124
125    /**
126     * Consumes the TensorStack, yielding the sources it was created from in the same order.
127     */
128    pub fn sources(self) -> [S; N] {
129        self.sources
130    }
131
132    // # Safety
133    //
134    // Giving out a mutable reference to our sources could allow then to be changed out from under
135    // us and make our shape invalid. However, since the sources implement TensorRef interior
136    // mutability is not allowed, so we can give out shared references without breaking our own
137    // integrity.
138    /**
139     * Gives a reference to all the TensorStack's sources it was created from in the same order
140     */
141    pub fn sources_ref(&self) -> &[S; N] {
142        &self.sources
143    }
144
145    /**
146     * Returns the shape of each of the matching sources the TensorStack was created from.
147     */
148    fn source_view_shape(&self) -> [(Dimension, usize); D] {
149        self.sources[0].view_shape()
150    }
151
152    fn number_of_sources() -> usize {
153        N
154    }
155}
156
157impl<T, S1, S2, const D: usize> TensorStack<T, (S1, S2), D>
158where
159    S1: TensorRef<T, D>,
160    S2: TensorRef<T, D>,
161{
162    /**
163     * Creates a TensorStack from two sources and a tuple of which dimension and name to stack
164     * the sources along in the range of 0 <= `d` <= D. The sources must all have an identical
165     * shape, and the dimension name to add must not be in the sources' shape already.
166     *
167     * # Panics
168     *
169     * If the shapes of the sources are not identical, the dimension for stacking is out
170     * of bounds, or the name is already in the sources' shape.
171     */
172    #[track_caller]
173    pub fn from(sources: (S1, S2), along: (usize, Dimension)) -> Self {
174        if along.0 > D {
175            panic!(
176                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
177                along
178            );
179        }
180        let shape = sources.0.view_shape();
181        if dimensions::contains(&shape, along.1) {
182            panic!(
183                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
184                along, shape
185            );
186        }
187        validate_shapes_equal([sources.0.view_shape(), sources.1.view_shape()].into_iter());
188        Self {
189            sources,
190            along,
191            _type: PhantomData,
192        }
193    }
194
195    /**
196     * Consumes the TensorStack, yielding the sources it was created from in the same order.
197     */
198    pub fn sources(self) -> (S1, S2) {
199        self.sources
200    }
201
202    // # Safety
203    //
204    // Giving out a mutable reference to our sources could allow then to be changed out from under
205    // us and make our shape invalid. However, since the sources implement TensorRef interior
206    // mutability is not allowed, so we can give out shared references without breaking our own
207    // integrity.
208    /**
209     * Gives a reference to all the TensorStack's sources it was created from in the same order
210     */
211    pub fn sources_ref(&self) -> &(S1, S2) {
212        &self.sources
213    }
214
215    /**
216     * Returns the shape of each of the matching sources the TensorStack was created from.
217     */
218    fn source_view_shape(&self) -> [(Dimension, usize); D] {
219        self.sources.0.view_shape()
220    }
221
222    fn number_of_sources() -> usize {
223        2
224    }
225}
226
227impl<T, S1, S2, S3, const D: usize> TensorStack<T, (S1, S2, S3), D>
228where
229    S1: TensorRef<T, D>,
230    S2: TensorRef<T, D>,
231    S3: TensorRef<T, D>,
232{
233    /**
234     * Creates a TensorStack from three sources and a tuple of which dimension and name to stack
235     * the sources along in the range of 0 <= `d` <= D. The sources must all have an identical
236     * shape, and the dimension name to add must not be in the sources' shape already.
237     *
238     * # Panics
239     *
240     * If the shapes of the sources are not identical, the dimension for stacking is out
241     * of bounds, or the name is already in the sources' shape.
242     */
243    #[track_caller]
244    pub fn from(sources: (S1, S2, S3), along: (usize, Dimension)) -> Self {
245        if along.0 > D {
246            panic!(
247                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
248                along
249            );
250        }
251        let shape = sources.0.view_shape();
252        if dimensions::contains(&shape, along.1) {
253            panic!(
254                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
255                along, shape
256            );
257        }
258        validate_shapes_equal(
259            [
260                sources.0.view_shape(),
261                sources.1.view_shape(),
262                sources.2.view_shape(),
263            ]
264            .into_iter(),
265        );
266        Self {
267            sources,
268            along,
269            _type: PhantomData,
270        }
271    }
272
273    /**
274     * Consumes the TensorStack, yielding the sources it was created from in the same order.
275     */
276    pub fn sources(self) -> (S1, S2, S3) {
277        self.sources
278    }
279
280    // # Safety
281    //
282    // Giving out a mutable reference to our sources could allow then to be changed out from under
283    // us and make our shape invalid. However, since the sources implement TensorRef interior
284    // mutability is not allowed, so we can give out shared references without breaking our own
285    // integrity.
286    /**
287     * Gives a reference to all the TensorStack's sources it was created from in the same order
288     */
289    pub fn sources_ref(&self) -> &(S1, S2, S3) {
290        &self.sources
291    }
292
293    /**
294     * Returns the shape of each of the matching sources the TensorStack was created from.
295     */
296    fn source_view_shape(&self) -> [(Dimension, usize); D] {
297        self.sources.0.view_shape()
298    }
299
300    fn number_of_sources() -> usize {
301        3
302    }
303}
304
305impl<T, S1, S2, S3, S4, const D: usize> TensorStack<T, (S1, S2, S3, S4), D>
306where
307    S1: TensorRef<T, D>,
308    S2: TensorRef<T, D>,
309    S3: TensorRef<T, D>,
310    S4: TensorRef<T, D>,
311{
312    /**
313     * Creates a TensorStack from four sources and a tuple of which dimension and name to stack
314     * the sources along in the range of 0 <= `d` <= D. The sources must all have an identical
315     * shape, and the dimension name to add must not be in the sources' shape already.
316     *
317     * # Panics
318     *
319     * If the shapes of the sources are not identical, the dimension for stacking is out
320     * of bounds, or the name is already in the sources' shape.
321     */
322    #[track_caller]
323    pub fn from(sources: (S1, S2, S3, S4), along: (usize, Dimension)) -> Self {
324        if along.0 > D {
325            panic!(
326                "The extra dimension the sources are stacked along {:?} must be inserted in the range 0 <= d <= D of the source shapes",
327                along
328            );
329        }
330        let shape = sources.0.view_shape();
331        if dimensions::contains(&shape, along.1) {
332            panic!(
333                "The extra dimension the sources are stacked along {:?} must not be one of the dimensions already in the source shapes: {:?}",
334                along, shape
335            );
336        }
337        validate_shapes_equal(
338            [
339                sources.0.view_shape(),
340                sources.1.view_shape(),
341                sources.2.view_shape(),
342                sources.3.view_shape(),
343            ]
344            .into_iter(),
345        );
346        Self {
347            sources,
348            along,
349            _type: PhantomData,
350        }
351    }
352
353    /**
354     * Consumes the TensorStack, yielding the sources it was created from in the same order.
355     */
356    pub fn sources(self) -> (S1, S2, S3, S4) {
357        self.sources
358    }
359
360    // # Safety
361    //
362    // Giving out a mutable reference to our sources could allow then to be changed out from under
363    // us and make our shape invalid. However, since the sources implement TensorRef interior
364    // mutability is not allowed, so we can give out shared references without breaking our own
365    // integrity.
366    /**
367     * Gives a reference to all the TensorStack's sources it was created from in the same order
368     */
369    pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
370        &self.sources
371    }
372
373    /**
374     * Returns the shape of each of the matching sources the TensorStack was created from.
375     */
376    fn source_view_shape(&self) -> [(Dimension, usize); D] {
377        self.sources.0.view_shape()
378    }
379
380    fn number_of_sources() -> usize {
381        4
382    }
383}
384
385macro_rules! tensor_stack_ref_impl {
386    (unsafe impl TensorRef for TensorStack $d:literal $mod:ident) => {
387        // To avoid helper name clashes we use a different module per macro invocation
388        mod $mod {
389            use crate::tensors::views::{TensorRef, TensorMut, DataLayout, TensorStack};
390            use crate::tensors::Dimension;
391
392            fn view_shape_impl(
393                shape: [(Dimension, usize); $d],
394                along: (usize, Dimension),
395                sources: usize,
396            ) -> [(Dimension, usize); $d + 1] {
397                let mut extra_shape = [("", 0); $d + 1];
398                let mut i = 0;
399                for (d, dimension) in extra_shape.iter_mut().enumerate() {
400                    match d == along.0 {
401                        true => {
402                            *dimension = (along.1, sources);
403                            // Do not increment i, this is the added dimension
404                        },
405                        false => {
406                            *dimension = shape[i];
407                            i += 1;
408                        }
409                    }
410                }
411                extra_shape
412            }
413
414            fn indexing(
415                indexes: [usize; $d + 1],
416                along: (usize, Dimension)
417            ) -> (usize, [usize; $d]) {
418                let mut indexes_into_source = [0; $d];
419                let mut i = 0;
420                for (d, &index) in indexes.iter().enumerate() {
421                    if d != along.0 {
422                        indexes_into_source[i] = index;
423                        i += 1;
424                    }
425                }
426                (indexes[along.0], indexes_into_source)
427            }
428
429            unsafe impl<T, S, const N: usize> TensorRef<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
430            where
431                S: TensorRef<T, $d>
432            {
433                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
434                    let (source, indexes) = indexing(indexes, self.along);
435                    self.sources.get(source)?.get_reference(indexes)
436                }
437
438                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
439                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
440                }
441
442                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
443                    let (source, indexes) = indexing(indexes, self.along);
444                    self.sources.get_unchecked(source).get_reference_unchecked(indexes)
445                }}
446
447                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
448                    // Our stacked shapes means the view shape no longer matches up to a single
449                    // line of data in memory.
450                    DataLayout::NonLinear
451                }
452            }
453
454            unsafe impl<T, S, const N: usize> TensorMut<T, { $d + 1 }> for TensorStack<T, [S; N], $d>
455            where
456                S: TensorMut<T, $d>
457            {
458                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
459                    let (source, indexes) = indexing(indexes, self.along);
460                    self.sources.get_mut(source)?.get_reference_mut(indexes)
461                }
462
463                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
464                    let (source, indexes) = indexing(indexes, self.along);
465                    self.sources.get_unchecked_mut(source).get_reference_unchecked_mut(indexes)
466                }}
467            }
468
469            unsafe impl<T, S1, S2> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
470            where
471                S1: TensorRef<T, $d>,
472                S2: TensorRef<T, $d>,
473            {
474                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
475                    let (source, indexes) = indexing(indexes, self.along);
476                    match source {
477                        0 => self.sources.0.get_reference(indexes),
478                        1 => self.sources.1.get_reference(indexes),
479                        _ => None
480                    }
481                }
482
483                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
484                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
485                }
486
487                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
488                    let (source, indexes) = indexing(indexes, self.along);
489                    match source {
490                        0 => self.sources.0.get_reference_unchecked(indexes),
491                        1 => self.sources.1.get_reference_unchecked(indexes),
492                        // TODO: Can we use unreachable_unchecked here?
493                        _ => panic!(
494                            "Invalid index should never be given to get_reference_unchecked"
495                        )
496                    }
497                }}
498
499                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
500                    // Our stacked shapes means the view shape no longer matches up to a single
501                    // line of data in memory.
502                    DataLayout::NonLinear
503                }
504            }
505
506            unsafe impl<T, S1, S2> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2), $d>
507            where
508                S1: TensorMut<T, $d>,
509                S2: TensorMut<T, $d>,
510            {
511                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
512                    let (source, indexes) = indexing(indexes, self.along);
513                    match source {
514                        0 => self.sources.0.get_reference_mut(indexes),
515                        1 => self.sources.1.get_reference_mut(indexes),
516                        _ => None
517                    }
518                }
519
520                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
521                    let (source, indexes) = indexing(indexes, self.along);
522                    match source {
523                        0 => self.sources.0.get_reference_unchecked_mut(indexes),
524                        1 => self.sources.1.get_reference_unchecked_mut(indexes),
525                        // TODO: Can we use unreachable_unchecked here?
526                        _ => panic!(
527                            "Invalid index should never be given to get_reference_unchecked"
528                        )
529                    }
530                }}
531            }
532
533            unsafe impl<T, S1, S2, S3> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
534            where
535                S1: TensorRef<T, $d>,
536                S2: TensorRef<T, $d>,
537                S3: TensorRef<T, $d>,
538            {
539                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
540                    let (source, indexes) = indexing(indexes, self.along);
541                    match source {
542                        0 => self.sources.0.get_reference(indexes),
543                        1 => self.sources.1.get_reference(indexes),
544                        2 => self.sources.2.get_reference(indexes),
545                        _ => None
546                    }
547                }
548
549                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
550                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
551                }
552
553                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
554                    let (source, indexes) = indexing(indexes, self.along);
555                    match source {
556                        0 => self.sources.0.get_reference_unchecked(indexes),
557                        1 => self.sources.1.get_reference_unchecked(indexes),
558                        2 => self.sources.2.get_reference_unchecked(indexes),
559                        // TODO: Can we use unreachable_unchecked here?
560                        _ => panic!(
561                            "Invalid index should never be given to get_reference_unchecked"
562                        )
563                    }
564                }}
565
566                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
567                    // Our stacked shapes means the view shape no longer matches up to a single
568                    // line of data in memory.
569                    DataLayout::NonLinear
570                }
571            }
572
573            unsafe impl<T, S1, S2, S3> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3), $d>
574            where
575                S1: TensorMut<T, $d>,
576                S2: TensorMut<T, $d>,
577                S3: TensorMut<T, $d>,
578            {
579                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
580                    let (source, indexes) = indexing(indexes, self.along);
581                    match source {
582                        0 => self.sources.0.get_reference_mut(indexes),
583                        1 => self.sources.1.get_reference_mut(indexes),
584                        2 => self.sources.2.get_reference_mut(indexes),
585                        _ => None
586                    }
587                }
588
589                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
590                    let (source, indexes) = indexing(indexes, self.along);
591                    match source {
592                        0 => self.sources.0.get_reference_unchecked_mut(indexes),
593                        1 => self.sources.1.get_reference_unchecked_mut(indexes),
594                        2 => self.sources.2.get_reference_unchecked_mut(indexes),
595                        // TODO: Can we use unreachable_unchecked here?
596                        _ => panic!(
597                            "Invalid index should never be given to get_reference_unchecked"
598                        )
599                    }
600                }}
601            }
602
603            unsafe impl<T, S1, S2, S3, S4> TensorRef<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
604            where
605                S1: TensorRef<T, $d>,
606                S2: TensorRef<T, $d>,
607                S3: TensorRef<T, $d>,
608                S4: TensorRef<T, $d>,
609            {
610                fn get_reference(&self, indexes: [usize; $d + 1]) -> Option<&T> {
611                    let (source, indexes) = indexing(indexes, self.along);
612                    match source {
613                        0 => self.sources.0.get_reference(indexes),
614                        1 => self.sources.1.get_reference(indexes),
615                        2 => self.sources.2.get_reference(indexes),
616                        3 => self.sources.3.get_reference(indexes),
617                        _ => None
618                    }
619                }
620
621                fn view_shape(&self) -> [(Dimension, usize); $d + 1] {
622                    view_shape_impl(self.source_view_shape(), self.along, Self::number_of_sources())
623                }
624
625                unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + 1]) -> &T { unsafe {
626                    let (source, indexes) = indexing(indexes, self.along);
627                    match source {
628                        0 => self.sources.0.get_reference_unchecked(indexes),
629                        1 => self.sources.1.get_reference_unchecked(indexes),
630                        2 => self.sources.2.get_reference_unchecked(indexes),
631                        3 => self.sources.3.get_reference_unchecked(indexes),
632                        // TODO: Can we use unreachable_unchecked here?
633                        _ => panic!(
634                            "Invalid index should never be given to get_reference_unchecked"
635                        )
636                    }
637                }}
638
639                fn data_layout(&self) -> DataLayout<{ $d + 1 }> {
640                    // Our stacked shapes means the view shape no longer matches up to a single
641                    // line of data in memory.
642                    DataLayout::NonLinear
643                }
644            }
645
646            unsafe impl<T, S1, S2, S3, S4> TensorMut<T, { $d + 1 }> for TensorStack<T, (S1, S2, S3, S4), $d>
647            where
648                S1: TensorMut<T, $d>,
649                S2: TensorMut<T, $d>,
650                S3: TensorMut<T, $d>,
651                S4: TensorMut<T, $d>,
652            {
653                fn get_reference_mut(&mut self, indexes: [usize; $d + 1]) -> Option<&mut T> {
654                    let (source, indexes) = indexing(indexes, self.along);
655                    match source {
656                        0 => self.sources.0.get_reference_mut(indexes),
657                        1 => self.sources.1.get_reference_mut(indexes),
658                        2 => self.sources.2.get_reference_mut(indexes),
659                        3 => self.sources.3.get_reference_mut(indexes),
660                        _ => None
661                    }
662                }
663
664                unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + 1]) -> &mut T { unsafe {
665                    let (source, indexes) = indexing(indexes, self.along);
666                    match source {
667                        0 => self.sources.0.get_reference_unchecked_mut(indexes),
668                        1 => self.sources.1.get_reference_unchecked_mut(indexes),
669                        2 => self.sources.2.get_reference_unchecked_mut(indexes),
670                        3 => self.sources.3.get_reference_unchecked_mut(indexes),
671                        // TODO: Can we use unreachable_unchecked here?
672                        _ => panic!(
673                            "Invalid index should never be given to get_reference_unchecked"
674                        )
675                    }
676                }}
677            }
678        }
679    }
680}
681
682tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 0 zero);
683tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 1 one);
684tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 2 two);
685tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 3 three);
686tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 4 four);
687tensor_stack_ref_impl!(unsafe impl TensorRef for TensorStack 5 five);
688
689#[test]
690fn test_stacking() {
691    use crate::tensors::Tensor;
692    use crate::tensors::views::{TensorMut, TensorView};
693    let vector1 = Tensor::from([("a", 3)], vec![9, 5, 2]);
694    let vector2 = Tensor::from([("a", 3)], vec![3, 6, 0]);
695    let vector3 = Tensor::from([("a", 3)], vec![8, 7, 1]);
696    let matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
697        (&vector1, &vector2, &vector3),
698        (1, "b"),
699    ));
700    #[rustfmt::skip]
701    assert_eq!(
702        matrix,
703        Tensor::from([("a", 3), ("b", 3)], vec![
704            9, 3, 8,
705            5, 6, 7,
706            2, 0, 1,
707        ])
708    );
709    let different_matrix = TensorView::from(TensorStack::<_, (_, _, _), 1>::from(
710        (&vector1, &vector2, &vector3),
711        (0, "b"),
712    ));
713    #[rustfmt::skip]
714    assert_eq!(
715        different_matrix,
716        Tensor::from([("b", 3), ("a", 3)], vec![
717            9, 5, 2,
718            3, 6, 0,
719            8, 7, 1,
720        ])
721    );
722    let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
723    let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
724        Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
725    let tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
726        [matrix_erased, different_matrix_erased],
727        (2, "c"),
728    ));
729    #[rustfmt::skip]
730    assert!(
731        tensor.eq(
732            &Tensor::from([("a", 3), ("b", 3), ("c", 2)], vec![
733                9, 9,
734                3, 5,
735                8, 2,
736
737                5, 3,
738                6, 6,
739                7, 0,
740
741                2, 8,
742                0, 7,
743                1, 1
744            ])
745        ),
746    );
747    let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
748    let different_matrix_erased: Box<dyn TensorMut<_, 2>> =
749        Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
750    let different_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
751        [matrix_erased, different_matrix_erased],
752        (1, "c"),
753    ));
754    #[rustfmt::skip]
755    assert!(
756        different_tensor.eq(
757            &Tensor::from([("a", 3), ("c", 2), ("b", 3)], vec![
758                9, 3, 8,
759                9, 5, 2,
760
761                5, 6, 7,
762                3, 6, 0,
763
764                2, 0, 1,
765                8, 7, 1
766            ])
767        ),
768    );
769    let matrix_erased: Box<dyn TensorRef<_, 2>> = Box::new(matrix.map(|x| x));
770    let different_matrix_erased: Box<dyn TensorRef<_, 2>> =
771        Box::new(different_matrix.rename_view(["a", "b"]).map(|x| x));
772    let another_tensor = TensorView::from(TensorStack::<_, [_; 2], 2>::from(
773        [matrix_erased, different_matrix_erased],
774        (0, "c"),
775    ));
776    #[rustfmt::skip]
777    assert!(
778        another_tensor.eq(
779            &Tensor::from([("c", 2), ("a", 3), ("b", 3)], vec![
780                9, 3, 8,
781                5, 6, 7,
782                2, 0, 1,
783
784                9, 5, 2,
785                3, 6, 0,
786                8, 7, 1,
787            ])
788        ),
789    );
790}
791
792/**
793 * Combines two or more tensors along an existing dimension in their shapes to create a Tensor
794 * with a length in that dimension equal to the sum of the sources together along that dimension.
795 * All other dimensions in the tensors' shapes must be the same.
796 *
797 * This can be framed as an D dimensional version of
798 * [std::iter::Iterator::chain](https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.chain)
799 *
800 * Note: TensorChain only supports tuple combinations for `2` to `4`. If you need to stack more
801 * than four tensors together, you can stack any number with the `[S; N]` implementation, though
802 * note this requires that all the tensors are the same type so you may need to box and erase
803 * the types to `Box<dyn TensorRef<T, D>>`.
804 *
805 * ```
806 * use easy_ml::tensors::Tensor;
807 * use easy_ml::tensors::views::{TensorView, TensorChain, TensorRef};
808 * let sample1 = Tensor::from([("sample", 1), ("data", 5)], vec![0, 1, 2, 3, 4]);
809 * let sample2 = Tensor::from([("sample", 1), ("data", 5)], vec![2, 4, 8, 16, 32]);
810 * // Because there are 4 variants of `TensorChain::from` you may need to use the turbofish
811 * // to tell the Rust compiler which variant you're using, but the actual type of `S` can be
812 * // left unspecified by using an underscore.
813 * let matrix = TensorChain::<i32, [_; 2], 2>::from([&sample1, &sample2], "sample");
814 * let equal_matrix = Tensor::from([("sample", 2), ("data", 5)], vec![
815 *     0, 1, 2, 3, 4,
816 *     2, 4, 8, 16, 32
817 *  ]);
818 * assert_eq!(equal_matrix, TensorView::from(matrix));
819 *
820 * let also_matrix = TensorChain::<i32, (_, _), 2>::from((sample1, sample2), "sample");
821 * assert_eq!(equal_matrix, TensorView::from(&also_matrix));
822 *
823 * // To stack `equal_matrix` and `also_matrix` using the `[S; N]` implementation we have to first
824 * // make them the same type, which we can do by boxing and erasing.
825 * let matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(also_matrix);
826 * let equal_matrix_erased: Box<dyn TensorRef<i32, 2>> = Box::new(equal_matrix);
827 * let repeated_data = TensorChain::<i32, [_; 2], 2>::from(
828 *     [matrix_erased, equal_matrix_erased], "data"
829 * );
830 * assert!(
831 *     TensorView::from(repeated_data).eq(
832 *         &Tensor::from([("sample", 2), ("data", 10)], vec![
833 *             0, 1, 2,  3,  4, 0, 1, 2,  3,  4,
834 *             2, 4, 8, 16, 32, 2, 4, 8, 16, 32
835 *         ])
836 *     ),
837 * );
838 * ```
839 */
840#[derive(Clone, Debug)]
841pub struct TensorChain<T, S, const D: usize> {
842    sources: S,
843    _type: PhantomData<T>,
844    along: usize,
845}
846
847fn validate_shapes_similar<const D: usize, I>(mut shapes: I, along: usize)
848where
849    I: Iterator<Item = [(Dimension, usize); D]>,
850{
851    // We'll reject fewer than one tensors in the constructors before getting here, so first unwrap
852    // is always going to succeed.
853    let first_shape = shapes.next().unwrap();
854    for (i, shape) in shapes.enumerate() {
855        for d in 0..D {
856            let similar = if d == along {
857                // don't need match for dimension lengths in the `along` dimension
858                shape[d].0 == first_shape[d].0
859            } else {
860                shape[d] == first_shape[d]
861            };
862            if !similar {
863                panic!(
864                    "The shapes of each tensor in the sources to chain along must be the same. Shape {:?} {:?} did not match the first shape {:?}",
865                    i + 1,
866                    shape,
867                    first_shape
868                );
869            }
870        }
871    }
872}
873
874impl<T, S, const D: usize, const N: usize> TensorChain<T, [S; N], D>
875where
876    S: TensorRef<T, D>,
877{
878    /**
879     * Creates a TensorChain from an array of sources of the same type and the dimension name to
880     * chain the sources along. The sources must all have an identical shape, including the
881     * provided dimension, except for the dimension lengths of the provided dimension name which
882     * may be different.
883     *
884     * # Panics
885     *
886     * If N == 0, D == 0, the shapes of the sources are not identical*, or the dimension for
887     * chaining is not in sources' shape.
888     *
889     * *except for the lengths along the provided dimension.
890     */
891    #[track_caller]
892    pub fn from(sources: [S; N], along: Dimension) -> Self {
893        if N == 0 {
894            panic!("No sources provided");
895        }
896        if D == 0 {
897            panic!("Can't chain along 0 dimensional tensors");
898        }
899        let shape = sources[0].view_shape();
900        let along = match dimensions::position_of(&shape, along) {
901            Some(d) => d,
902            None => panic!(
903                "The dimension {:?} is not in the source's shapes: {:?}",
904                along, shape
905            ),
906        };
907        validate_shapes_similar(sources.iter().map(|tensor| tensor.view_shape()), along);
908        Self {
909            sources,
910            along,
911            _type: PhantomData,
912        }
913    }
914
915    /**
916     * Consumes the TensorChain, yielding the sources it was created from in the same order.
917     */
918    pub fn sources(self) -> [S; N] {
919        self.sources
920    }
921
922    // # Safety
923    //
924    // Giving out a mutable reference to our sources could allow then to be changed out from under
925    // us and make our shape invalid. However, since the sources implement TensorRef interior
926    // mutability is not allowed, so we can give out shared references without breaking our own
927    // integrity.
928    /**
929     * Gives a reference to all the TensorChain's sources it was created from in the same order
930     */
931    pub fn sources_ref(&self) -> &[S; N] {
932        &self.sources
933    }
934}
935
936impl<T, S1, S2, const D: usize> TensorChain<T, (S1, S2), D>
937where
938    S1: TensorRef<T, D>,
939    S2: TensorRef<T, D>,
940{
941    /**
942     * Creates a TensorChain from two sources and the dimension name to chain the sources along.
943     * The sources must all have an identical shape, including the provided dimension, except for
944     * the dimension lengths of the provided dimension name which may be different.
945     *
946     * # Panics
947     *
948     * If D == 0, the shapes of the sources are not identical*, or the dimension for
949     * chaining is not in sources' shape.
950     *
951     * *except for the lengths along the provided dimension.
952     */
953    #[track_caller]
954    pub fn from(sources: (S1, S2), along: Dimension) -> Self {
955        if D == 0 {
956            panic!("Can't chain along 0 dimensional tensors");
957        }
958        let shape = sources.0.view_shape();
959        let along = match dimensions::position_of(&shape, along) {
960            Some(d) => d,
961            None => panic!(
962                "The dimension {:?} is not in the source's shapes: {:?}",
963                along, shape
964            ),
965        };
966        validate_shapes_similar(
967            [sources.0.view_shape(), sources.1.view_shape()].into_iter(),
968            along,
969        );
970        Self {
971            sources,
972            along,
973            _type: PhantomData,
974        }
975    }
976
977    /**
978     * Consumes the TensorChain, yielding the sources it was created from in the same order.
979     */
980    pub fn sources(self) -> (S1, S2) {
981        self.sources
982    }
983
984    // # Safety
985    //
986    // Giving out a mutable reference to our sources could allow then to be changed out from under
987    // us and make our shape invalid. However, since the sources implement TensorRef interior
988    // mutability is not allowed, so we can give out shared references without breaking our own
989    // integrity.
990    /**
991     * Gives a reference to all the TensorChain's sources it was created from in the same order
992     */
993    pub fn sources_ref(&self) -> &(S1, S2) {
994        &self.sources
995    }
996}
997
998impl<T, S1, S2, S3, const D: usize> TensorChain<T, (S1, S2, S3), D>
999where
1000    S1: TensorRef<T, D>,
1001    S2: TensorRef<T, D>,
1002    S3: TensorRef<T, D>,
1003{
1004    /**
1005     * Creates a TensorChain from three sources and the dimension name to chain the sources along.
1006     * The sources must all have an identical shape, including the provided dimension, except for
1007     * the dimension lengths of the provided dimension name which may be different.
1008     *
1009     * # Panics
1010     *
1011     * If D == 0, the shapes of the sources are not identical*, or the dimension for
1012     * chaining is not in sources' shape.
1013     *
1014     * *except for the lengths along the provided dimension.
1015     */
1016    #[track_caller]
1017    pub fn from(sources: (S1, S2, S3), along: Dimension) -> Self {
1018        if D == 0 {
1019            panic!("Can't chain along 0 dimensional tensors");
1020        }
1021        let shape = sources.0.view_shape();
1022        let along = match dimensions::position_of(&shape, along) {
1023            Some(d) => d,
1024            None => panic!(
1025                "The dimension {:?} is not in the source's shapes: {:?}",
1026                along, shape
1027            ),
1028        };
1029        validate_shapes_similar(
1030            [
1031                sources.0.view_shape(),
1032                sources.1.view_shape(),
1033                sources.2.view_shape(),
1034            ]
1035            .into_iter(),
1036            along,
1037        );
1038        Self {
1039            sources,
1040            along,
1041            _type: PhantomData,
1042        }
1043    }
1044
1045    /**
1046     * Consumes the TensorChain, yielding the sources it was created from in the same order.
1047     */
1048    pub fn sources(self) -> (S1, S2, S3) {
1049        self.sources
1050    }
1051
1052    // # Safety
1053    //
1054    // Giving out a mutable reference to our sources could allow then to be changed out from under
1055    // us and make our shape invalid. However, since the sources implement TensorRef interior
1056    // mutability is not allowed, so we can give out shared references without breaking our own
1057    // integrity.
1058    /**
1059     * Gives a reference to all the TensorChain's sources it was created from in the same order
1060     */
1061    pub fn sources_ref(&self) -> &(S1, S2, S3) {
1062        &self.sources
1063    }
1064}
1065
1066impl<T, S1, S2, S3, S4, const D: usize> TensorChain<T, (S1, S2, S3, S4), D>
1067where
1068    S1: TensorRef<T, D>,
1069    S2: TensorRef<T, D>,
1070    S3: TensorRef<T, D>,
1071    S4: TensorRef<T, D>,
1072{
1073    /**
1074     * Creates a TensorChain from four sources and the dimension name to chain the sources along.
1075     * The sources must all have an identical shape, including the provided dimension, except for
1076     * the dimension lengths of the provided dimension name which may be different.
1077     *
1078     * # Panics
1079     *
1080     * If D == 0, the shapes of the sources are not identical*, or the dimension for
1081     * chaining is not in sources' shape.
1082     *
1083     * *except for the lengths along the provided dimension.
1084     */
1085    #[track_caller]
1086    pub fn from(sources: (S1, S2, S3, S4), along: Dimension) -> Self {
1087        if D == 0 {
1088            panic!("Can't chain along 0 dimensional tensors");
1089        }
1090        let shape = sources.0.view_shape();
1091        let along = match dimensions::position_of(&shape, along) {
1092            Some(d) => d,
1093            None => panic!(
1094                "The dimension {:?} is not in the source's shapes: {:?}",
1095                along, shape
1096            ),
1097        };
1098        validate_shapes_similar(
1099            [
1100                sources.0.view_shape(),
1101                sources.1.view_shape(),
1102                sources.2.view_shape(),
1103                sources.3.view_shape(),
1104            ]
1105            .into_iter(),
1106            along,
1107        );
1108        Self {
1109            sources,
1110            along,
1111            _type: PhantomData,
1112        }
1113    }
1114
1115    /**
1116     * Consumes the TensorChain, yielding the sources it was created from in the same order.
1117     */
1118    pub fn sources(self) -> (S1, S2, S3, S4) {
1119        self.sources
1120    }
1121
1122    // # Safety
1123    //
1124    // Giving out a mutable reference to our sources could allow then to be changed out from under
1125    // us and make our shape invalid. However, since the sources implement TensorRef interior
1126    // mutability is not allowed, so we can give out shared references without breaking our own
1127    // integrity.
1128    /**
1129     * Gives a reference to all the TensorChain's sources it was created from in the same order
1130     */
1131    pub fn sources_ref(&self) -> &(S1, S2, S3, S4) {
1132        &self.sources
1133    }
1134}
1135
1136fn view_shape_impl<I, const D: usize>(
1137    first_shape: [(Dimension, usize); D],
1138    shapes: I,
1139    along: usize,
1140) -> [(Dimension, usize); D]
1141where
1142    I: Iterator<Item = [(Dimension, usize); D]>,
1143{
1144    let mut shape = first_shape;
1145    shape[along].1 = shapes.into_iter().map(|shape| shape[along].1).sum();
1146    shape
1147}
1148
1149fn indexing<I, const D: usize>(
1150    indexes: [usize; D],
1151    shapes: I,
1152    along: usize,
1153) -> Option<(usize, [usize; D])>
1154where
1155    I: Iterator<Item = [(Dimension, usize); D]>,
1156{
1157    let mut shapes = shapes.enumerate();
1158    // Keep trying to index the next shape in the chain, if i is still greater
1159    // than the available length we know it's for a later shape, and can subtract
1160    // that available length till we find one.
1161    let mut i = indexes[along];
1162    loop {
1163        let (source, next_shape) = shapes.next()?;
1164        let length_along_chained_dimension = next_shape[along].1;
1165        if i < length_along_chained_dimension {
1166            #[allow(clippy::clone_on_copy)]
1167            let mut indexes = indexes.clone();
1168            indexes[along] = i;
1169            return Some((source, indexes));
1170        }
1171        i -= length_along_chained_dimension;
1172    }
1173}
1174
1175unsafe impl<T, S, const D: usize, const N: usize> TensorRef<T, D> for TensorChain<T, [S; N], D>
1176where
1177    S: TensorRef<T, D>,
1178{
1179    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1180        let (source, indexes) = indexing(
1181            indexes,
1182            self.sources.iter().map(|s| s.view_shape()),
1183            self.along,
1184        )?;
1185        self.sources.get(source)?.get_reference(indexes)
1186    }
1187
1188    fn view_shape(&self) -> [(Dimension, usize); D] {
1189        view_shape_impl(
1190            self.sources[0].view_shape(),
1191            self.sources.iter().map(|s| s.view_shape()),
1192            self.along,
1193        )
1194    }
1195
1196    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1197        unsafe {
1198            let (source, indexes) = indexing(
1199                indexes,
1200                self.sources.iter().map(|s| s.view_shape()),
1201                self.along,
1202            )
1203            // The caller is already responsible for providing valid indexes to us
1204            // so `indexing` will always return Some
1205            .unwrap_unchecked();
1206            self.sources
1207                .get(source)
1208                .unwrap()
1209                .get_reference_unchecked(indexes)
1210        }
1211    }
1212
1213    fn data_layout(&self) -> DataLayout<D> {
1214        // Our chained shapes means the view shape no longer matches up to a single
1215        // line of data in memory in the general case.
1216        DataLayout::NonLinear
1217    }
1218}
1219
1220unsafe impl<T, S, const D: usize, const N: usize> TensorMut<T, D> for TensorChain<T, [S; N], D>
1221where
1222    S: TensorMut<T, D>,
1223{
1224    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1225        let (source, indexes) = indexing(
1226            indexes,
1227            self.sources.iter().map(|s| s.view_shape()),
1228            self.along,
1229        )?;
1230        self.sources.get_mut(source)?.get_reference_mut(indexes)
1231    }
1232
1233    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1234        unsafe {
1235            let (source, indexes) = indexing(
1236                indexes,
1237                self.sources.iter().map(|s| s.view_shape()),
1238                self.along,
1239            )
1240            // The caller is already responsible for providing valid indexes to us
1241            // so `indexing` will always return Some
1242            .unwrap_unchecked();
1243            self.sources
1244                .get_mut(source)
1245                .unwrap()
1246                .get_reference_unchecked_mut(indexes)
1247        }
1248    }
1249}
1250
1251unsafe impl<T, S1, S2, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2), D>
1252where
1253    S1: TensorRef<T, D>,
1254    S2: TensorRef<T, D>,
1255{
1256    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1257        let (source, indexes) = indexing(
1258            indexes,
1259            [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1260            self.along,
1261        )?;
1262        match source {
1263            0 => self.sources.0.get_reference(indexes),
1264            1 => self.sources.1.get_reference(indexes),
1265            _ => None,
1266        }
1267    }
1268
1269    fn view_shape(&self) -> [(Dimension, usize); D] {
1270        view_shape_impl(
1271            self.sources.0.view_shape(),
1272            [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1273            self.along,
1274        )
1275    }
1276
1277    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1278        unsafe {
1279            let (source, indexes) = indexing(
1280                indexes,
1281                [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1282                self.along,
1283            )
1284            // The caller is already responsible for providing valid indexes to us
1285            // so `indexing` will always return Some
1286            .unwrap_unchecked();
1287            match source {
1288                0 => self.sources.0.get_reference_unchecked(indexes),
1289                1 => self.sources.1.get_reference_unchecked(indexes),
1290                // TODO: Can we use unreachable_unchecked here?
1291                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1292            }
1293        }
1294    }
1295
1296    fn data_layout(&self) -> DataLayout<D> {
1297        // Our chained shapes means the view shape no longer matches up to a single
1298        // line of data in memory in the general case.
1299        DataLayout::NonLinear
1300    }
1301}
1302
1303unsafe impl<T, S1, S2, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2), D>
1304where
1305    S1: TensorMut<T, D>,
1306    S2: TensorMut<T, D>,
1307{
1308    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1309        let (source, indexes) = indexing(
1310            indexes,
1311            [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1312            self.along,
1313        )?;
1314        match source {
1315            0 => self.sources.0.get_reference_mut(indexes),
1316            1 => self.sources.1.get_reference_mut(indexes),
1317            _ => None,
1318        }
1319    }
1320
1321    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1322        unsafe {
1323            let (source, indexes) = indexing(
1324                indexes,
1325                [self.sources.0.view_shape(), self.sources.1.view_shape()].into_iter(),
1326                self.along,
1327            )
1328            // The caller is already responsible for providing valid indexes to us
1329            // so `indexing` will always return Some
1330            .unwrap_unchecked();
1331            match source {
1332                0 => self.sources.0.get_reference_unchecked_mut(indexes),
1333                1 => self.sources.1.get_reference_unchecked_mut(indexes),
1334                // TODO: Can we use unreachable_unchecked here?
1335                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1336            }
1337        }
1338    }
1339}
1340
1341unsafe impl<T, S1, S2, S3, const D: usize> TensorRef<T, D> for TensorChain<T, (S1, S2, S3), D>
1342where
1343    S1: TensorRef<T, D>,
1344    S2: TensorRef<T, D>,
1345    S3: TensorRef<T, D>,
1346{
1347    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1348        let (source, indexes) = indexing(
1349            indexes,
1350            [
1351                self.sources.0.view_shape(),
1352                self.sources.1.view_shape(),
1353                self.sources.2.view_shape(),
1354            ]
1355            .into_iter(),
1356            self.along,
1357        )?;
1358        match source {
1359            0 => self.sources.0.get_reference(indexes),
1360            1 => self.sources.1.get_reference(indexes),
1361            2 => self.sources.2.get_reference(indexes),
1362            _ => None,
1363        }
1364    }
1365
1366    fn view_shape(&self) -> [(Dimension, usize); D] {
1367        view_shape_impl(
1368            self.sources.0.view_shape(),
1369            [
1370                self.sources.0.view_shape(),
1371                self.sources.1.view_shape(),
1372                self.sources.2.view_shape(),
1373            ]
1374            .into_iter(),
1375            self.along,
1376        )
1377    }
1378
1379    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1380        unsafe {
1381            let (source, indexes) = indexing(
1382                indexes,
1383                [
1384                    self.sources.0.view_shape(),
1385                    self.sources.1.view_shape(),
1386                    self.sources.2.view_shape(),
1387                ]
1388                .into_iter(),
1389                self.along,
1390            )
1391            // The caller is already responsible for providing valid indexes to us
1392            // so `indexing` will always return Some
1393            .unwrap_unchecked();
1394            match source {
1395                0 => self.sources.0.get_reference_unchecked(indexes),
1396                1 => self.sources.1.get_reference_unchecked(indexes),
1397                2 => self.sources.2.get_reference_unchecked(indexes),
1398                // TODO: Can we use unreachable_unchecked here?
1399                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1400            }
1401        }
1402    }
1403
1404    fn data_layout(&self) -> DataLayout<D> {
1405        // Our chained shapes means the view shape no longer matches up to a single
1406        // line of data in memory in the general case.
1407        DataLayout::NonLinear
1408    }
1409}
1410
1411unsafe impl<T, S1, S2, S3, const D: usize> TensorMut<T, D> for TensorChain<T, (S1, S2, S3), D>
1412where
1413    S1: TensorMut<T, D>,
1414    S2: TensorMut<T, D>,
1415    S3: TensorMut<T, D>,
1416{
1417    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1418        let (source, indexes) = indexing(
1419            indexes,
1420            [
1421                self.sources.0.view_shape(),
1422                self.sources.1.view_shape(),
1423                self.sources.2.view_shape(),
1424            ]
1425            .into_iter(),
1426            self.along,
1427        )?;
1428        match source {
1429            0 => self.sources.0.get_reference_mut(indexes),
1430            1 => self.sources.1.get_reference_mut(indexes),
1431            2 => self.sources.2.get_reference_mut(indexes),
1432            _ => None,
1433        }
1434    }
1435
1436    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1437        unsafe {
1438            let (source, indexes) = indexing(
1439                indexes,
1440                [
1441                    self.sources.0.view_shape(),
1442                    self.sources.1.view_shape(),
1443                    self.sources.2.view_shape(),
1444                ]
1445                .into_iter(),
1446                self.along,
1447            )
1448            // The caller is already responsible for providing valid indexes to us
1449            // so `indexing` will always return Some
1450            .unwrap_unchecked();
1451            match source {
1452                0 => self.sources.0.get_reference_unchecked_mut(indexes),
1453                1 => self.sources.1.get_reference_unchecked_mut(indexes),
1454                2 => self.sources.2.get_reference_unchecked_mut(indexes),
1455                // TODO: Can we use unreachable_unchecked here?
1456                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1457            }
1458        }
1459    }
1460}
1461
1462unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorRef<T, D>
1463    for TensorChain<T, (S1, S2, S3, S4), D>
1464where
1465    S1: TensorRef<T, D>,
1466    S2: TensorRef<T, D>,
1467    S3: TensorRef<T, D>,
1468    S4: TensorRef<T, D>,
1469{
1470    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1471        let (source, indexes) = indexing(
1472            indexes,
1473            [
1474                self.sources.0.view_shape(),
1475                self.sources.1.view_shape(),
1476                self.sources.2.view_shape(),
1477                self.sources.3.view_shape(),
1478            ]
1479            .into_iter(),
1480            self.along,
1481        )?;
1482        match source {
1483            0 => self.sources.0.get_reference(indexes),
1484            1 => self.sources.1.get_reference(indexes),
1485            2 => self.sources.2.get_reference(indexes),
1486            3 => self.sources.3.get_reference(indexes),
1487            _ => None,
1488        }
1489    }
1490
1491    fn view_shape(&self) -> [(Dimension, usize); D] {
1492        view_shape_impl(
1493            self.sources.0.view_shape(),
1494            [
1495                self.sources.0.view_shape(),
1496                self.sources.1.view_shape(),
1497                self.sources.2.view_shape(),
1498                self.sources.3.view_shape(),
1499            ]
1500            .into_iter(),
1501            self.along,
1502        )
1503    }
1504
1505    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1506        unsafe {
1507            let (source, indexes) = indexing(
1508                indexes,
1509                [
1510                    self.sources.0.view_shape(),
1511                    self.sources.1.view_shape(),
1512                    self.sources.2.view_shape(),
1513                    self.sources.3.view_shape(),
1514                ]
1515                .into_iter(),
1516                self.along,
1517            )
1518            // The caller is already responsible for providing valid indexes to us
1519            // so `indexing` will always return Some
1520            .unwrap_unchecked();
1521            match source {
1522                0 => self.sources.0.get_reference_unchecked(indexes),
1523                1 => self.sources.1.get_reference_unchecked(indexes),
1524                2 => self.sources.2.get_reference_unchecked(indexes),
1525                3 => self.sources.3.get_reference_unchecked(indexes),
1526                // TODO: Can we use unreachable_unchecked here?
1527                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1528            }
1529        }
1530    }
1531
1532    fn data_layout(&self) -> DataLayout<D> {
1533        // Our chained shapes means the view shape no longer matches up to a single
1534        // line of data in memory in the general case.
1535        DataLayout::NonLinear
1536    }
1537}
1538
1539unsafe impl<T, S1, S2, S3, S4, const D: usize> TensorMut<T, D>
1540    for TensorChain<T, (S1, S2, S3, S4), D>
1541where
1542    S1: TensorMut<T, D>,
1543    S2: TensorMut<T, D>,
1544    S3: TensorMut<T, D>,
1545    S4: TensorMut<T, D>,
1546{
1547    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1548        let (source, indexes) = indexing(
1549            indexes,
1550            [
1551                self.sources.0.view_shape(),
1552                self.sources.1.view_shape(),
1553                self.sources.2.view_shape(),
1554                self.sources.3.view_shape(),
1555            ]
1556            .into_iter(),
1557            self.along,
1558        )?;
1559        match source {
1560            0 => self.sources.0.get_reference_mut(indexes),
1561            1 => self.sources.1.get_reference_mut(indexes),
1562            2 => self.sources.2.get_reference_mut(indexes),
1563            3 => self.sources.3.get_reference_mut(indexes),
1564            _ => None,
1565        }
1566    }
1567
1568    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1569        unsafe {
1570            let (source, indexes) = indexing(
1571                indexes,
1572                [
1573                    self.sources.0.view_shape(),
1574                    self.sources.1.view_shape(),
1575                    self.sources.2.view_shape(),
1576                    self.sources.3.view_shape(),
1577                ]
1578                .into_iter(),
1579                self.along,
1580            )
1581            // The caller is already responsible for providing valid indexes to us
1582            // so `indexing` will always return Some
1583            .unwrap_unchecked();
1584            match source {
1585                0 => self.sources.0.get_reference_unchecked_mut(indexes),
1586                1 => self.sources.1.get_reference_unchecked_mut(indexes),
1587                2 => self.sources.2.get_reference_unchecked_mut(indexes),
1588                3 => self.sources.3.get_reference_unchecked_mut(indexes),
1589                // TODO: Can we use unreachable_unchecked here?
1590                _ => panic!("Invalid index should never be given to get_reference_unchecked"),
1591            }
1592        }
1593    }
1594}
1595
1596#[test]
1597fn test_chaining() {
1598    use crate::tensors::Tensor;
1599    use crate::tensors::views::TensorView;
1600    #[rustfmt::skip]
1601    let matrix1 = Tensor::from(
1602        [("a", 3), ("b", 2)],
1603        vec![
1604            9, 5,
1605            2, 1,
1606            3, 5
1607        ]
1608    );
1609    #[rustfmt::skip]
1610    let matrix2 = Tensor::from(
1611        [("a", 4), ("b", 2)],
1612        vec![
1613            0, 1,
1614            8, 4,
1615            1, 7,
1616            6, 3
1617        ]
1618    );
1619    let matrix = TensorView::from(TensorChain::<_, (_, _), 2>::from((&matrix1, &matrix2), "a"));
1620    #[rustfmt::skip]
1621    assert_eq!(
1622        matrix,
1623        Tensor::from([("a", 7), ("b", 2)], vec![
1624            9, 5,
1625            2, 1,
1626            3, 5,
1627            0, 1,
1628            8, 4,
1629            1, 7,
1630            6, 3
1631        ])
1632    );
1633    let matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(matrix.map(|x| x));
1634    let different_matrix = Tensor::from([("a", 7), ("b", 1)], (0..7).collect());
1635    let different_matrix_erased: Box<dyn TensorMut<_, 2>> = Box::new(different_matrix);
1636    let another_matrix = TensorView::from(TensorChain::<_, [_; 2], 2>::from(
1637        [matrix_erased, different_matrix_erased],
1638        "b",
1639    ));
1640    #[rustfmt::skip]
1641    assert!(
1642        another_matrix.eq(
1643            &Tensor::from([("a", 7), ("b", 3)], vec![
1644                9, 5, 0,
1645                2, 1, 1,
1646                3, 5, 2,
1647                0, 1, 3,
1648                8, 4, 4,
1649                1, 7, 5,
1650                6, 3, 6
1651            ])
1652        )
1653    );
1654}