Skip to main content

easy_ml/tensors/
indexing.rs

1/*!
2 * # Indexing
3 *
4 * Many libraries represent tensors as N dimensional arrays, however there is often some semantic
5 * meaning to each dimension. You may have a batch of 2000 images, each 100 pixels wide and high,
6 * with each pixel representing 3 numbers for rgb values. This can be represented as a
7 * 2000 x 100 x 100 x 3 tensor, but a 4 dimensional array does not track the semantic meaning
8 * of each dimension and associated index.
9 *
10 * 6 months later you could come back to the code and forget which order the dimensions were
11 * created in, at best getting the indexes out of bounds and causing a crash in your application,
12 * and at worst silently reading the wrong data without realising. *Was it width then height or
13 * height then width?*...
14 *
15 * Easy ML moves the N dimensional array to an implementation detail, and most of its APIs work
16 * on the names of each dimension in a tensor instead of just the order. Instead of a
17 * 2000 x 100 x 100 x 3 tensor in which the last element is at [1999, 99, 99, 2], Easy ML tracks
18 * the names of the dimensions, so you have a
19 * `[("batch", 2000), ("width", 100), ("height", 100), ("rgb", 3)]` shaped tensor.
20 *
21 * This can't stop you from getting the math wrong, but confusion over which dimension
22 * means what is reduced. Tensors carry around their pairs of dimension name and length
23 * so adding a `[("batch", 2000), ("width", 100), ("height", 100), ("rgb", 3)]` shaped tensor
24 * to a `[("batch", 2000), ("height", 100), ("width", 100), ("rgb", 3)]` will fail unless you
25 * reorder one first, and you could access an element as
26 * `tensor.index_by(["batch", "width", "height", "rgb"]).get([1999, 0, 99, 3])` or
27 * `tensor.index_by(["batch", "height", "width", "rgb"]).get([1999, 99, 0, 3])` and read the same data,
28 * because you index into dimensions based on their name, not just the order they are stored in
29 * memory.
30 *
31 * Even with a name for each dimension, at some point you still need to say what order you want
32 * to index each dimension with, and this is where [`TensorAccess`] comes in. It
33 * creates a mapping from the dimension name order you want to access elements with to the order
34 * the dimensions are stored as.
35 */
36
37use crate::differentiation::{Index, Primitive, Record, RecordTensor};
38use crate::numeric::Numeric;
39use crate::tensors::dimensions;
40use crate::tensors::dimensions::DimensionMappings;
41use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
42use crate::tensors::{Dimension, Tensor};
43
44use std::error::Error;
45use std::fmt;
46use std::iter::{ExactSizeIterator, FusedIterator};
47use std::marker::PhantomData;
48
49pub use crate::matrices::iterators::WithIndex;
50
51/**
52 * Access to the data in a Tensor with a particular order of dimension indexing. The order
53 * affects the shape of the TensorAccess as well as the order of indexes you supply to read
54 * or write values to the tensor.
55 *
56 * See the [module level documentation](crate::tensors::indexing) for more information.
57 */
58#[derive(Clone, Debug)]
59pub struct TensorAccess<T, S, const D: usize> {
60    source: S,
61    dimension_mapping: DimensionMappings<D>,
62    _type: PhantomData<T>,
63}
64
65impl<T, S, const D: usize> TensorAccess<T, S, D>
66where
67    S: TensorRef<T, D>,
68{
69    /**
70     * Creates a TensorAccess which can be indexed in the order of the supplied dimensions
71     * to read or write values from this tensor.
72     *
73     * # Panics
74     *
75     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
76     */
77    #[track_caller]
78    pub fn from(source: S, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
79        match TensorAccess::try_from(source, dimensions) {
80            Err(error) => panic!("{}", error),
81            Ok(success) => success,
82        }
83    }
84
85    /**
86     * Creates a TensorAccess which can be indexed in the order of the supplied dimensions
87     * to read or write values from this tensor.
88     *
89     * Returns Err if the set of dimensions supplied do not match the set of dimensions in this
90     * tensor's shape.
91     */
92    pub fn try_from(
93        source: S,
94        dimensions: [Dimension; D],
95    ) -> Result<TensorAccess<T, S, D>, InvalidDimensionsError<D>> {
96        Ok(TensorAccess {
97            dimension_mapping: DimensionMappings::new(&source.view_shape(), &dimensions)
98                .ok_or_else(|| InvalidDimensionsError {
99                    actual: source.view_shape(),
100                    requested: dimensions,
101                })?,
102            source,
103            _type: PhantomData,
104        })
105    }
106
107    /**
108     * Creates a TensorAccess which is indexed in the same order as the dimensions in the view
109     * shape of the tensor it is created from.
110     *
111     * Hence if you create a TensorAccess directly from a Tensor by `from_source_order`
112     * this uses the order the dimensions were laid out in memory with.
113     *
114     * ```
115     * use easy_ml::tensors::Tensor;
116     * use easy_ml::tensors::indexing::TensorAccess;
117     * let tensor = Tensor::from([("x", 2), ("y", 2), ("z", 2)], vec![
118     *     1, 2,
119     *     3, 4,
120     *
121     *     5, 6,
122     *     7, 8
123     * ]);
124     * let xyz = tensor.index_by(["x", "y", "z"]);
125     * let also_xyz = TensorAccess::from_source_order(&tensor);
126     * let also_xyz = tensor.index();
127     * ```
128     */
129    pub fn from_source_order(source: S) -> TensorAccess<T, S, D> {
130        TensorAccess {
131            dimension_mapping: DimensionMappings::no_op_mapping(),
132            source,
133            _type: PhantomData,
134        }
135    }
136
137    /**
138     * Creates a TensorAccess which is indexed in the same order as the linear data layout
139     * dimensions in the tensor it is created from, or None if the source data layout
140     * is not linear.
141     *
142     * Hence if you use `from_memory_order` on a source that was originally big endian like
143     * [Tensor] this uses the order for efficient iteration through each step in memory
144     * when [iterating](TensorIterator).
145     */
146    pub fn from_memory_order(source: S) -> Option<TensorAccess<T, S, D>> {
147        let data_layout = match source.data_layout() {
148            DataLayout::Linear(order) => order,
149            _ => return None,
150        };
151        let shape = source.view_shape();
152        Some(TensorAccess::try_from(source, data_layout).unwrap_or_else(|_| panic!(
153            "Source implementation contained dimensions {:?} in data_layout that were not the same set as in the view_shape {:?} which breaks the contract of TensorRef",
154             data_layout, shape
155        )))
156    }
157
158    /**
159     * The shape this TensorAccess has with the dimensions mapped to the order the TensorAccess
160     * was created with, not necessarily the same order as in the underlying tensor.
161     */
162    pub fn shape(&self) -> [(Dimension, usize); D] {
163        self.dimension_mapping
164            .map_shape_to_requested(&self.source.view_shape())
165    }
166
167    pub fn source(self) -> S {
168        self.source
169    }
170
171    // # Safety
172    //
173    // Giving out a mutable reference to our source could allow it to be changed out from under us
174    // and make our dimmension mapping invalid. However, since the source implements TensorRef
175    // interior mutability is not allowed, so we can give out shared references without breaking
176    // our own integrity.
177    pub fn source_ref(&self) -> &S {
178        &self.source
179    }
180}
181
182/**
183 * An error indicating failure to create a TensorAccess because the requested dimension order
184 * does not match the shape in the source data.
185 */
186#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
187pub struct InvalidDimensionsError<const D: usize> {
188    pub actual: [(Dimension, usize); D],
189    pub requested: [Dimension; D],
190}
191
192impl<const D: usize> Error for InvalidDimensionsError<D> {}
193
194impl<const D: usize> fmt::Display for InvalidDimensionsError<D> {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        write!(
197            f,
198            "Requested dimension order: {:?} does not match the shape in the source: {:?}",
199            &self.actual, &self.requested
200        )
201    }
202}
203
204#[test]
205fn test_sync() {
206    fn assert_sync<T: Sync>() {}
207    assert_sync::<InvalidDimensionsError<3>>();
208}
209
210#[test]
211fn test_send() {
212    fn assert_send<T: Send>() {}
213    assert_send::<InvalidDimensionsError<3>>();
214}
215
216impl<T, S, const D: usize> TensorAccess<T, S, D>
217where
218    S: TensorRef<T, D>,
219{
220    /**
221     * Using the dimension ordering of the TensorAccess, gets a reference to the value at the
222     * index if the index is in range. Otherwise returns None.
223     */
224    pub fn try_get_reference(&self, indexes: [usize; D]) -> Option<&T> {
225        self.source
226            .get_reference(self.dimension_mapping.map_dimensions_to_source(&indexes))
227    }
228
229    /**
230     * Using the dimension ordering of the TensorAccess, gets a reference to the value at the
231     * index if the index is in range, panicking if the index is out of range.
232     */
233    // NOTE: Ideally `get_reference` would be used here for consistency, but that opens the
234    // minefield of TensorRef::get_reference and TensorAccess::get_ref being different signatures
235    // but the same name.
236    #[track_caller]
237    pub fn get_ref(&self, indexes: [usize; D]) -> &T {
238        match self.try_get_reference(indexes) {
239            Some(reference) => reference,
240            None => panic!(
241                "Unable to index with {:?}, Tensor dimensions are {:?}.",
242                indexes,
243                self.shape()
244            ),
245        }
246    }
247
248    /**
249     * Using the dimension ordering of the TensorAccess, gets a reference to the value at the
250     * index wihout any bounds checking.
251     *
252     * # Safety
253     *
254     * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
255     * resulting reference is not used. Valid indexes are defined as in [TensorRef]. Note that
256     * the order of the indexes needed here must match with
257     * [`TensorAccess::shape`](TensorAccess::shape) which may not neccessarily be the same
258     * as the `view_shape` of the `TensorRef` implementation this TensorAccess was created from).
259     *
260     * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
261     * [TensorRef]: TensorRef
262     */
263    // NOTE: This aliases with TensorRef::get_reference_unchecked but the TensorRef impl
264    // just calls this and the signatures match anyway, so there are no potential issues.
265    #[allow(clippy::missing_safety_doc)] // it's not missing
266    pub unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
267        unsafe {
268            self.source
269                .get_reference_unchecked(self.dimension_mapping.map_dimensions_to_source(&indexes))
270        }
271    }
272
273    /**
274     * Returns an iterator over references to the data in this TensorAccess, in the order of
275     * the TensorAccess shape.
276     */
277    pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, TensorAccess<T, S, D>, D> {
278        TensorReferenceIterator::from(self)
279    }
280}
281
282impl<T, S, const D: usize> TensorAccess<T, S, D>
283where
284    S: TensorRef<T, D>,
285    T: Clone,
286{
287    /**
288     * Using the dimension ordering of the TensorAccess, gets a copy of the value at the
289     * index if the index is in range, panicking if the index is out of range.
290     *
291     * For a non panicking API see [`try_get_reference`](TensorAccess::try_get_reference)
292     */
293    #[track_caller]
294    pub fn get(&self, indexes: [usize; D]) -> T {
295        match self.try_get_reference(indexes) {
296            Some(reference) => reference.clone(),
297            None => panic!(
298                "Unable to index with {:?}, Tensor dimensions are {:?}.",
299                indexes,
300                self.shape()
301            ),
302        }
303    }
304
305    /**
306     * Gets a copy of the first value in this tensor.
307     * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
308     * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
309     */
310    pub fn first(&self) -> T {
311        self.iter()
312            .next()
313            .expect("Tensors always have at least 1 element")
314    }
315
316    /**
317     * Creates and returns a new tensor with all values from the original with the
318     * function applied to each.
319     *
320     * Note: mapping methods are defined on [Tensor] and
321     * [TensorView](crate::tensors::views::TensorView) directly so you don't need to create a
322     * TensorAccess unless you want to do the mapping with a different dimension order.
323     */
324    pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
325        let mapped = self.iter().map(mapping_function).collect();
326        Tensor::from(self.shape(), mapped)
327    }
328
329    /**
330     * Creates and returns a new tensor with all values from the original and
331     * the index of each value mapped by a function. The indexes passed to the mapping
332     * function always increment the rightmost index, starting at all 0s, using the dimension
333     * order that the TensorAccess is indexed by, not neccessarily the index order the
334     * original source uses.
335     *
336     * Note: mapping methods are defined on [Tensor] and
337     * [TensorView](crate::tensors::views::TensorView) directly so you don't need to create a
338     * TensorAccess unless you want to do the mapping with a different dimension order.
339     */
340    pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
341        let mapped = self
342            .iter()
343            .with_index()
344            .map(|(i, x)| mapping_function(i, x))
345            .collect();
346        Tensor::from(self.shape(), mapped)
347    }
348
349    /**
350     * Returns an iterator over copies of the data in this TensorAccess, in the order of
351     * the TensorAccess shape.
352     */
353    pub fn iter(&self) -> TensorIterator<'_, T, TensorAccess<T, S, D>, D> {
354        TensorIterator::from(self)
355    }
356}
357
358impl<T, S, const D: usize> TensorAccess<T, S, D>
359where
360    S: TensorMut<T, D>,
361{
362    /**
363     * Using the dimension ordering of the TensorAccess, gets a mutable reference to the value at
364     * the index if the index is in range. Otherwise returns None.
365     */
366    pub fn try_get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
367        self.source
368            .get_reference_mut(self.dimension_mapping.map_dimensions_to_source(&indexes))
369    }
370
371    /**
372     * Using the dimension ordering of the TensorAccess, gets a mutable reference to the value at
373     * the index if the index is in range, panicking if the index is out of range.
374     */
375    // NOTE: Ideally `get_reference_mut` would be used here for consistency, but that opens the
376    // minefield of TensorMut::get_reference_mut and TensorAccess::get_ref_mut being different
377    // signatures but the same name.
378    #[track_caller]
379    pub fn get_ref_mut(&mut self, indexes: [usize; D]) -> &mut T {
380        match self.try_get_reference_mut(indexes) {
381            Some(reference) => reference,
382            // can't provide a better error because the borrow checker insists that returning
383            // a reference in the Some branch means our mutable borrow prevents us calling
384            // self.shape() and a bad error is better than cloning self.shape() on every call
385            None => panic!("Unable to index with {:?}", indexes),
386        }
387    }
388
389    /**
390     * Using the dimension ordering of the TensorAccess, gets a mutable reference to the value at
391     * the index wihout any bounds checking.
392     *
393     * # Safety
394     *
395     * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
396     * resulting reference is not used. Valid indexes are defined as in [TensorRef]. Note that
397     * the order of the indexes needed here must match with
398     * [`TensorAccess::shape`](TensorAccess::shape) which may not neccessarily be the same
399     * as the `view_shape` of the `TensorRef` implementation this TensorAccess was created from).
400     *
401     * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
402     * [TensorRef]: TensorRef
403     */
404    // NOTE: This aliases with TensorRef::get_reference_unchecked_mut but the TensorMut impl
405    // just calls this and the signatures match anyway, so there are no potential issues.
406    #[allow(clippy::missing_safety_doc)] // it's not missing
407    pub unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
408        unsafe {
409            self.source.get_reference_unchecked_mut(
410                self.dimension_mapping.map_dimensions_to_source(&indexes),
411            )
412        }
413    }
414
415    /**
416     * Returns an iterator over mutable references to the data in this TensorAccess, in the order
417     * of the TensorAccess shape.
418     */
419    pub fn iter_reference_mut(
420        &mut self,
421    ) -> TensorReferenceMutIterator<'_, T, TensorAccess<T, S, D>, D> {
422        TensorReferenceMutIterator::from(self)
423    }
424}
425
426impl<T, S, const D: usize> TensorAccess<T, S, D>
427where
428    S: TensorMut<T, D>,
429    T: Clone,
430{
431    /**
432     * Applies a function to all values in the tensor, modifying
433     * the tensor in place.
434     */
435    pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
436        self.iter_reference_mut()
437            .for_each(|x| *x = mapping_function(x.clone()));
438    }
439
440    /**
441     * Applies a function to all values and each value's index in the tensor, modifying
442     * the tensor in place. The indexes passed to the mapping function always increment
443     * the rightmost index, starting at all 0s, using the dimension order that the
444     * TensorAccess is indexed by, not neccessarily the index order the original source uses.
445     */
446    pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
447        self.iter_reference_mut()
448            .with_index()
449            .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
450    }
451}
452
453impl<'a, T, S, const D: usize> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D>
454where
455    T: Numeric + Primitive,
456    S: TensorRef<(T, Index), D>,
457{
458    /**
459     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
460     * as a Record if the index is in range, panicking if the index is out of range.
461     *
462     * If you need to access all the data as records instead of just a specific index you should
463     * probably use one of the iterator APIs instead.
464     *
465     * See also: [iter_as_records](RecordTensor::iter_as_records)
466     *
467     * # Panics
468     *
469     * If the index is out of range.
470     *
471     * For a non panicking API see [try_get_as_record](TensorAccess::try_get_as_record)
472     *
473     * ```
474     * use easy_ml::differentiation::RecordTensor;
475     * use easy_ml::differentiation::WengertList;
476     * use easy_ml::tensors::Tensor;
477     *
478     * let list = WengertList::new();
479     * let X = RecordTensor::variables(
480     *     &list,
481     *     Tensor::from(
482     *         [("r", 2), ("c", 3)],
483     *         vec![
484     *             3.0, 4.0, 5.0,
485     *             1.0, 4.0, 9.0,
486     *         ]
487     *     )
488     * );
489     * let x = X.index_by(["c", "r"]).get_as_record([2, 0]);
490     * assert_eq!(x.number, 5.0);
491     * ```
492     */
493    #[track_caller]
494    pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
495        Record::from_existing(self.get(indexes), self.source.history())
496    }
497
498    /**
499     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
500     * as a Record if the index is in range. Otherwise returns None.
501     *
502     * If you need to access all the data as records instead of just a specific index you should
503     * probably use one of the iterator APIs instead.
504     *
505     * See also: [iter_as_records](RecordTensor::iter_as_records)
506     */
507    pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
508        self.try_get_reference(indexes)
509            .map(|r| Record::from_existing(r.clone(), self.source.history()))
510    }
511}
512
513impl<'a, T, S, const D: usize> TensorAccess<(T, Index), RecordTensor<'a, T, S, D>, D>
514where
515    T: Numeric + Primitive,
516    S: TensorRef<(T, Index), D>,
517{
518    /**
519     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
520     * as a Record if the index is in range, panicking if the index is out of range.
521     *
522     * If you need to access all the data as records instead of just a specific index you should
523     * probably use one of the iterator APIs instead.
524     *
525     * See also: [iter_as_records](RecordTensor::iter_as_records)
526     *
527     * # Panics
528     *
529     * If the index is out of range.
530     *
531     * For a non panicking API see [try_get_as_record](TensorAccess::try_get_as_record)
532     */
533    #[track_caller]
534    pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
535        Record::from_existing(self.get(indexes), self.source.history())
536    }
537
538    /**
539     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
540     * as a Record if the index is in range. Otherwise returns None.
541     *
542     * If you need to access all the data as records instead of just a specific index you should
543     * probably use one of the iterator APIs instead.
544     *
545     * See also: [iter_as_records](RecordTensor::iter_as_records)
546     */
547    pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
548        self.try_get_reference(indexes)
549            .map(|r| Record::from_existing(r.clone(), self.source.history()))
550    }
551}
552
553impl<'a, T, S, const D: usize> TensorAccess<(T, Index), &mut RecordTensor<'a, T, S, D>, D>
554where
555    T: Numeric + Primitive,
556    S: TensorRef<(T, Index), D>,
557{
558    /**
559     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
560     * as a Record if the index is in range, panicking if the index is out of range.
561     *
562     * If you need to access all the data as records instead of just a specific index you should
563     * probably use one of the iterator APIs instead.
564     *
565     * See also: [iter_as_records](RecordTensor::iter_as_records)
566     *
567     * # Panics
568     *
569     * If the index is out of range.
570     *
571     * For a non panicking API see [try_get_as_record](TensorAccess::try_get_as_record)
572     */
573    #[track_caller]
574    pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
575        Record::from_existing(self.get(indexes), self.source.history())
576    }
577
578    /**
579     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
580     * as a Record if the index is in range. Otherwise returns None.
581     *
582     * If you need to access all the data as records instead of just a specific index you should
583     * probably use one of the iterator APIs instead.
584     *
585     * See also: [iter_as_records](RecordTensor::iter_as_records)
586     */
587    pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
588        self.try_get_reference(indexes)
589            .map(|r| Record::from_existing(r.clone(), self.source.history()))
590    }
591}
592
593// # Safety
594//
595// The type implementing TensorRef inside the TensorAccess must implement it correctly, so by
596// delegating to it without changing anything other than the order we index it, we implement
597// TensorRef correctly as well.
598/**
599 * A TensorAccess implements TensorRef, with the dimension order and indexing matching that of the
600 * TensorAccess shape.
601 */
602unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorAccess<T, S, D>
603where
604    S: TensorRef<T, D>,
605{
606    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
607        self.try_get_reference(indexes)
608    }
609
610    fn view_shape(&self) -> [(Dimension, usize); D] {
611        self.shape()
612    }
613
614    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
615        unsafe { self.get_reference_unchecked(indexes) }
616    }
617
618    fn data_layout(&self) -> DataLayout<D> {
619        match self.source.data_layout() {
620            // We might have reordered the view_shape but we didn't rearrange the memory or change
621            // what each dimension name refers to in memory, so the data layout remains as is.
622            DataLayout::Linear(order) => DataLayout::Linear(order),
623            DataLayout::NonLinear => DataLayout::NonLinear,
624            DataLayout::Other => DataLayout::Other,
625        }
626    }
627}
628
629// # Safety
630//
631// The type implementing TensorMut inside the TensorAccess must implement it correctly, so by
632// delegating to it without changing anything other than the order we index it, we implement
633// TensorMut correctly as well.
634/**
635 * A TensorAccess implements TensorMut, with the dimension order and indexing matching that of the
636 * TensorAccess shape.
637 */
638unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorAccess<T, S, D>
639where
640    S: TensorMut<T, D>,
641{
642    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
643        self.try_get_reference_mut(indexes)
644    }
645
646    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
647        unsafe { self.get_reference_unchecked_mut(indexes) }
648    }
649}
650
651/**
652 * Any tensor access of a Displayable type implements Display
653 *
654 * You can control the precision of the formatting using format arguments, i.e.
655 * `format!("{:.3}", tensor)`
656 */
657impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorAccess<T, S, D>
658where
659    T: std::fmt::Display,
660    S: TensorRef<T, D>,
661{
662    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
663        crate::tensors::display::format_view(&self, f)?;
664        writeln!(f)?;
665        write!(f, "Data Layout = {:?}", self.data_layout())
666    }
667}
668
669/**
670 * An iterator over all indexes in a shape.
671 *
672 * First the all 0 index is iterated, then each iteration increments the rightmost index.
673 * For a shape of `[("a", 2), ("b", 2), ("c", 2)]` this will yield indexes in order of: `[0,0,0]`,
674 * `[0,0,1]`, `[0,1,0]`, `[0,1,1]`, `[1,0,0]`, `[1,0,1]`, `[1,1,0]`, `[1,1,1]`,
675 *
676 * You don't typically need to use this directly, as tensors have iterators that iterate over
677 * them and return values to you (using this under the hood), but `ShapeIterator` can be useful
678 * if you need to hold a mutable reference to a tensor while iterating as `ShapeIterator` does
679 * not borrow the tensor. NB: if you do index into a tensor you're mutably borrowing using
680 * `ShapeIterator` directly, take care to ensure you don't accidentally reshape the tensor and
681 * continue to use indexes from `ShapeIterator` as they would then be invalid.
682 */
683#[derive(Clone, Debug)]
684pub struct ShapeIterator<const D: usize> {
685    shape: [(Dimension, usize); D],
686    indexes: [usize; D],
687    finished: bool,
688}
689
690/// If we're given an invalid shape (shape input is not neccessarily going to meet the no
691/// 0 lengths contract of TensorRef because that's not actually required here), we
692/// should return a finished iterator immediately and not iterate at all.
693/// Since this is an iterator over an owned shape, it's not going to become invalid later
694/// when we start iterating so this is the only constructor check we need.
695fn is_starting_index_valid(shape: &[(Dimension, usize)]) -> bool {
696    shape.iter().all(|(_, l)| *l > 0)
697}
698
699impl<const D: usize> ShapeIterator<D> {
700    /**
701     * Constructs a ShapeIterator for a shape.
702     *
703     * If the shape has any dimensions with a length of zero, the iterator will immediately
704     * return None on [`next()`](Iterator::next).
705     */
706    pub fn from(shape: [(Dimension, usize); D]) -> ShapeIterator<D> {
707        let starting_index_valid = is_starting_index_valid(&shape);
708        ShapeIterator {
709            shape,
710            indexes: [0; D],
711            finished: !starting_index_valid,
712        }
713    }
714}
715
716impl<const D: usize> Iterator for ShapeIterator<D> {
717    type Item = [usize; D];
718
719    fn next(&mut self) -> Option<Self::Item> {
720        iter(&mut self.finished, &mut self.indexes, &self.shape)
721    }
722
723    fn size_hint(&self) -> (usize, Option<usize>) {
724        size_hint(self.finished, &self.indexes, &self.shape)
725    }
726}
727
728// Once we hit the end we mark ourselves as finished so we're always Fused.
729impl<const D: usize> FusedIterator for ShapeIterator<D> {}
730// We can always calculate the exact number of steps remaining because the shape and indexes are
731// private fields that are only mutated by `next` to count up.
732impl<const D: usize> ExactSizeIterator for ShapeIterator<D> {}
733
734/// Common index order iterator logic
735fn iter<const D: usize>(
736    finished: &mut bool,
737    indexes: &mut [usize; D],
738    shape: &[(Dimension, usize); D],
739) -> Option<[usize; D]> {
740    if *finished {
741        return None;
742    }
743
744    let value = Some(*indexes);
745
746    if D > 0 {
747        // Increment index of final dimension. In the 2D case, we iterate through a row by
748        // incrementing through every column index.
749        indexes[D - 1] += 1;
750        for d in (1..D).rev() {
751            if indexes[d] == shape[d].1 {
752                // ran to end of this dimension with our index
753                // In the 2D case, we finished indexing through every column in the row,
754                // and it's now time to move onto the next row.
755                indexes[d] = 0;
756                indexes[d - 1] += 1;
757            }
758        }
759        // Check if we ran past the final index
760        if indexes[0] == shape[0].1 {
761            *finished = true;
762        }
763    } else {
764        *finished = true;
765    }
766
767    value
768}
769
770/// Common index order iterator logic
771fn iter_back<const D: usize>(
772    finished: &mut bool,
773    indexes: &mut [usize; D],
774    shape: &[(Dimension, usize); D],
775) -> Option<[usize; D]> {
776    if *finished {
777        return None;
778    }
779
780    let value = Some(*indexes);
781
782    if D > 0 {
783        let mut bounds = [false; D];
784
785        // Decrement index of final dimension. In the 2D case, we iterate through a row by
786        // decrementing through every column index.
787        if indexes[D - 1] == 0 {
788            bounds[D - 1] = true;
789        } else {
790            indexes[D - 1] -= 1;
791        }
792        for d in (1..D).rev() {
793            if bounds[d] {
794                // ran to start of this dimension with our index
795                // In the 2D case, we finished indexing through every column in the row,
796                // and it's now time to move onto the next row.
797                indexes[d] = shape[d].1 - 1;
798                if indexes[d - 1] == 0 {
799                    bounds[d - 1] = true;
800                } else {
801                    indexes[d - 1] -= 1;
802                }
803            }
804        }
805        // Check if we reached the first index
806        if bounds[0] {
807            *finished = true;
808        }
809    } else {
810        *finished = true;
811    }
812
813    value
814}
815
816/// Common size hint logic
817fn size_hint<const D: usize>(
818    finished: bool,
819    indexes: &[usize; D],
820    shape: &[(Dimension, usize); D],
821) -> (usize, Option<usize>) {
822    if finished {
823        return (0, Some(0));
824    }
825
826    let remaining = if D > 0 {
827        let total = dimensions::elements(shape);
828        let strides = crate::tensors::compute_strides(shape);
829        let seen = crate::tensors::get_index_direct_unchecked(indexes, &strides);
830        total - seen
831    } else {
832        1
833        // If D == 0 and we're not finished we've not returned the sole index yet so there's
834        // exactly 1 left
835    };
836
837    (remaining, Some(remaining))
838}
839
840/// Common size hint logic
841fn double_ended_size_hint<const D: usize>(
842    finished: bool,
843    forward_indexes: &[usize; D],
844    back_indexes: &[usize; D],
845    shape: &[(Dimension, usize); D],
846) -> (usize, Option<usize>) {
847    if finished {
848        return (0, Some(0));
849    }
850
851    let remaining = if D > 0 {
852        //let total = dimensions::elements(shape);
853        let strides = crate::tensors::compute_strides(shape);
854        let progress_forward =
855            crate::tensors::get_index_direct_unchecked(forward_indexes, &strides);
856        let progress_backward = crate::tensors::get_index_direct_unchecked(back_indexes, &strides);
857        // progress_forward will range from 0 if we've not iterated forward at all yet
858        // through to the total-1 if we are on the final index at the end.
859        // likewise progress_backward starts at total-1 and finishes at 0 when on the first
860        // index.
861        // To calculate total left going forward (as in forward only case) and then
862        // subtract the total already seen backward we'd have:
863        // (total - progress_forward) - ((total - 1) - progress_backward)
864        // This cancels to
865        1 + progress_backward - progress_forward
866    } else {
867        1
868        // If D == 0 and we're not finished we've not returned the sole index yet so there's
869        // exactly 1 left
870    };
871
872    (remaining, Some(remaining))
873}
874
875#[derive(Clone, Debug)]
876pub(crate) struct DynamicShapeIterator {
877    shape: Vec<(Dimension, usize)>,
878    indexes: Vec<usize>,
879    next: Vec<usize>,
880    finished: bool,
881}
882
883impl DynamicShapeIterator {
884    pub(crate) fn from(shape: &Vec<(Dimension, usize)>) -> DynamicShapeIterator {
885        let starting_index_valid = is_starting_index_valid(&shape);
886        let number_of_dimensions = shape.len();
887        DynamicShapeIterator {
888            shape: shape.clone(),
889            indexes: vec![0; number_of_dimensions],
890            next: vec![0; number_of_dimensions],
891            finished: !starting_index_valid,
892        }
893    }
894
895    pub(crate) fn next(&mut self) -> Option<&Vec<usize>> {
896        if self.finished {
897            return None;
898        }
899
900        let dimensions = self.shape.len();
901        // We return borrows of self.next, and assign to it the
902        // contents of self.indexes so we can avoid allocating
903        // a vec on each iteration, this keeps the vecs at a constant 2
904        // for the entire iteration. Unfortunately returning a self
905        // borrow also makes implementing Iterator very tricky, so this
906        // is just a method with a similar API.
907        self.next.clone_from(&self.indexes);
908        let value = Some(&self.next);
909
910        if dimensions > 0 {
911            // Increment index of final dimension. In the 2D case, we iterate through a row by
912            // incrementing through every column index.
913            self.indexes[dimensions - 1] += 1;
914            for d in (1..dimensions).rev() {
915                if self.indexes[d] == self.shape[d].1 {
916                    // ran to end of this dimension with our index
917                    // In the 2D case, we finished indexing through every column in the row,
918                    // and it's now time to move onto the next row.
919                    self.indexes[d] = 0;
920                    self.indexes[d - 1] += 1;
921                }
922            }
923            // Check if we ran past the final index
924            if self.indexes[0] == self.shape[0].1 {
925                self.finished = true;
926            }
927        } else {
928            self.finished = true;
929        }
930
931        value
932    }
933}
934
935/**
936 * An iterator over all indexes in a shape which can iterate in both directions.
937 *
938 * Going forwards, first the all 0 index is iterated, then each iteration increments the rightmost
939 * index.
940 * For a shape of `[("a", 2), ("b", 2), ("c", 2)]` this will yield indexes in order of: `[0,0,0]`,
941 * `[0,0,1]`, `[0,1,0]`, `[0,1,1]`, `[1,0,0]`, `[1,0,1]`, `[1,1,0]`, `[1,1,1]`,
942 * When iterating backwards, the indexes are yielded in reverse. Indexes do not cross,
943 * iteration is over when they indexes meet in the middle.
944 */
945#[derive(Clone, Debug)]
946pub struct DoubleEndedShapeIterator<const D: usize> {
947    shape: [(Dimension, usize); D],
948    forward_indexes: [usize; D],
949    back_indexes: [usize; D],
950    finished: bool,
951}
952
953impl<const D: usize> DoubleEndedShapeIterator<D> {
954    /**
955     * Constructs a DoubleEndedShapeIterator for a shape.
956     *
957     * If the shape has any dimensions with a length of zero, the iterator will immediately
958     * return None on [`next()`](Iterator::next) or
959     * [`next_back()`](DoubleEndedIterator::next_back()).
960     */
961    pub fn from(shape: [(Dimension, usize); D]) -> DoubleEndedShapeIterator<D> {
962        let starting_index_valid = is_starting_index_valid(&shape);
963        DoubleEndedShapeIterator {
964            shape,
965            forward_indexes: [0; D],
966            back_indexes: shape.map(|(_, l)| l - 1),
967            finished: !starting_index_valid,
968        }
969    }
970}
971
972fn overlapping_iterators<const D: usize>(
973    forward_indexes: &[usize; D],
974    back_indexes: &[usize; D],
975) -> bool {
976    forward_indexes == back_indexes
977}
978
979impl<const D: usize> Iterator for DoubleEndedShapeIterator<D> {
980    type Item = [usize; D];
981
982    fn next(&mut self) -> Option<Self::Item> {
983        let will_finish = overlapping_iterators(&self.forward_indexes, &self.back_indexes);
984        let item = iter(&mut self.finished, &mut self.forward_indexes, &self.shape);
985        if will_finish {
986            self.finished = true;
987        }
988        item
989    }
990
991    fn size_hint(&self) -> (usize, Option<usize>) {
992        double_ended_size_hint(
993            self.finished,
994            &self.forward_indexes,
995            &self.back_indexes,
996            &self.shape,
997        )
998    }
999}
1000
1001impl<const D: usize> DoubleEndedIterator for DoubleEndedShapeIterator<D> {
1002    fn next_back(&mut self) -> Option<Self::Item> {
1003        let will_finish = overlapping_iterators(&self.forward_indexes, &self.back_indexes);
1004        let item = iter_back(&mut self.finished, &mut self.back_indexes, &self.shape);
1005        if will_finish {
1006            self.finished = true;
1007        }
1008        item
1009    }
1010}
1011
1012// Once we hit the end we mark ourselves as finished so we're always Fused.
1013impl<const D: usize> FusedIterator for DoubleEndedShapeIterator<D> {}
1014// We can always calculate the exact number of steps remaining because the shape and indexes are
1015// private fields that are only mutated by `next` to count up.
1016impl<const D: usize> ExactSizeIterator for DoubleEndedShapeIterator<D> {}
1017
1018/**
1019 * An iterator over copies of all values in a tensor.
1020 *
1021 * First the all 0 index is iterated, then each iteration increments the rightmost index.
1022 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
1023 * this will take a single step in memory on each iteration, akin to iterating through the
1024 * flattened data of the tensor.
1025 *
1026 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
1027 * will still iterate the rightmost index allowing iteration through dimensions in a different
1028 * order to how they are stored, but no longer taking a single step in memory on each
1029 * iteration (which may be less cache friendly for the CPU).
1030 *
1031 * ```
1032 * use easy_ml::tensors::Tensor;
1033 * let tensor_0 = Tensor::from_scalar(1);
1034 * let tensor_1 = Tensor::from([("a", 7)], vec![ 1, 2, 3, 4, 5, 6, 7 ]);
1035 * let tensor_2 = Tensor::from([("a", 2), ("b", 3)], vec![
1036 *    // two rows, three columns
1037 *    1, 2, 3,
1038 *    4, 5, 6
1039 * ]);
1040 * let tensor_3 = Tensor::from([("a", 2), ("b", 1), ("c", 2)], vec![
1041 *     // two rows each a single column, stacked on top of each other
1042 *     1,
1043 *     2,
1044 *
1045 *     3,
1046 *     4
1047 * ]);
1048 * let tensor_access_0 = tensor_0.index_by([]);
1049 * let tensor_access_1 = tensor_1.index_by(["a"]);
1050 * let tensor_access_2 = tensor_2.index_by(["a", "b"]);
1051 * let tensor_access_2_rev = tensor_2.index_by(["b", "a"]);
1052 * let tensor_access_3 = tensor_3.index_by(["a", "b", "c"]);
1053 * let tensor_access_3_rev = tensor_3.index_by(["c", "b", "a"]);
1054 * assert_eq!(
1055 *     tensor_0.iter().collect::<Vec<i32>>(),
1056 *     vec![1]
1057 * );
1058 * assert_eq!(
1059 *     tensor_access_0.iter().collect::<Vec<i32>>(),
1060 *     vec![1]
1061 * );
1062 * assert_eq!(
1063 *     tensor_1.iter().collect::<Vec<i32>>(),
1064 *     vec![1, 2, 3, 4, 5, 6, 7]
1065 * );
1066 * assert_eq!(
1067 *     tensor_access_1.iter().collect::<Vec<i32>>(),
1068 *     vec![1, 2, 3, 4, 5, 6, 7]
1069 * );
1070 * assert_eq!(
1071 *     tensor_2.iter().collect::<Vec<i32>>(),
1072 *     vec![1, 2, 3, 4, 5, 6]
1073 * );
1074 * assert_eq!(
1075 *     tensor_access_2.iter().collect::<Vec<i32>>(),
1076 *     vec![1, 2, 3, 4, 5, 6]
1077 * );
1078 * assert_eq!(
1079 *     tensor_access_2.iter().rev().collect::<Vec<i32>>(),
1080 *     vec![6, 5, 4, 3, 2, 1]
1081 * );
1082 * assert_eq!(
1083 *     tensor_access_2_rev.iter().collect::<Vec<i32>>(),
1084 *     vec![1, 4, 2, 5, 3, 6]
1085 * );
1086 * assert_eq!(
1087 *     tensor_3.iter().collect::<Vec<i32>>(),
1088 *     vec![1, 2, 3, 4]
1089 * );
1090 * assert_eq!(
1091 *     tensor_3.iter().rev().collect::<Vec<i32>>(),
1092 *     vec![4, 3, 2, 1]
1093 * );
1094 * assert_eq!(
1095 *     tensor_access_3.iter().collect::<Vec<i32>>(),
1096 *     vec![1, 2, 3, 4]
1097 * );
1098 * assert_eq!(
1099 *     tensor_access_3_rev.iter().collect::<Vec<i32>>(),
1100 *     vec![1, 3, 2, 4]
1101 * );
1102 * ```
1103 */
1104#[derive(Debug)]
1105pub struct TensorIterator<'a, T, S, const D: usize> {
1106    shape_iterator: DoubleEndedShapeIterator<D>,
1107    source: &'a S,
1108    _type: PhantomData<T>,
1109}
1110
1111impl<'a, T, S, const D: usize> TensorIterator<'a, T, S, D>
1112where
1113    T: Clone,
1114    S: TensorRef<T, D>,
1115{
1116    pub fn from(source: &S) -> TensorIterator<'_, T, S, D> {
1117        TensorIterator {
1118            shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1119            source,
1120            _type: PhantomData,
1121        }
1122    }
1123
1124    /**
1125     * Constructs an iterator which also yields the indexes of each element in
1126     * this iterator.
1127     */
1128    pub fn with_index(self) -> WithIndex<Self> {
1129        WithIndex { iterator: self }
1130    }
1131}
1132
1133impl<'a, T, S, const D: usize> From<TensorIterator<'a, T, S, D>>
1134    for WithIndex<TensorIterator<'a, T, S, D>>
1135where
1136    T: Clone,
1137    S: TensorRef<T, D>,
1138{
1139    fn from(iterator: TensorIterator<'a, T, S, D>) -> Self {
1140        iterator.with_index()
1141    }
1142}
1143
1144impl<'a, T, S, const D: usize> Iterator for TensorIterator<'a, T, S, D>
1145where
1146    T: Clone,
1147    S: TensorRef<T, D>,
1148{
1149    type Item = T;
1150
1151    fn next(&mut self) -> Option<Self::Item> {
1152        // Safety: Our iterator only iterates over the correct indexes into our tensor's shape as
1153        // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
1154        // immutable reference to our tensor source, it can't be resized which ensures
1155        // DoubleEndedShapeIterator can always yield valid indexes for our iteration.
1156        self.shape_iterator
1157            .next()
1158            .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) }.clone())
1159    }
1160
1161    fn size_hint(&self) -> (usize, Option<usize>) {
1162        self.shape_iterator.size_hint()
1163    }
1164}
1165
1166impl<'a, T, S, const D: usize> DoubleEndedIterator for TensorIterator<'a, T, S, D>
1167where
1168    T: Clone,
1169    S: TensorRef<T, D>,
1170{
1171    fn next_back(&mut self) -> Option<Self::Item> {
1172        // Safety: Our iterator only iterates over the correct indexes into our tensor's shape as
1173        // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
1174        // immutable reference to our tensor source, it can't be resized which ensures
1175        // DoubleEndedShapeIterator can always yield valid indexes for our iteration.
1176        self.shape_iterator
1177            .next_back()
1178            .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) }.clone())
1179    }
1180}
1181
1182impl<'a, T, S, const D: usize> FusedIterator for TensorIterator<'a, T, S, D>
1183where
1184    T: Clone,
1185    S: TensorRef<T, D>,
1186{
1187}
1188
1189impl<'a, T, S, const D: usize> ExactSizeIterator for TensorIterator<'a, T, S, D>
1190where
1191    T: Clone,
1192    S: TensorRef<T, D>,
1193{
1194}
1195
1196impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorIterator<'a, T, S, D>>
1197where
1198    T: Clone,
1199    S: TensorRef<T, D>,
1200{
1201    type Item = ([usize; D], T);
1202
1203    fn next(&mut self) -> Option<Self::Item> {
1204        let index = self.iterator.shape_iterator.forward_indexes;
1205        self.iterator.next().map(|x| (index, x))
1206    }
1207
1208    fn size_hint(&self) -> (usize, Option<usize>) {
1209        self.iterator.size_hint()
1210    }
1211}
1212
1213impl<'a, T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorIterator<'a, T, S, D>>
1214where
1215    T: Clone,
1216    S: TensorRef<T, D>,
1217{
1218    fn next_back(&mut self) -> Option<Self::Item> {
1219        let index = self.iterator.shape_iterator.back_indexes;
1220        self.iterator.next_back().map(|x| (index, x))
1221    }
1222}
1223
1224impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorIterator<'a, T, S, D>>
1225where
1226    T: Clone,
1227    S: TensorRef<T, D>,
1228{
1229}
1230
1231impl<'a, T, S, const D: usize> ExactSizeIterator for WithIndex<TensorIterator<'a, T, S, D>>
1232where
1233    T: Clone,
1234    S: TensorRef<T, D>,
1235{
1236}
1237
1238/**
1239 * An iterator over references to all values in a tensor.
1240 *
1241 * First the all 0 index is iterated, then each iteration increments the rightmost index.
1242 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
1243 * this will take a single step in memory on each iteration, akin to iterating through the
1244 * flattened data of the tensor.
1245 *
1246 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
1247 * will still iterate the rightmost index allowing iteration through dimensions in a different
1248 * order to how they are stored, but no longer taking a single step in memory on each
1249 * iteration (which may be less cache friendly for the CPU).
1250 *
1251 * ```
1252 * use easy_ml::tensors::Tensor;
1253 * let tensor_0 = Tensor::from_scalar(1);
1254 * let tensor_1 = Tensor::from([("a", 7)], vec![ 1, 2, 3, 4, 5, 6, 7 ]);
1255 * let tensor_2 = Tensor::from([("a", 2), ("b", 3)], vec![
1256 *    // two rows, three columns
1257 *    1, 2, 3,
1258 *    4, 5, 6
1259 * ]);
1260 * let tensor_3 = Tensor::from([("a", 2), ("b", 1), ("c", 2)], vec![
1261 *     // two rows each a single column, stacked on top of each other
1262 *     1,
1263 *     2,
1264 *
1265 *     3,
1266 *     4
1267 * ]);
1268 * let tensor_access_0 = tensor_0.index_by([]);
1269 * let tensor_access_1 = tensor_1.index_by(["a"]);
1270 * let tensor_access_2 = tensor_2.index_by(["a", "b"]);
1271 * let tensor_access_2_rev = tensor_2.index_by(["b", "a"]);
1272 * let tensor_access_3 = tensor_3.index_by(["a", "b", "c"]);
1273 * let tensor_access_3_rev = tensor_3.index_by(["c", "b", "a"]);
1274 * assert_eq!(
1275 *     tensor_0.iter_reference().cloned().collect::<Vec<i32>>(),
1276 *     vec![1]
1277 * );
1278 * assert_eq!(
1279 *     tensor_access_0.iter_reference().cloned().collect::<Vec<i32>>(),
1280 *     vec![1]
1281 * );
1282 * assert_eq!(
1283 *     tensor_1.iter_reference().cloned().collect::<Vec<i32>>(),
1284 *     vec![1, 2, 3, 4, 5, 6, 7]
1285 * );
1286 * assert_eq!(
1287 *     tensor_access_1.iter_reference().cloned().collect::<Vec<i32>>(),
1288 *     vec![1, 2, 3, 4, 5, 6, 7]
1289 * );
1290 * assert_eq!(
1291 *     tensor_2.iter_reference().cloned().collect::<Vec<i32>>(),
1292 *     vec![1, 2, 3, 4, 5, 6]
1293 * );
1294 * assert_eq!(
1295 *     tensor_2.iter_reference().rev().cloned().collect::<Vec<i32>>(),
1296 *     vec![6, 5, 4, 3, 2, 1]
1297 * );
1298 * assert_eq!(
1299 *     tensor_access_2.iter_reference().cloned().collect::<Vec<i32>>(),
1300 *     vec![1, 2, 3, 4, 5, 6]
1301 * );
1302 * assert_eq!(
1303 *     tensor_access_2_rev.iter_reference().cloned().collect::<Vec<i32>>(),
1304 *     vec![1, 4, 2, 5, 3, 6]
1305 * );
1306 * assert_eq!(
1307 *     tensor_3.iter_reference().cloned().collect::<Vec<i32>>(),
1308 *     vec![1, 2, 3, 4]
1309 * );
1310 * assert_eq!(
1311 *     tensor_3.iter_reference().rev().cloned().collect::<Vec<i32>>(),
1312 *     vec![4, 3, 2, 1]
1313 * );
1314 * assert_eq!(
1315 *     tensor_access_3.iter_reference().cloned().collect::<Vec<i32>>(),
1316 *     vec![1, 2, 3, 4]
1317 * );
1318 * assert_eq!(
1319 *     tensor_access_3_rev.iter_reference().cloned().collect::<Vec<i32>>(),
1320 *     vec![1, 3, 2, 4]
1321 * );
1322 * ```
1323 */
1324#[derive(Debug)]
1325pub struct TensorReferenceIterator<'a, T, S, const D: usize> {
1326    shape_iterator: DoubleEndedShapeIterator<D>,
1327    source: &'a S,
1328    _type: PhantomData<&'a T>,
1329}
1330
1331impl<'a, T, S, const D: usize> TensorReferenceIterator<'a, T, S, D>
1332where
1333    S: TensorRef<T, D>,
1334{
1335    pub fn from(source: &S) -> TensorReferenceIterator<'_, T, S, D> {
1336        TensorReferenceIterator {
1337            shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1338            source,
1339            _type: PhantomData,
1340        }
1341    }
1342
1343    /**
1344     * Constructs an iterator which also yields the indexes of each element in
1345     * this iterator.
1346     */
1347    pub fn with_index(self) -> WithIndex<Self> {
1348        WithIndex { iterator: self }
1349    }
1350}
1351
1352impl<'a, T, S, const D: usize> From<TensorReferenceIterator<'a, T, S, D>>
1353    for WithIndex<TensorReferenceIterator<'a, T, S, D>>
1354where
1355    S: TensorRef<T, D>,
1356{
1357    fn from(iterator: TensorReferenceIterator<'a, T, S, D>) -> Self {
1358        iterator.with_index()
1359    }
1360}
1361
1362impl<'a, T, S, const D: usize> Iterator for TensorReferenceIterator<'a, T, S, D>
1363where
1364    S: TensorRef<T, D>,
1365{
1366    type Item = &'a T;
1367
1368    fn next(&mut self) -> Option<Self::Item> {
1369        // Safety: Our iterator only iterates over the correct indexes into our tensor's shape as
1370        // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
1371        // immutable reference to our tensor source, it can't be resized which ensures
1372        // DoubleEndedIterator can always yield valid indexes for our iteration.
1373        self.shape_iterator
1374            .next()
1375            .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) })
1376    }
1377
1378    fn size_hint(&self) -> (usize, Option<usize>) {
1379        self.shape_iterator.size_hint()
1380    }
1381}
1382
1383impl<'a, T, S, const D: usize> DoubleEndedIterator for TensorReferenceIterator<'a, T, S, D>
1384where
1385    S: TensorRef<T, D>,
1386{
1387    fn next_back(&mut self) -> Option<Self::Item> {
1388        // Safety: Our iterator only iterates over the correct indexes into our tensor's shape as
1389        // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
1390        // immutable reference to our tensor source, it can't be resized which ensures
1391        // DoubleEndedIterator can always yield valid indexes for our iteration.
1392        self.shape_iterator
1393            .next_back()
1394            .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) })
1395    }
1396}
1397
1398impl<'a, T, S, const D: usize> FusedIterator for TensorReferenceIterator<'a, T, S, D> where
1399    S: TensorRef<T, D>
1400{
1401}
1402
1403impl<'a, T, S, const D: usize> ExactSizeIterator for TensorReferenceIterator<'a, T, S, D> where
1404    S: TensorRef<T, D>
1405{
1406}
1407
1408impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorReferenceIterator<'a, T, S, D>>
1409where
1410    S: TensorRef<T, D>,
1411{
1412    type Item = ([usize; D], &'a T);
1413
1414    fn next(&mut self) -> Option<Self::Item> {
1415        let index = self.iterator.shape_iterator.forward_indexes;
1416        self.iterator.next().map(|x| (index, x))
1417    }
1418
1419    fn size_hint(&self) -> (usize, Option<usize>) {
1420        self.iterator.size_hint()
1421    }
1422}
1423
1424impl<'a, T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>>
1425where
1426    S: TensorRef<T, D>,
1427{
1428    fn next_back(&mut self) -> Option<Self::Item> {
1429        let index = self.iterator.shape_iterator.back_indexes;
1430        self.iterator.next_back().map(|x| (index, x))
1431    }
1432}
1433
1434impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>> where
1435    S: TensorRef<T, D>
1436{
1437}
1438
1439impl<'a, T, S, const D: usize> ExactSizeIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>> where
1440    S: TensorRef<T, D>
1441{
1442}
1443
1444/**
1445 * An iterator over mutable references to all values in a tensor.
1446 *
1447 * First the all 0 index is iterated, then each iteration increments the rightmost index.
1448 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
1449 * this will take a single step in memory on each iteration, akin to iterating through the
1450 * flattened data of the tensor.
1451 *
1452 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
1453 * will still iterate the rightmost index allowing iteration through dimensions in a different
1454 * order to how they are stored, but no longer taking a single step in memory on each
1455 * iteration (which may be less cache friendly for the CPU).
1456 *
1457 * ```
1458 * use easy_ml::tensors::Tensor;
1459 * let mut tensor = Tensor::from([("a", 7)], vec![ 1, 2, 3, 4, 5, 6, 7 ]);
1460 * let doubled = tensor.map(|x| 2 * x);
1461 * // mutating a tensor in place can also be done with Tensor::map_mut and
1462 * // Tensor::map_mut_with_index
1463 * for elem in tensor.iter_reference_mut() {
1464 *    *elem = 2 * *elem;
1465 * }
1466 * assert_eq!(
1467 *     tensor,
1468 *     doubled,
1469 * );
1470 * ```
1471 */
1472#[derive(Debug)]
1473pub struct TensorReferenceMutIterator<'a, T, S, const D: usize> {
1474    shape_iterator: DoubleEndedShapeIterator<D>,
1475    source: &'a mut S,
1476    _type: PhantomData<&'a mut T>,
1477}
1478
1479impl<'a, T, S, const D: usize> TensorReferenceMutIterator<'a, T, S, D>
1480where
1481    S: TensorMut<T, D>,
1482{
1483    pub fn from(source: &mut S) -> TensorReferenceMutIterator<'_, T, S, D> {
1484        TensorReferenceMutIterator {
1485            shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1486            source,
1487            _type: PhantomData,
1488        }
1489    }
1490
1491    /**
1492     * Constructs an iterator which also yields the indexes of each element in
1493     * this iterator.
1494     */
1495    pub fn with_index(self) -> WithIndex<Self> {
1496        WithIndex { iterator: self }
1497    }
1498}
1499
1500impl<'a, T, S, const D: usize> From<TensorReferenceMutIterator<'a, T, S, D>>
1501    for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1502where
1503    S: TensorMut<T, D>,
1504{
1505    fn from(iterator: TensorReferenceMutIterator<'a, T, S, D>) -> Self {
1506        iterator.with_index()
1507    }
1508}
1509
1510impl<'a, T, S, const D: usize> Iterator for TensorReferenceMutIterator<'a, T, S, D>
1511where
1512    S: TensorMut<T, D>,
1513{
1514    type Item = &'a mut T;
1515
1516    fn next(&mut self) -> Option<Self::Item> {
1517        self.shape_iterator.next().map(|indexes| {
1518            unsafe {
1519                // Safety: We are not allowed to give out overlapping mutable references,
1520                // but since we will always increment the counter on every call to next()
1521                // and stop when we reach the end no references will overlap.
1522                // The compiler doesn't know this, so transmute the lifetime for it.
1523                // Safety: DoubleEndedShapeIterator only iterates over the correct indexes into our
1524                // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1525                // mutability and we hold an exclusive reference to our tensor source, it can't
1526                // be resized (except by us - and we don't) which ensures DoubleEndedShapeIterator
1527                // can always yield valid indexes for our iteration.
1528                std::mem::transmute::<&mut T, &mut T>(
1529                    self.source.get_reference_unchecked_mut(indexes)
1530                )
1531            }
1532        })
1533    }
1534
1535    fn size_hint(&self) -> (usize, Option<usize>) {
1536        self.shape_iterator.size_hint()
1537    }
1538}
1539
1540impl<'a, T, S, const D: usize> DoubleEndedIterator for TensorReferenceMutIterator<'a, T, S, D>
1541where
1542    S: TensorMut<T, D>,
1543{
1544    fn next_back(&mut self) -> Option<Self::Item> {
1545        self.shape_iterator.next_back().map(|indexes| {
1546            unsafe {
1547                // Safety: We are not allowed to give out overlapping mutable references,
1548                // but since we will always increment the counter on every call to next()
1549                // and stop when we reach the end no references will overlap.
1550                // The compiler doesn't know this, so transmute the lifetime for it.
1551                // Safety: Our iterator only iterates over the correct indexes into our
1552                // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1553                // mutability and we hold an exclusive reference to our tensor source, it can't
1554                // be resized (except by us - and we don't) which ensures DoubleEndedShapeIterator
1555                // can always yield valid indexes for our iteration.
1556                std::mem::transmute::<&mut T, &mut T>(
1557                    self.source.get_reference_unchecked_mut(indexes)
1558                )
1559            }
1560        })
1561    }
1562}
1563
1564impl<'a, T, S, const D: usize> FusedIterator for TensorReferenceMutIterator<'a, T, S, D> where
1565    S: TensorMut<T, D>
1566{
1567}
1568
1569impl<'a, T, S, const D: usize> ExactSizeIterator for TensorReferenceMutIterator<'a, T, S, D> where
1570    S: TensorMut<T, D>
1571{
1572}
1573
1574impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1575where
1576    S: TensorMut<T, D>,
1577{
1578    type Item = ([usize; D], &'a mut T);
1579
1580    fn next(&mut self) -> Option<Self::Item> {
1581        let index = self.iterator.shape_iterator.forward_indexes;
1582        self.iterator.next().map(|x| (index, x))
1583    }
1584
1585    fn size_hint(&self) -> (usize, Option<usize>) {
1586        self.iterator.size_hint()
1587    }
1588}
1589
1590impl<'a, T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1591where
1592    S: TensorMut<T, D>,
1593{
1594    fn next_back(&mut self) -> Option<Self::Item> {
1595        let index = self.iterator.shape_iterator.back_indexes;
1596        self.iterator.next_back().map(|x| (index, x))
1597    }
1598}
1599
1600
1601impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>> where
1602    S: TensorMut<T, D>
1603{
1604}
1605
1606impl<'a, T, S, const D: usize> ExactSizeIterator
1607    for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1608where
1609    S: TensorMut<T, D>,
1610{
1611}
1612
1613/**
1614 * An iterator over all values in an owned tensor.
1615 *
1616 * This iterator does not clone the values, it returns the actual values stored in the tensor.
1617 * There is no such method to return `T` by value from a [TensorRef]/[TensorMut], to do
1618 * this it [replaces](std::mem::replace) the values with dummy values. Hence it can only be
1619 * created for types that implement [Default] or [ZeroOne](crate::numeric::ZeroOne)
1620 * from [Numeric](crate::numeric) which provide a means to create dummy values.
1621 *
1622 * First the all 0 index is iterated, then each iteration increments the rightmost index.
1623 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
1624 * this will take a single step in memory on each iteration, akin to iterating through the
1625 * flattened data of the tensor.
1626 *
1627 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
1628 * will still iterate the rightmost index allowing iteration through dimensions in a different
1629 * order to how they are stored, but no longer taking a single step in memory on each
1630 * iteration (which may be less cache friendly for the CPU).
1631 *
1632 * ```
1633 * use easy_ml::tensors::Tensor;
1634 *
1635 * #[derive(Debug, Default, Eq, PartialEq)]
1636 * struct NoClone(i32);
1637 *
1638 * let tensor = Tensor::from([("a", 3)], vec![ NoClone(1), NoClone(2), NoClone(3) ]);
1639 * let values = tensor.iter_owned(); // will use T::default() for dummy values
1640 * assert_eq!(vec![ NoClone(1), NoClone(2), NoClone(3) ], values.collect::<Vec<NoClone>>());
1641 * ```
1642 */
1643#[derive(Debug)]
1644pub struct TensorOwnedIterator<T, S, const D: usize> {
1645    shape_iterator: DoubleEndedShapeIterator<D>,
1646    source: S,
1647    producer: fn() -> T,
1648}
1649
1650impl<T, S, const D: usize> TensorOwnedIterator<T, S, D>
1651where
1652    S: TensorMut<T, D>,
1653{
1654    /**
1655     * Creates the TensorOwnedIterator from a source where the default values will be provided
1656     * by [Default::default]. This constructor is also used by the convenience
1657     * methods on [Tensor::iter_owned](Tensor::iter_owned) and
1658     * [TensorView::iter_owned](crate::tensors::views::TensorView::iter_owned).
1659     */
1660    pub fn from(source: S) -> TensorOwnedIterator<T, S, D>
1661    where
1662        T: Default,
1663    {
1664        TensorOwnedIterator {
1665            shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1666            source,
1667            producer: || T::default(),
1668        }
1669    }
1670
1671    /**
1672     * Creates the TensorOwnedIterator from a source where the default values will be provided
1673     * by [ZeroOne::zero](crate::numeric::ZeroOne::zero).
1674     */
1675    pub fn from_numeric(source: S) -> TensorOwnedIterator<T, S, D>
1676    where
1677        T: crate::numeric::ZeroOne,
1678    {
1679        TensorOwnedIterator {
1680            shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1681            source,
1682            producer: || T::zero(),
1683        }
1684    }
1685
1686    /**
1687     * Constructs an iterator which also yields the indexes of each element in
1688     * this iterator.
1689     */
1690    pub fn with_index(self) -> WithIndex<Self> {
1691        WithIndex { iterator: self }
1692    }
1693}
1694
1695impl<T, S, const D: usize> From<TensorOwnedIterator<T, S, D>>
1696    for WithIndex<TensorOwnedIterator<T, S, D>>
1697where
1698    S: TensorMut<T, D>,
1699{
1700    fn from(iterator: TensorOwnedIterator<T, S, D>) -> Self {
1701        iterator.with_index()
1702    }
1703}
1704
1705impl<T, S, const D: usize> Iterator for TensorOwnedIterator<T, S, D>
1706where
1707    S: TensorMut<T, D>,
1708{
1709    type Item = T;
1710
1711    fn next(&mut self) -> Option<Self::Item> {
1712        self.shape_iterator.next().map(|indexes| {
1713            let producer = self.producer;
1714            let dummy = producer();
1715            // Safety: DoubleEndedShapeIterator only iterates over the correct indexes into our
1716            // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1717            // mutability and we hold our tensor source by value, it can't be resized (except by
1718            // us - and we don't) which ensures it can always yield valid indexes for
1719            // our iteration.
1720            std::mem::replace(
1721                unsafe { self.source.get_reference_unchecked_mut(indexes) },
1722                dummy,
1723            )
1724        })
1725    }
1726
1727    fn size_hint(&self) -> (usize, Option<usize>) {
1728        self.shape_iterator.size_hint()
1729    }
1730}
1731
1732impl<T, S, const D: usize> DoubleEndedIterator for TensorOwnedIterator<T, S, D>
1733where
1734    S: TensorMut<T, D>,
1735{
1736    fn next_back(&mut self) -> Option<Self::Item> {
1737        self.shape_iterator.next_back().map(|indexes| {
1738            let producer = self.producer;
1739            let dummy = producer();
1740            // Safety: DoubleEndedShapeIterator only iterates over the correct indexes into our
1741            // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1742            // mutability and we hold our tensor source by value, it can't be resized (except by
1743            // us - and we don't) which ensures it can always yield valid indexes for
1744            // our iteration.
1745            std::mem::replace(
1746                unsafe { self.source.get_reference_unchecked_mut(indexes) },
1747                dummy,
1748            )
1749        })
1750    }
1751}
1752
1753impl<T, S, const D: usize> FusedIterator for TensorOwnedIterator<T, S, D> where S: TensorMut<T, D> {}
1754
1755impl<T, S, const D: usize> ExactSizeIterator for TensorOwnedIterator<T, S, D> where
1756    S: TensorMut<T, D>
1757{
1758}
1759
1760impl<T, S, const D: usize> Iterator for WithIndex<TensorOwnedIterator<T, S, D>>
1761where
1762    S: TensorMut<T, D>,
1763{
1764    type Item = ([usize; D], T);
1765
1766    fn next(&mut self) -> Option<Self::Item> {
1767        let index = self.iterator.shape_iterator.forward_indexes;
1768        self.iterator.next().map(|x| (index, x))
1769    }
1770
1771    fn size_hint(&self) -> (usize, Option<usize>) {
1772        self.iterator.size_hint()
1773    }
1774}
1775
1776impl<T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorOwnedIterator<T, S, D>>
1777where
1778    S: TensorMut<T, D>,
1779{
1780    fn next_back(&mut self) -> Option<Self::Item> {
1781        let index = self.iterator.shape_iterator.back_indexes;
1782        self.iterator.next_back().map(|x| (index, x))
1783    }
1784}
1785
1786impl<T, S, const D: usize> FusedIterator for WithIndex<TensorOwnedIterator<T, S, D>> where
1787    S: TensorMut<T, D>
1788{
1789}
1790
1791impl<T, S, const D: usize> ExactSizeIterator for WithIndex<TensorOwnedIterator<T, S, D>> where
1792    S: TensorMut<T, D>
1793{
1794}
1795
1796/**
1797 * A TensorTranspose makes the data in the tensor it is created from appear to be in a different
1798 * order, swapping the lengths of each named dimension to match the new order but leaving the
1799 * dimension name order unchanged.
1800 *
1801 * If you need to swap not just the order of the data but also the order of the dimension
1802 * names, use [TensorAccess] instead.
1803 *
1804 * ```
1805 * use easy_ml::tensors::Tensor;
1806 * use easy_ml::tensors::indexing::TensorTranspose;
1807 * use easy_ml::tensors::views::TensorView;
1808 * let tensor = Tensor::from([("batch", 2), ("rows", 3), ("columns", 2)], vec![
1809 *     1, 2,
1810 *     3, 4,
1811 *     5, 6,
1812 *
1813 *     7, 8,
1814 *     9, 0,
1815 *     1, 2
1816 * ]);
1817 * let transposed = TensorView::from(TensorTranspose::from(&tensor, ["batch", "columns", "rows"]));
1818 * assert_eq!(
1819 *     transposed,
1820 *     Tensor::from([("batch", 2), ("rows", 2), ("columns", 3)], vec![
1821 *         1, 3, 5,
1822 *         2, 4, 6,
1823 *
1824 *         7, 9, 1,
1825 *         8, 0, 2
1826 *     ])
1827 * );
1828 * let also_transposed = tensor.transpose_view(["batch", "columns", "rows"]);
1829 * ```
1830 */
1831#[derive(Clone)]
1832pub struct TensorTranspose<T, S, const D: usize> {
1833    access: TensorAccess<T, S, D>,
1834}
1835
1836impl<T: fmt::Debug, S: fmt::Debug, const D: usize> fmt::Debug for TensorTranspose<T, S, D> {
1837    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1838        f.debug_struct("TensorTranspose")
1839            .field("source", &self.access.source)
1840            .field("dimension_mapping", &self.access.dimension_mapping)
1841            .field("_type", &self.access._type)
1842            .finish()
1843    }
1844}
1845
1846impl<T, S, const D: usize> TensorTranspose<T, S, D>
1847where
1848    S: TensorRef<T, D>,
1849{
1850    /**
1851     * Creates a TensorTranspose which makes the data appear in the order of the
1852     * supplied dimensions. The order of the dimension names is unchanged, although their lengths
1853     * may swap.
1854     *
1855     * # Panics
1856     *
1857     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1858     * order need not match.
1859     */
1860    #[track_caller]
1861    pub fn from(source: S, dimensions: [Dimension; D]) -> TensorTranspose<T, S, D> {
1862        TensorTranspose {
1863            access: match TensorAccess::try_from(source, dimensions) {
1864                Err(error) => panic!("{}", error),
1865                Ok(success) => success,
1866            },
1867        }
1868    }
1869
1870    /**
1871     * Creates a TensorTranspose which makes the data to appear in the order of the
1872     * supplied dimensions. The order of the dimension names is unchanged, although their lengths
1873     * may swap.
1874     *
1875     * Returns Err if the set of dimensions supplied do not match the set of dimensions in this
1876     * tensor's shape.
1877     */
1878    pub fn try_from(
1879        source: S,
1880        dimensions: [Dimension; D],
1881    ) -> Result<TensorTranspose<T, S, D>, InvalidDimensionsError<D>> {
1882        TensorAccess::try_from(source, dimensions).map(|access| TensorTranspose { access })
1883    }
1884
1885    /**
1886     * The shape of this TensorTranspose appears to rearrange the data to the order of supplied
1887     * dimensions. The actual data in the underlying tensor and the order of the dimension names
1888     * on this TensorTranspose remains unchanged, although the lengths of the dimensions in this
1889     * shape of may swap compared to the source's shape.
1890     */
1891    pub fn shape(&self) -> [(Dimension, usize); D] {
1892        let names = self.access.source.view_shape();
1893        let order = self.access.shape();
1894        std::array::from_fn(|d| (names[d].0, order[d].1))
1895    }
1896
1897    pub fn source(self) -> S {
1898        self.access.source
1899    }
1900
1901    // # Safety
1902    //
1903    // Giving out a mutable reference to our source could allow it to be changed out from under us
1904    // and make our dimmension mapping invalid. However, since the source implements TensorRef
1905    // interior mutability is not allowed, so we can give out shared references without breaking
1906    // our own integrity.
1907    pub fn source_ref(&self) -> &S {
1908        &self.access.source
1909    }
1910}
1911
1912// # Safety
1913//
1914// The TensorAccess must implement TensorRef correctly, so by delegating to it without changing
1915// anything other than the order of the dimension names we expose, we implement
1916// TensoTensorRefrMut correctly as well.
1917/**
1918 * A TensorTranspose implements TensorRef, with the dimension order and indexing matching that
1919 * of the TensorTranspose shape.
1920 */
1921unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorTranspose<T, S, D>
1922where
1923    S: TensorRef<T, D>,
1924{
1925    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1926        // we didn't change the lengths of any dimension in our shape from the TensorAccess so we
1927        // can delegate to the tensor access for non named indexing here
1928        self.access.try_get_reference(indexes)
1929    }
1930
1931    fn view_shape(&self) -> [(Dimension, usize); D] {
1932        self.shape()
1933    }
1934
1935    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1936        unsafe { self.access.get_reference_unchecked(indexes) }
1937    }
1938
1939    fn data_layout(&self) -> DataLayout<D> {
1940        let data_layout = self.access.data_layout();
1941        match data_layout {
1942            DataLayout::Linear(order) => DataLayout::Linear(
1943                self.access
1944                    .dimension_mapping
1945                    .map_linear_data_layout_to_transposed(&order),
1946            ),
1947            _ => data_layout,
1948        }
1949    }
1950}
1951
1952// # Safety
1953//
1954// The TensorAccess must implement TensorMut correctly, so so by delegating to it without changing
1955// anything other than the order of the dimension names we expose, we implement, we implement
1956// TensorMut correctly as well.
1957/**
1958 * A TensorTranspose implements TensorMut, with the dimension order and indexing matching that of
1959 * the TensorTranspose shape.
1960 */
1961unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorTranspose<T, S, D>
1962where
1963    S: TensorMut<T, D>,
1964{
1965    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1966        self.access.try_get_reference_mut(indexes)
1967    }
1968
1969    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1970        unsafe { self.access.get_reference_unchecked_mut(indexes) }
1971    }
1972}
1973
1974/**
1975 * Any tensor transpose of a Displayable type implements Display
1976 *
1977 * You can control the precision of the formatting using format arguments, i.e.
1978 * `format!("{:.3}", tensor)`
1979 */
1980impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorTranspose<T, S, D>
1981where
1982    T: std::fmt::Display,
1983    S: TensorRef<T, D>,
1984{
1985    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1986        crate::tensors::display::format_view(&self, f)?;
1987        writeln!(f)?;
1988        write!(f, "Data Layout = {:?}", self.data_layout())
1989    }
1990}
1991
1992// Main test suite is in tests/ but DynamicShapeIterator isn't public API
1993// so can't import
1994#[test]
1995fn test_dynamic_shape_iterator_exact_size() {
1996    let mut iterator = DynamicShapeIterator::from(&vec![("x", 3), ("y", 2)]);
1997
1998    let a = iterator.next().cloned();
1999    assert_eq!(a, Some(vec![0, 0]));
2000
2001    let b = iterator.next().cloned();
2002    assert_eq!(b, Some(vec![0, 1]));
2003
2004    let c = iterator.next().cloned();
2005    assert_eq!(c, Some(vec![1, 0]));
2006
2007    let d = iterator.next().cloned();
2008    assert_eq!(d, Some(vec![1, 1]));
2009
2010    let e = iterator.next().cloned();
2011    assert_eq!(e, Some(vec![2, 0]));
2012
2013    let f = iterator.next().cloned();
2014    assert_eq!(f, Some(vec![2, 1]));
2015
2016    let g = iterator.next().cloned();
2017    assert_eq!(g, None);
2018}