easy_ml/tensors/
indexing.rs

1/*!
2 * # Indexing
3 *
4 * Many libraries represent tensors as N dimensional arrays, however there is often some semantic
5 * meaning to each dimension. You may have a batch of 2000 images, each 100 pixels wide and high,
6 * with each pixel representing 3 numbers for rgb values. This can be represented as a
7 * 2000 x 100 x 100 x 3 tensor, but a 4 dimensional array does not track the semantic meaning
8 * of each dimension and associated index.
9 *
10 * 6 months later you could come back to the code and forget which order the dimensions were
11 * created in, at best getting the indexes out of bounds and causing a crash in your application,
12 * and at worst silently reading the wrong data without realising. *Was it width then height or
13 * height then width?*...
14 *
15 * Easy ML moves the N dimensional array to an implementation detail, and most of its APIs work
16 * on the names of each dimension in a tensor instead of just the order. Instead of a
17 * 2000 x 100 x 100 x 3 tensor in which the last element is at [1999, 99, 99, 2], Easy ML tracks
18 * the names of the dimensions, so you have a
19 * `[("batch", 2000), ("width", 100), ("height", 100), ("rgb", 3)]` shaped tensor.
20 *
21 * This can't stop you from getting the math wrong, but confusion over which dimension
22 * means what is reduced. Tensors carry around their pairs of dimension name and length
23 * so adding a `[("batch", 2000), ("width", 100), ("height", 100), ("rgb", 3)]` shaped tensor
24 * to a `[("batch", 2000), ("height", 100), ("width", 100), ("rgb", 3)]` will fail unless you
25 * reorder one first, and you could access an element as
26 * `tensor.index_by(["batch", "width", "height", "rgb"]).get([1999, 0, 99, 3])` or
27 * `tensor.index_by(["batch", "height", "width", "rgb"]).get([1999, 99, 0, 3])` and read the same data,
28 * because you index into dimensions based on their name, not just the order they are stored in
29 * memory.
30 *
31 * Even with a name for each dimension, at some point you still need to say what order you want
32 * to index each dimension with, and this is where [`TensorAccess`] comes in. It
33 * creates a mapping from the dimension name order you want to access elements with to the order
34 * the dimensions are stored as.
35 */
36
37use crate::differentiation::{Index, Primitive, Record, RecordTensor};
38use crate::numeric::Numeric;
39use crate::tensors::dimensions;
40use crate::tensors::dimensions::DimensionMappings;
41use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
42use crate::tensors::{Dimension, Tensor};
43
44use std::error::Error;
45use std::fmt;
46use std::iter::{ExactSizeIterator, FusedIterator};
47use std::marker::PhantomData;
48
49pub use crate::matrices::iterators::WithIndex;
50
51// TODO: Iterators should use unchecked indexing once fully stress tested.
52
53/**
54 * Access to the data in a Tensor with a particular order of dimension indexing. The order
55 * affects the shape of the TensorAccess as well as the order of indexes you supply to read
56 * or write values to the tensor.
57 *
58 * See the [module level documentation](crate::tensors::indexing) for more information.
59 */
60#[derive(Clone, Debug)]
61pub struct TensorAccess<T, S, const D: usize> {
62    source: S,
63    dimension_mapping: DimensionMappings<D>,
64    _type: PhantomData<T>,
65}
66
67impl<T, S, const D: usize> TensorAccess<T, S, D>
68where
69    S: TensorRef<T, D>,
70{
71    /**
72     * Creates a TensorAccess which can be indexed in the order of the supplied dimensions
73     * to read or write values from this tensor.
74     *
75     * # Panics
76     *
77     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
78     */
79    #[track_caller]
80    pub fn from(source: S, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
81        match TensorAccess::try_from(source, dimensions) {
82            Err(error) => panic!("{}", error),
83            Ok(success) => success,
84        }
85    }
86
87    /**
88     * Creates a TensorAccess which can be indexed in the order of the supplied dimensions
89     * to read or write values from this tensor.
90     *
91     * Returns Err if the set of dimensions supplied do not match the set of dimensions in this
92     * tensor's shape.
93     */
94    pub fn try_from(
95        source: S,
96        dimensions: [Dimension; D],
97    ) -> Result<TensorAccess<T, S, D>, InvalidDimensionsError<D>> {
98        Ok(TensorAccess {
99            dimension_mapping: DimensionMappings::new(&source.view_shape(), &dimensions)
100                .ok_or_else(|| InvalidDimensionsError {
101                    actual: source.view_shape(),
102                    requested: dimensions,
103                })?,
104            source,
105            _type: PhantomData,
106        })
107    }
108
109    /**
110     * Creates a TensorAccess which is indexed in the same order as the dimensions in the view
111     * shape of the tensor it is created from.
112     *
113     * Hence if you create a TensorAccess directly from a Tensor by `from_source_order`
114     * this uses the order the dimensions were laid out in memory with.
115     *
116     * ```
117     * use easy_ml::tensors::Tensor;
118     * use easy_ml::tensors::indexing::TensorAccess;
119     * let tensor = Tensor::from([("x", 2), ("y", 2), ("z", 2)], vec![
120     *     1, 2,
121     *     3, 4,
122     *
123     *     5, 6,
124     *     7, 8
125     * ]);
126     * let xyz = tensor.index_by(["x", "y", "z"]);
127     * let also_xyz = TensorAccess::from_source_order(&tensor);
128     * let also_xyz = tensor.index();
129     * ```
130     */
131    pub fn from_source_order(source: S) -> TensorAccess<T, S, D> {
132        TensorAccess {
133            dimension_mapping: DimensionMappings::no_op_mapping(),
134            source,
135            _type: PhantomData,
136        }
137    }
138
139    /**
140     * Creates a TensorAccess which is indexed in the same order as the linear data layout
141     * dimensions in the tensor it is created from, or None if the source data layout
142     * is not linear.
143     *
144     * Hence if you use `from_memory_order` on a source that was originally big endian like
145     * [Tensor] this uses the order for efficient iteration through each step in memory
146     * when [iterating](TensorIterator).
147     */
148    pub fn from_memory_order(source: S) -> Option<TensorAccess<T, S, D>> {
149        let data_layout = match source.data_layout() {
150            DataLayout::Linear(order) => order,
151            _ => return None,
152        };
153        let shape = source.view_shape();
154        Some(TensorAccess::try_from(source, data_layout).unwrap_or_else(|_| panic!(
155            "Source implementation contained dimensions {:?} in data_layout that were not the same set as in the view_shape {:?} which breaks the contract of TensorRef",
156             data_layout, shape
157        )))
158    }
159
160    /**
161     * The shape this TensorAccess has with the dimensions mapped to the order the TensorAccess
162     * was created with, not necessarily the same order as in the underlying tensor.
163     */
164    pub fn shape(&self) -> [(Dimension, usize); D] {
165        self.dimension_mapping
166            .map_shape_to_requested(&self.source.view_shape())
167    }
168
169    pub fn source(self) -> S {
170        self.source
171    }
172
173    // # Safety
174    //
175    // Giving out a mutable reference to our source could allow it to be changed out from under us
176    // and make our dimmension mapping invalid. However, since the source implements TensorRef
177    // interior mutability is not allowed, so we can give out shared references without breaking
178    // our own integrity.
179    pub fn source_ref(&self) -> &S {
180        &self.source
181    }
182}
183
184/**
185 * An error indicating failure to create a TensorAccess because the requested dimension order
186 * does not match the shape in the source data.
187 */
188#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
189pub struct InvalidDimensionsError<const D: usize> {
190    pub actual: [(Dimension, usize); D],
191    pub requested: [Dimension; D],
192}
193
194impl<const D: usize> Error for InvalidDimensionsError<D> {}
195
196impl<const D: usize> fmt::Display for InvalidDimensionsError<D> {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        write!(
199            f,
200            "Requested dimension order: {:?} does not match the shape in the source: {:?}",
201            &self.actual, &self.requested
202        )
203    }
204}
205
206#[test]
207fn test_sync() {
208    fn assert_sync<T: Sync>() {}
209    assert_sync::<InvalidDimensionsError<3>>();
210}
211
212#[test]
213fn test_send() {
214    fn assert_send<T: Send>() {}
215    assert_send::<InvalidDimensionsError<3>>();
216}
217
218impl<T, S, const D: usize> TensorAccess<T, S, D>
219where
220    S: TensorRef<T, D>,
221{
222    /**
223     * Using the dimension ordering of the TensorAccess, gets a reference to the value at the
224     * index if the index is in range. Otherwise returns None.
225     */
226    pub fn try_get_reference(&self, indexes: [usize; D]) -> Option<&T> {
227        self.source
228            .get_reference(self.dimension_mapping.map_dimensions_to_source(&indexes))
229    }
230
231    /**
232     * Using the dimension ordering of the TensorAccess, gets a reference to the value at the
233     * index if the index is in range, panicking if the index is out of range.
234     */
235    // NOTE: Ideally `get_reference` would be used here for consistency, but that opens the
236    // minefield of TensorRef::get_reference and TensorAccess::get_ref being different signatures
237    // but the same name.
238    #[track_caller]
239    pub fn get_ref(&self, indexes: [usize; D]) -> &T {
240        match self.try_get_reference(indexes) {
241            Some(reference) => reference,
242            None => panic!(
243                "Unable to index with {:?}, Tensor dimensions are {:?}.",
244                indexes,
245                self.shape()
246            ),
247        }
248    }
249
250    /**
251     * Using the dimension ordering of the TensorAccess, gets a reference to the value at the
252     * index wihout any bounds checking.
253     *
254     * # Safety
255     *
256     * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
257     * resulting reference is not used. Valid indexes are defined as in [TensorRef]. Note that
258     * the order of the indexes needed here must match with
259     * [`TensorAccess::shape`](TensorAccess::shape) which may not neccessarily be the same
260     * as the `view_shape` of the `TensorRef` implementation this TensorAccess was created from).
261     *
262     * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
263     * [TensorRef]: TensorRef
264     */
265    // NOTE: This aliases with TensorRef::get_reference_unchecked but the TensorRef impl
266    // just calls this and the signatures match anyway, so there are no potential issues.
267    #[allow(clippy::missing_safety_doc)] // it's not missing
268    pub unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
269        unsafe {
270            self.source
271                .get_reference_unchecked(self.dimension_mapping.map_dimensions_to_source(&indexes))
272        }
273    }
274
275    /**
276     * Returns an iterator over references to the data in this TensorAccess, in the order of
277     * the TensorAccess shape.
278     */
279    pub fn iter_reference(&self) -> TensorReferenceIterator<T, TensorAccess<T, S, D>, D> {
280        TensorReferenceIterator::from(self)
281    }
282}
283
284impl<T, S, const D: usize> TensorAccess<T, S, D>
285where
286    S: TensorRef<T, D>,
287    T: Clone,
288{
289    /**
290     * Using the dimension ordering of the TensorAccess, gets a copy of the value at the
291     * index if the index is in range, panicking if the index is out of range.
292     *
293     * For a non panicking API see [`try_get_reference`](TensorAccess::try_get_reference)
294     */
295    #[track_caller]
296    pub fn get(&self, indexes: [usize; D]) -> T {
297        match self.try_get_reference(indexes) {
298            Some(reference) => reference.clone(),
299            None => panic!(
300                "Unable to index with {:?}, Tensor dimensions are {:?}.",
301                indexes,
302                self.shape()
303            ),
304        }
305    }
306
307    /**
308     * Gets a copy of the first value in this tensor.
309     * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
310     * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
311     */
312    pub fn first(&self) -> T {
313        self.iter()
314            .next()
315            .expect("Tensors always have at least 1 element")
316    }
317
318    /**
319     * Creates and returns a new tensor with all values from the original with the
320     * function applied to each.
321     *
322     * Note: mapping methods are defined on [Tensor] and
323     * [TensorView](crate::tensors::views::TensorView) directly so you don't need to create a
324     * TensorAccess unless you want to do the mapping with a different dimension order.
325     */
326    pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
327        let mapped = self.iter().map(mapping_function).collect();
328        Tensor::from(self.shape(), mapped)
329    }
330
331    /**
332     * Creates and returns a new tensor with all values from the original and
333     * the index of each value mapped by a function. The indexes passed to the mapping
334     * function always increment the rightmost index, starting at all 0s, using the dimension
335     * order that the TensorAccess is indexed by, not neccessarily the index order the
336     * original source uses.
337     *
338     * Note: mapping methods are defined on [Tensor] and
339     * [TensorView](crate::tensors::views::TensorView) directly so you don't need to create a
340     * TensorAccess unless you want to do the mapping with a different dimension order.
341     */
342    pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
343        let mapped = self
344            .iter()
345            .with_index()
346            .map(|(i, x)| mapping_function(i, x))
347            .collect();
348        Tensor::from(self.shape(), mapped)
349    }
350
351    /**
352     * Returns an iterator over copies of the data in this TensorAccess, in the order of
353     * the TensorAccess shape.
354     */
355    pub fn iter(&self) -> TensorIterator<T, TensorAccess<T, S, D>, D> {
356        TensorIterator::from(self)
357    }
358}
359
360impl<T, S, const D: usize> TensorAccess<T, S, D>
361where
362    S: TensorMut<T, D>,
363{
364    /**
365     * Using the dimension ordering of the TensorAccess, gets a mutable reference to the value at
366     * the index if the index is in range. Otherwise returns None.
367     */
368    pub fn try_get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
369        self.source
370            .get_reference_mut(self.dimension_mapping.map_dimensions_to_source(&indexes))
371    }
372
373    /**
374     * Using the dimension ordering of the TensorAccess, gets a mutable reference to the value at
375     * the index if the index is in range, panicking if the index is out of range.
376     */
377    // NOTE: Ideally `get_reference_mut` would be used here for consistency, but that opens the
378    // minefield of TensorMut::get_reference_mut and TensorAccess::get_ref_mut being different
379    // signatures but the same name.
380    #[track_caller]
381    pub fn get_ref_mut(&mut self, indexes: [usize; D]) -> &mut T {
382        match self.try_get_reference_mut(indexes) {
383            Some(reference) => reference,
384            // can't provide a better error because the borrow checker insists that returning
385            // a reference in the Some branch means our mutable borrow prevents us calling
386            // self.shape() and a bad error is better than cloning self.shape() on every call
387            None => panic!("Unable to index with {:?}", indexes),
388        }
389    }
390
391    /**
392     * Using the dimension ordering of the TensorAccess, gets a mutable reference to the value at
393     * the index wihout any bounds checking.
394     *
395     * # Safety
396     *
397     * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
398     * resulting reference is not used. Valid indexes are defined as in [TensorRef]. Note that
399     * the order of the indexes needed here must match with
400     * [`TensorAccess::shape`](TensorAccess::shape) which may not neccessarily be the same
401     * as the `view_shape` of the `TensorRef` implementation this TensorAccess was created from).
402     *
403     * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
404     * [TensorRef]: TensorRef
405     */
406    // NOTE: This aliases with TensorRef::get_reference_unchecked_mut but the TensorMut impl
407    // just calls this and the signatures match anyway, so there are no potential issues.
408    #[allow(clippy::missing_safety_doc)] // it's not missing
409    pub unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
410        unsafe {
411            self.source.get_reference_unchecked_mut(
412                self.dimension_mapping.map_dimensions_to_source(&indexes),
413            )
414        }
415    }
416
417    /**
418     * Returns an iterator over mutable references to the data in this TensorAccess, in the order
419     * of the TensorAccess shape.
420     */
421    pub fn iter_reference_mut(
422        &mut self,
423    ) -> TensorReferenceMutIterator<T, TensorAccess<T, S, D>, D> {
424        TensorReferenceMutIterator::from(self)
425    }
426}
427
428impl<T, S, const D: usize> TensorAccess<T, S, D>
429where
430    S: TensorMut<T, D>,
431    T: Clone,
432{
433    /**
434     * Applies a function to all values in the tensor, modifying
435     * the tensor in place.
436     */
437    pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
438        self.iter_reference_mut()
439            .for_each(|x| *x = mapping_function(x.clone()));
440    }
441
442    /**
443     * Applies a function to all values and each value's index in the tensor, modifying
444     * the tensor in place. The indexes passed to the mapping function always increment
445     * the rightmost index, starting at all 0s, using the dimension order that the
446     * TensorAccess is indexed by, not neccessarily the index order the original source uses.
447     */
448    pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
449        self.iter_reference_mut()
450            .with_index()
451            .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
452    }
453}
454
455impl<'a, T, S, const D: usize> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D>
456where
457    T: Numeric + Primitive,
458    S: TensorRef<(T, Index), D>,
459{
460    /**
461     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
462     * as a Record if the index is in range, panicking if the index is out of range.
463     *
464     * If you need to access all the data as records instead of just a specific index you should
465     * probably use one of the iterator APIs instead.
466     *
467     * See also: [iter_as_records](RecordTensor::iter_as_records)
468     *
469     * # Panics
470     *
471     * If the index is out of range.
472     *
473     * For a non panicking API see [try_get_as_record](TensorAccess::try_get_as_record)
474     *
475     * ```
476     * use easy_ml::differentiation::RecordTensor;
477     * use easy_ml::differentiation::WengertList;
478     * use easy_ml::tensors::Tensor;
479     *
480     * let list = WengertList::new();
481     * let X = RecordTensor::variables(
482     *     &list,
483     *     Tensor::from(
484     *         [("r", 2), ("c", 3)],
485     *         vec![
486     *             3.0, 4.0, 5.0,
487     *             1.0, 4.0, 9.0,
488     *         ]
489     *     )
490     * );
491     * let x = X.index_by(["c", "r"]).get_as_record([2, 0]);
492     * assert_eq!(x.number, 5.0);
493     * ```
494     */
495    #[track_caller]
496    pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
497        Record::from_existing(self.get(indexes), self.source.history())
498    }
499
500    /**
501     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
502     * as a Record if the index is in range. Otherwise returns None.
503     *
504     * If you need to access all the data as records instead of just a specific index you should
505     * probably use one of the iterator APIs instead.
506     *
507     * See also: [iter_as_records](RecordTensor::iter_as_records)
508     */
509    pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
510        self.try_get_reference(indexes)
511            .map(|r| Record::from_existing(r.clone(), self.source.history()))
512    }
513}
514
515impl<'a, T, S, const D: usize> TensorAccess<(T, Index), RecordTensor<'a, T, S, D>, D>
516where
517    T: Numeric + Primitive,
518    S: TensorRef<(T, Index), D>,
519{
520    /**
521     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
522     * as a Record if the index is in range, panicking if the index is out of range.
523     *
524     * If you need to access all the data as records instead of just a specific index you should
525     * probably use one of the iterator APIs instead.
526     *
527     * See also: [iter_as_records](RecordTensor::iter_as_records)
528     *
529     * # Panics
530     *
531     * If the index is out of range.
532     *
533     * For a non panicking API see [try_get_as_record](TensorAccess::try_get_as_record)
534     */
535    #[track_caller]
536    pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
537        Record::from_existing(self.get(indexes), self.source.history())
538    }
539
540    /**
541     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
542     * as a Record if the index is in range. Otherwise returns None.
543     *
544     * If you need to access all the data as records instead of just a specific index you should
545     * probably use one of the iterator APIs instead.
546     *
547     * See also: [iter_as_records](RecordTensor::iter_as_records)
548     */
549    pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
550        self.try_get_reference(indexes)
551            .map(|r| Record::from_existing(r.clone(), self.source.history()))
552    }
553}
554
555impl<'a, T, S, const D: usize> TensorAccess<(T, Index), &mut RecordTensor<'a, T, S, D>, D>
556where
557    T: Numeric + Primitive,
558    S: TensorRef<(T, Index), D>,
559{
560    /**
561     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
562     * as a Record if the index is in range, panicking if the index is out of range.
563     *
564     * If you need to access all the data as records instead of just a specific index you should
565     * probably use one of the iterator APIs instead.
566     *
567     * See also: [iter_as_records](RecordTensor::iter_as_records)
568     *
569     * # Panics
570     *
571     * If the index is out of range.
572     *
573     * For a non panicking API see [try_get_as_record](TensorAccess::try_get_as_record)
574     */
575    #[track_caller]
576    pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
577        Record::from_existing(self.get(indexes), self.source.history())
578    }
579
580    /**
581     * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
582     * as a Record if the index is in range. Otherwise returns None.
583     *
584     * If you need to access all the data as records instead of just a specific index you should
585     * probably use one of the iterator APIs instead.
586     *
587     * See also: [iter_as_records](RecordTensor::iter_as_records)
588     */
589    pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
590        self.try_get_reference(indexes)
591            .map(|r| Record::from_existing(r.clone(), self.source.history()))
592    }
593}
594
595// # Safety
596//
597// The type implementing TensorRef inside the TensorAccess must implement it correctly, so by
598// delegating to it without changing anything other than the order we index it, we implement
599// TensorRef correctly as well.
600/**
601 * A TensorAccess implements TensorRef, with the dimension order and indexing matching that of the
602 * TensorAccess shape.
603 */
604unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorAccess<T, S, D>
605where
606    S: TensorRef<T, D>,
607{
608    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
609        self.try_get_reference(indexes)
610    }
611
612    fn view_shape(&self) -> [(Dimension, usize); D] {
613        self.shape()
614    }
615
616    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
617        unsafe { self.get_reference_unchecked(indexes) }
618    }
619
620    fn data_layout(&self) -> DataLayout<D> {
621        match self.source.data_layout() {
622            // We might have reordered the view_shape but we didn't rearrange the memory or change
623            // what each dimension name refers to in memory, so the data layout remains as is.
624            DataLayout::Linear(order) => DataLayout::Linear(order),
625            DataLayout::NonLinear => DataLayout::NonLinear,
626            DataLayout::Other => DataLayout::Other,
627        }
628    }
629}
630
631// # Safety
632//
633// The type implementing TensorMut inside the TensorAccess must implement it correctly, so by
634// delegating to it without changing anything other than the order we index it, we implement
635// TensorMut correctly as well.
636/**
637 * A TensorAccess implements TensorMut, with the dimension order and indexing matching that of the
638 * TensorAccess shape.
639 */
640unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorAccess<T, S, D>
641where
642    S: TensorMut<T, D>,
643{
644    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
645        self.try_get_reference_mut(indexes)
646    }
647
648    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
649        unsafe { self.get_reference_unchecked_mut(indexes) }
650    }
651}
652
653/**
654 * Any tensor access of a Displayable type implements Display
655 *
656 * You can control the precision of the formatting using format arguments, i.e.
657 * `format!("{:.3}", tensor)`
658 */
659impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorAccess<T, S, D>
660where
661    T: std::fmt::Display,
662    S: TensorRef<T, D>,
663{
664    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
665        crate::tensors::display::format_view(&self, f)?;
666        writeln!(f)?;
667        write!(f, "Data Layout = {:?}", self.data_layout())
668    }
669}
670
671/**
672 * An iterator over all indexes in a shape.
673 *
674 * First the all 0 index is iterated, then each iteration increments the rightmost index.
675 * For a shape of `[("a", 2), ("b", 2), ("c", 2)]` this will yield indexes in order of: `[0,0,0]`,
676 * `[0,0,1]`, `[0,1,0]`, `[0,1,1]`, `[1,0,0]`, `[1,0,1]`, `[1,1,0]`, `[1,1,1]`,
677 *
678 * You don't typically need to use this directly, as tensors have iterators that iterate over
679 * them and return values to you (using this under the hood), but `ShapeIterator` can be useful
680 * if you need to hold a mutable reference to a tensor while iterating as `ShapeIterator` does
681 * not borrow the tensor. NB: if you do index into a tensor you're mutably borrowing using
682 * `ShapeIterator` directly, take care to ensure you don't accidentally reshape the tensor and
683 * continue to use indexes from `ShapeIterator` as they would then be invalid.
684 */
685#[derive(Clone, Debug)]
686pub struct ShapeIterator<const D: usize> {
687    shape: [(Dimension, usize); D],
688    indexes: [usize; D],
689    finished: bool,
690}
691
692impl<const D: usize> ShapeIterator<D> {
693    /**
694     * Constructs a ShapeIterator for a shape.
695     *
696     * If the shape has any dimensions with a length of zero, the iterator will immediately
697     * return None on [`next()`](Iterator::next).
698     */
699    pub fn from(shape: [(Dimension, usize); D]) -> ShapeIterator<D> {
700        // If we're given an invalid shape (shape input is not neccessarily going to meet the no
701        // 0 lengths contract of TensorRef because that's not actually required here), return
702        // a finished iterator
703        // Since this is an iterator over an owned shape, it's not going to become invalid later
704        // when we start iterating so this is the only check we need.
705        let starting_index_valid = shape.iter().all(|(_, l)| *l > 0);
706        ShapeIterator {
707            shape,
708            indexes: [0; D],
709            finished: !starting_index_valid,
710        }
711    }
712}
713
714impl<const D: usize> Iterator for ShapeIterator<D> {
715    type Item = [usize; D];
716
717    fn next(&mut self) -> Option<Self::Item> {
718        iter(&mut self.finished, &mut self.indexes, &self.shape)
719    }
720
721    fn size_hint(&self) -> (usize, Option<usize>) {
722        size_hint(self.finished, &self.indexes, &self.shape)
723    }
724}
725
726// Once we hit the end we mark ourselves as finished so we're always Fused.
727impl<const D: usize> FusedIterator for ShapeIterator<D> {}
728// We can always calculate the exact number of steps remaining because the shape and indexes are
729// private fields that are only mutated by `next` to count up.
730impl<const D: usize> ExactSizeIterator for ShapeIterator<D> {}
731
732/// Common index order iterator logic
733fn iter<const D: usize>(
734    finished: &mut bool,
735    indexes: &mut [usize; D],
736    shape: &[(Dimension, usize); D],
737) -> Option<[usize; D]> {
738    if *finished {
739        return None;
740    }
741
742    let value = Some(*indexes);
743
744    if D > 0 {
745        // Increment index of final dimension. In the 2D case, we iterate through a row by
746        // incrementing through every column index.
747        indexes[D - 1] += 1;
748        for d in (1..D).rev() {
749            if indexes[d] == shape[d].1 {
750                // ran to end of this dimension with our index
751                // In the 2D case, we finished indexing through every column in the row,
752                // and it's now time to move onto the next row.
753                indexes[d] = 0;
754                indexes[d - 1] += 1;
755            }
756        }
757        // Check if we ran past the final index
758        if indexes[0] == shape[0].1 {
759            *finished = true;
760        }
761    } else {
762        *finished = true;
763    }
764
765    value
766}
767
768/// Common size hint logic
769fn size_hint<const D: usize>(
770    finished: bool,
771    indexes: &[usize; D],
772    shape: &[(Dimension, usize); D],
773) -> (usize, Option<usize>) {
774    if finished {
775        return (0, Some(0));
776    }
777
778    let remaining = if D > 0 {
779        let total = dimensions::elements(shape);
780        let strides = crate::tensors::compute_strides(shape);
781        let seen = crate::tensors::get_index_direct_unchecked(indexes, &strides);
782        total - seen
783    } else {
784        1
785        // If D == 0 and we're not finished we've not returned the sole index yet so there's
786        // exactly 1 left
787    };
788
789    (remaining, Some(remaining))
790}
791
792/**
793 * An iterator over copies of all values in a tensor.
794 *
795 * First the all 0 index is iterated, then each iteration increments the rightmost index.
796 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
797 * this will take a single step in memory on each iteration, akin to iterating through the
798 * flattened data of the tensor.
799 *
800 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
801 * will still iterate the rightmost index allowing iteration through dimensions in a different
802 * order to how they are stored, but no longer taking a single step in memory on each
803 * iteration (which may be less cache friendly for the CPU).
804 *
805 * ```
806 * use easy_ml::tensors::Tensor;
807 * let tensor_0 = Tensor::from_scalar(1);
808 * let tensor_1 = Tensor::from([("a", 7)], vec![ 1, 2, 3, 4, 5, 6, 7 ]);
809 * let tensor_2 = Tensor::from([("a", 2), ("b", 3)], vec![
810 *    // two rows, three columns
811 *    1, 2, 3,
812 *    4, 5, 6
813 * ]);
814 * let tensor_3 = Tensor::from([("a", 2), ("b", 1), ("c", 2)], vec![
815 *     // two rows each a single column, stacked on top of each other
816 *     1,
817 *     2,
818 *
819 *     3,
820 *     4
821 * ]);
822 * let tensor_access_0 = tensor_0.index_by([]);
823 * let tensor_access_1 = tensor_1.index_by(["a"]);
824 * let tensor_access_2 = tensor_2.index_by(["a", "b"]);
825 * let tensor_access_2_rev = tensor_2.index_by(["b", "a"]);
826 * let tensor_access_3 = tensor_3.index_by(["a", "b", "c"]);
827 * let tensor_access_3_rev = tensor_3.index_by(["c", "b", "a"]);
828 * assert_eq!(
829 *     tensor_0.iter().collect::<Vec<i32>>(),
830 *     vec![1]
831 * );
832 * assert_eq!(
833 *     tensor_access_0.iter().collect::<Vec<i32>>(),
834 *     vec![1]
835 * );
836 * assert_eq!(
837 *     tensor_1.iter().collect::<Vec<i32>>(),
838 *     vec![1, 2, 3, 4, 5, 6, 7]
839 * );
840 * assert_eq!(
841 *     tensor_access_1.iter().collect::<Vec<i32>>(),
842 *     vec![1, 2, 3, 4, 5, 6, 7]
843 * );
844 * assert_eq!(
845 *     tensor_2.iter().collect::<Vec<i32>>(),
846 *     vec![1, 2, 3, 4, 5, 6]
847 * );
848 * assert_eq!(
849 *     tensor_access_2.iter().collect::<Vec<i32>>(),
850 *     vec![1, 2, 3, 4, 5, 6]
851 * );
852 * assert_eq!(
853 *     tensor_access_2_rev.iter().collect::<Vec<i32>>(),
854 *     vec![1, 4, 2, 5, 3, 6]
855 * );
856 * assert_eq!(
857 *     tensor_3.iter().collect::<Vec<i32>>(),
858 *     vec![1, 2, 3, 4]
859 * );
860 * assert_eq!(
861 *     tensor_access_3.iter().collect::<Vec<i32>>(),
862 *     vec![1, 2, 3, 4]
863 * );
864 * assert_eq!(
865 *     tensor_access_3_rev.iter().collect::<Vec<i32>>(),
866 *     vec![1, 3, 2, 4]
867 * );
868 * ```
869 */
870#[derive(Debug)]
871pub struct TensorIterator<'a, T, S, const D: usize> {
872    shape_iterator: ShapeIterator<D>,
873    source: &'a S,
874    _type: PhantomData<T>,
875}
876
877impl<'a, T, S, const D: usize> TensorIterator<'a, T, S, D>
878where
879    T: Clone,
880    S: TensorRef<T, D>,
881{
882    pub fn from(source: &S) -> TensorIterator<T, S, D> {
883        TensorIterator {
884            shape_iterator: ShapeIterator::from(source.view_shape()),
885            source,
886            _type: PhantomData,
887        }
888    }
889
890    /**
891     * Constructs an iterator which also yields the indexes of each element in
892     * this iterator.
893     */
894    pub fn with_index(self) -> WithIndex<Self> {
895        WithIndex { iterator: self }
896    }
897}
898
899impl<'a, T, S, const D: usize> From<TensorIterator<'a, T, S, D>>
900    for WithIndex<TensorIterator<'a, T, S, D>>
901where
902    T: Clone,
903    S: TensorRef<T, D>,
904{
905    fn from(iterator: TensorIterator<'a, T, S, D>) -> Self {
906        iterator.with_index()
907    }
908}
909
910impl<'a, T, S, const D: usize> Iterator for TensorIterator<'a, T, S, D>
911where
912    T: Clone,
913    S: TensorRef<T, D>,
914{
915    type Item = T;
916
917    fn next(&mut self) -> Option<Self::Item> {
918        // Safety: ShapeIterator only iterates over the correct indexes into our tensor's shape as
919        // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
920        // immutable reference to our tensor source, it can't be resized which ensures
921        // ShapeIterator can always yield valid indexes for our iteration.
922        self.shape_iterator
923            .next()
924            .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) }.clone())
925    }
926
927    fn size_hint(&self) -> (usize, Option<usize>) {
928        self.shape_iterator.size_hint()
929    }
930}
931
932impl<'a, T, S, const D: usize> FusedIterator for TensorIterator<'a, T, S, D>
933where
934    T: Clone,
935    S: TensorRef<T, D>,
936{
937}
938
939impl<'a, T, S, const D: usize> ExactSizeIterator for TensorIterator<'a, T, S, D>
940where
941    T: Clone,
942    S: TensorRef<T, D>,
943{
944}
945
946impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorIterator<'a, T, S, D>>
947where
948    T: Clone,
949    S: TensorRef<T, D>,
950{
951    type Item = ([usize; D], T);
952
953    fn next(&mut self) -> Option<Self::Item> {
954        let index = self.iterator.shape_iterator.indexes;
955        self.iterator.next().map(|x| (index, x))
956    }
957
958    fn size_hint(&self) -> (usize, Option<usize>) {
959        self.iterator.size_hint()
960    }
961}
962
963impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorIterator<'a, T, S, D>>
964where
965    T: Clone,
966    S: TensorRef<T, D>,
967{
968}
969
970impl<'a, T, S, const D: usize> ExactSizeIterator for WithIndex<TensorIterator<'a, T, S, D>>
971where
972    T: Clone,
973    S: TensorRef<T, D>,
974{
975}
976
977/**
978 * An iterator over references to all values in a tensor.
979 *
980 * First the all 0 index is iterated, then each iteration increments the rightmost index.
981 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
982 * this will take a single step in memory on each iteration, akin to iterating through the
983 * flattened data of the tensor.
984 *
985 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
986 * will still iterate the rightmost index allowing iteration through dimensions in a different
987 * order to how they are stored, but no longer taking a single step in memory on each
988 * iteration (which may be less cache friendly for the CPU).
989 *
990 * ```
991 * use easy_ml::tensors::Tensor;
992 * let tensor_0 = Tensor::from_scalar(1);
993 * let tensor_1 = Tensor::from([("a", 7)], vec![ 1, 2, 3, 4, 5, 6, 7 ]);
994 * let tensor_2 = Tensor::from([("a", 2), ("b", 3)], vec![
995 *    // two rows, three columns
996 *    1, 2, 3,
997 *    4, 5, 6
998 * ]);
999 * let tensor_3 = Tensor::from([("a", 2), ("b", 1), ("c", 2)], vec![
1000 *     // two rows each a single column, stacked on top of each other
1001 *     1,
1002 *     2,
1003 *
1004 *     3,
1005 *     4
1006 * ]);
1007 * let tensor_access_0 = tensor_0.index_by([]);
1008 * let tensor_access_1 = tensor_1.index_by(["a"]);
1009 * let tensor_access_2 = tensor_2.index_by(["a", "b"]);
1010 * let tensor_access_2_rev = tensor_2.index_by(["b", "a"]);
1011 * let tensor_access_3 = tensor_3.index_by(["a", "b", "c"]);
1012 * let tensor_access_3_rev = tensor_3.index_by(["c", "b", "a"]);
1013 * assert_eq!(
1014 *     tensor_0.iter_reference().cloned().collect::<Vec<i32>>(),
1015 *     vec![1]
1016 * );
1017 * assert_eq!(
1018 *     tensor_access_0.iter_reference().cloned().collect::<Vec<i32>>(),
1019 *     vec![1]
1020 * );
1021 * assert_eq!(
1022 *     tensor_1.iter_reference().cloned().collect::<Vec<i32>>(),
1023 *     vec![1, 2, 3, 4, 5, 6, 7]
1024 * );
1025 * assert_eq!(
1026 *     tensor_access_1.iter_reference().cloned().collect::<Vec<i32>>(),
1027 *     vec![1, 2, 3, 4, 5, 6, 7]
1028 * );
1029 * assert_eq!(
1030 *     tensor_2.iter_reference().cloned().collect::<Vec<i32>>(),
1031 *     vec![1, 2, 3, 4, 5, 6]
1032 * );
1033 * assert_eq!(
1034 *     tensor_access_2.iter_reference().cloned().collect::<Vec<i32>>(),
1035 *     vec![1, 2, 3, 4, 5, 6]
1036 * );
1037 * assert_eq!(
1038 *     tensor_access_2_rev.iter_reference().cloned().collect::<Vec<i32>>(),
1039 *     vec![1, 4, 2, 5, 3, 6]
1040 * );
1041 * assert_eq!(
1042 *     tensor_3.iter_reference().cloned().collect::<Vec<i32>>(),
1043 *     vec![1, 2, 3, 4]
1044 * );
1045 * assert_eq!(
1046 *     tensor_access_3.iter_reference().cloned().collect::<Vec<i32>>(),
1047 *     vec![1, 2, 3, 4]
1048 * );
1049 * assert_eq!(
1050 *     tensor_access_3_rev.iter_reference().cloned().collect::<Vec<i32>>(),
1051 *     vec![1, 3, 2, 4]
1052 * );
1053 * ```
1054 */
1055#[derive(Debug)]
1056pub struct TensorReferenceIterator<'a, T, S, const D: usize> {
1057    shape_iterator: ShapeIterator<D>,
1058    source: &'a S,
1059    _type: PhantomData<&'a T>,
1060}
1061
1062impl<'a, T, S, const D: usize> TensorReferenceIterator<'a, T, S, D>
1063where
1064    S: TensorRef<T, D>,
1065{
1066    pub fn from(source: &S) -> TensorReferenceIterator<T, S, D> {
1067        TensorReferenceIterator {
1068            shape_iterator: ShapeIterator::from(source.view_shape()),
1069            source,
1070            _type: PhantomData,
1071        }
1072    }
1073
1074    /**
1075     * Constructs an iterator which also yields the indexes of each element in
1076     * this iterator.
1077     */
1078    pub fn with_index(self) -> WithIndex<Self> {
1079        WithIndex { iterator: self }
1080    }
1081}
1082
1083impl<'a, T, S, const D: usize> From<TensorReferenceIterator<'a, T, S, D>>
1084    for WithIndex<TensorReferenceIterator<'a, T, S, D>>
1085where
1086    S: TensorRef<T, D>,
1087{
1088    fn from(iterator: TensorReferenceIterator<'a, T, S, D>) -> Self {
1089        iterator.with_index()
1090    }
1091}
1092
1093impl<'a, T, S, const D: usize> Iterator for TensorReferenceIterator<'a, T, S, D>
1094where
1095    S: TensorRef<T, D>,
1096{
1097    type Item = &'a T;
1098
1099    fn next(&mut self) -> Option<Self::Item> {
1100        // Safety: ShapeIterator only iterates over the correct indexes into our tensor's shape as
1101        // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
1102        // immutable reference to our tensor source, it can't be resized which ensures
1103        // ShapeIterator can always yield valid indexes for our iteration.
1104        self.shape_iterator
1105            .next()
1106            .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) })
1107    }
1108
1109    fn size_hint(&self) -> (usize, Option<usize>) {
1110        self.shape_iterator.size_hint()
1111    }
1112}
1113
1114impl<'a, T, S, const D: usize> FusedIterator for TensorReferenceIterator<'a, T, S, D> where
1115    S: TensorRef<T, D>
1116{
1117}
1118
1119impl<'a, T, S, const D: usize> ExactSizeIterator for TensorReferenceIterator<'a, T, S, D> where
1120    S: TensorRef<T, D>
1121{
1122}
1123
1124impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorReferenceIterator<'a, T, S, D>>
1125where
1126    S: TensorRef<T, D>,
1127{
1128    type Item = ([usize; D], &'a T);
1129
1130    fn next(&mut self) -> Option<Self::Item> {
1131        let index = self.iterator.shape_iterator.indexes;
1132        self.iterator.next().map(|x| (index, x))
1133    }
1134
1135    fn size_hint(&self) -> (usize, Option<usize>) {
1136        self.iterator.size_hint()
1137    }
1138}
1139
1140impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>> where
1141    S: TensorRef<T, D>
1142{
1143}
1144
1145impl<'a, T, S, const D: usize> ExactSizeIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>> where
1146    S: TensorRef<T, D>
1147{
1148}
1149
1150/**
1151 * An iterator over mutable references to all values in a tensor.
1152 *
1153 * First the all 0 index is iterated, then each iteration increments the rightmost index.
1154 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
1155 * this will take a single step in memory on each iteration, akin to iterating through the
1156 * flattened data of the tensor.
1157 *
1158 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
1159 * will still iterate the rightmost index allowing iteration through dimensions in a different
1160 * order to how they are stored, but no longer taking a single step in memory on each
1161 * iteration (which may be less cache friendly for the CPU).
1162 *
1163 * ```
1164 * use easy_ml::tensors::Tensor;
1165 * let mut tensor = Tensor::from([("a", 7)], vec![ 1, 2, 3, 4, 5, 6, 7 ]);
1166 * let doubled = tensor.map(|x| 2 * x);
1167 * // mutating a tensor in place can also be done with Tensor::map_mut and
1168 * // Tensor::map_mut_with_index
1169 * for elem in tensor.iter_reference_mut() {
1170 *    *elem = 2 * *elem;
1171 * }
1172 * assert_eq!(
1173 *     tensor,
1174 *     doubled,
1175 * );
1176 * ```
1177 */
1178#[derive(Debug)]
1179pub struct TensorReferenceMutIterator<'a, T, S, const D: usize> {
1180    shape_iterator: ShapeIterator<D>,
1181    source: &'a mut S,
1182    _type: PhantomData<&'a mut T>,
1183}
1184
1185impl<'a, T, S, const D: usize> TensorReferenceMutIterator<'a, T, S, D>
1186where
1187    S: TensorMut<T, D>,
1188{
1189    pub fn from(source: &mut S) -> TensorReferenceMutIterator<T, S, D> {
1190        TensorReferenceMutIterator {
1191            shape_iterator: ShapeIterator::from(source.view_shape()),
1192            source,
1193            _type: PhantomData,
1194        }
1195    }
1196
1197    /**
1198     * Constructs an iterator which also yields the indexes of each element in
1199     * this iterator.
1200     */
1201    pub fn with_index(self) -> WithIndex<Self> {
1202        WithIndex { iterator: self }
1203    }
1204}
1205
1206impl<'a, T, S, const D: usize> From<TensorReferenceMutIterator<'a, T, S, D>>
1207    for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1208where
1209    S: TensorMut<T, D>,
1210{
1211    fn from(iterator: TensorReferenceMutIterator<'a, T, S, D>) -> Self {
1212        iterator.with_index()
1213    }
1214}
1215
1216impl<'a, T, S, const D: usize> Iterator for TensorReferenceMutIterator<'a, T, S, D>
1217where
1218    S: TensorMut<T, D>,
1219{
1220    type Item = &'a mut T;
1221
1222    fn next(&mut self) -> Option<Self::Item> {
1223        self.shape_iterator.next().map(|indexes| {
1224            unsafe {
1225                // Safety: We are not allowed to give out overlapping mutable references,
1226                // but since we will always increment the counter on every call to next()
1227                // and stop when we reach the end no references will overlap.
1228                // The compiler doesn't know this, so transmute the lifetime for it.
1229                // Safety: ShapeIterator only iterates over the correct indexes into our
1230                // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1231                // mutability and we hold an exclusive reference to our tensor source, it can't
1232                // be resized (except by us - and we don't) which ensures ShapeIterator can always
1233                // yield valid indexes for our iteration.
1234                std::mem::transmute(self.source.get_reference_unchecked_mut(indexes))
1235            }
1236        })
1237    }
1238
1239    fn size_hint(&self) -> (usize, Option<usize>) {
1240        self.shape_iterator.size_hint()
1241    }
1242}
1243
1244impl<'a, T, S, const D: usize> FusedIterator for TensorReferenceMutIterator<'a, T, S, D> where
1245    S: TensorMut<T, D>
1246{
1247}
1248
1249impl<'a, T, S, const D: usize> ExactSizeIterator for TensorReferenceMutIterator<'a, T, S, D> where
1250    S: TensorMut<T, D>
1251{
1252}
1253
1254impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1255where
1256    S: TensorMut<T, D>,
1257{
1258    type Item = ([usize; D], &'a mut T);
1259
1260    fn next(&mut self) -> Option<Self::Item> {
1261        let index = self.iterator.shape_iterator.indexes;
1262        self.iterator.next().map(|x| (index, x))
1263    }
1264
1265    fn size_hint(&self) -> (usize, Option<usize>) {
1266        self.iterator.size_hint()
1267    }
1268}
1269
1270impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>> where
1271    S: TensorMut<T, D>
1272{
1273}
1274
1275impl<'a, T, S, const D: usize> ExactSizeIterator
1276    for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1277where
1278    S: TensorMut<T, D>,
1279{
1280}
1281
1282/**
1283 * An iterator over all values in an owned tensor.
1284 *
1285 * This iterator does not clone the values, it returns the actual values stored in the tensor.
1286 * There is no such method to return `T` by value from a [TensorRef]/[TensorMut], to do
1287 * this it [replaces](std::mem::replace) the values with dummy values. Hence it can only be
1288 * created for types that implement [Default] or [ZeroOne](crate::numeric::ZeroOne)
1289 * from [Numeric](crate::numeric) which provide a means to create dummy values.
1290 *
1291 * First the all 0 index is iterated, then each iteration increments the rightmost index.
1292 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
1293 * this will take a single step in memory on each iteration, akin to iterating through the
1294 * flattened data of the tensor.
1295 *
1296 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
1297 * will still iterate the rightmost index allowing iteration through dimensions in a different
1298 * order to how they are stored, but no longer taking a single step in memory on each
1299 * iteration (which may be less cache friendly for the CPU).
1300 *
1301 * ```
1302 * use easy_ml::tensors::Tensor;
1303 *
1304 * #[derive(Debug, Default, Eq, PartialEq)]
1305 * struct NoClone(i32);
1306 *
1307 * let tensor = Tensor::from([("a", 3)], vec![ NoClone(1), NoClone(2), NoClone(3) ]);
1308 * let values = tensor.iter_owned(); // will use T::default() for dummy values
1309 * assert_eq!(vec![ NoClone(1), NoClone(2), NoClone(3) ], values.collect::<Vec<NoClone>>());
1310 * ```
1311 */
1312#[derive(Debug)]
1313pub struct TensorOwnedIterator<T, S, const D: usize> {
1314    shape_iterator: ShapeIterator<D>,
1315    source: S,
1316    producer: fn() -> T,
1317}
1318
1319impl<T, S, const D: usize> TensorOwnedIterator<T, S, D>
1320where
1321    S: TensorMut<T, D>,
1322{
1323    /**
1324     * Creates the TensorOwnedIterator from a source where the default values will be provided
1325     * by [Default::default]. This constructor is also used by the convenience
1326     * methods on [Tensor::iter_owned](Tensor::iter_owned) and
1327     * [TensorView::iter_owned](crate::tensors::views::TensorView::iter_owned).
1328     */
1329    pub fn from(source: S) -> TensorOwnedIterator<T, S, D>
1330    where
1331        T: Default,
1332    {
1333        TensorOwnedIterator {
1334            shape_iterator: ShapeIterator::from(source.view_shape()),
1335            source,
1336            producer: || T::default(),
1337        }
1338    }
1339
1340    /**
1341     * Creates the TensorOwnedIterator from a source where the default values will be provided
1342     * by [ZeroOne::zero](crate::numeric::ZeroOne::zero).
1343     */
1344    pub fn from_numeric(source: S) -> TensorOwnedIterator<T, S, D>
1345    where
1346        T: crate::numeric::ZeroOne,
1347    {
1348        TensorOwnedIterator {
1349            shape_iterator: ShapeIterator::from(source.view_shape()),
1350            source,
1351            producer: || T::zero(),
1352        }
1353    }
1354
1355    /**
1356     * Constructs an iterator which also yields the indexes of each element in
1357     * this iterator.
1358     */
1359    pub fn with_index(self) -> WithIndex<Self> {
1360        WithIndex { iterator: self }
1361    }
1362}
1363
1364impl<T, S, const D: usize> From<TensorOwnedIterator<T, S, D>>
1365    for WithIndex<TensorOwnedIterator<T, S, D>>
1366where
1367    S: TensorMut<T, D>,
1368{
1369    fn from(iterator: TensorOwnedIterator<T, S, D>) -> Self {
1370        iterator.with_index()
1371    }
1372}
1373
1374impl<T, S, const D: usize> Iterator for TensorOwnedIterator<T, S, D>
1375where
1376    S: TensorMut<T, D>,
1377{
1378    type Item = T;
1379
1380    fn next(&mut self) -> Option<Self::Item> {
1381        self.shape_iterator.next().map(|indexes| {
1382            let producer = self.producer;
1383            let dummy = producer();
1384            // Safety: ShapeIterator only iterates over the correct indexes into our
1385            // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1386            // mutability and we hold our tensor source by value, it can't be resized (except by
1387            // us - and we don't) which ensures ShapeIterator can always yield valid indexes for
1388            // our iteration.
1389            std::mem::replace(
1390                unsafe { self.source.get_reference_unchecked_mut(indexes) },
1391                dummy,
1392            )
1393        })
1394    }
1395
1396    fn size_hint(&self) -> (usize, Option<usize>) {
1397        self.shape_iterator.size_hint()
1398    }
1399}
1400
1401impl<T, S, const D: usize> FusedIterator for TensorOwnedIterator<T, S, D> where S: TensorMut<T, D> {}
1402
1403impl<T, S, const D: usize> ExactSizeIterator for TensorOwnedIterator<T, S, D> where
1404    S: TensorMut<T, D>
1405{
1406}
1407
1408impl<T, S, const D: usize> Iterator for WithIndex<TensorOwnedIterator<T, S, D>>
1409where
1410    S: TensorMut<T, D>,
1411{
1412    type Item = ([usize; D], T);
1413
1414    fn next(&mut self) -> Option<Self::Item> {
1415        let index = self.iterator.shape_iterator.indexes;
1416        self.iterator.next().map(|x| (index, x))
1417    }
1418
1419    fn size_hint(&self) -> (usize, Option<usize>) {
1420        self.iterator.size_hint()
1421    }
1422}
1423
1424impl<T, S, const D: usize> FusedIterator for WithIndex<TensorOwnedIterator<T, S, D>> where
1425    S: TensorMut<T, D>
1426{
1427}
1428
1429impl<T, S, const D: usize> ExactSizeIterator for WithIndex<TensorOwnedIterator<T, S, D>> where
1430    S: TensorMut<T, D>
1431{
1432}
1433
1434/**
1435 * A TensorTranspose makes the data in the tensor it is created from appear to be in a different
1436 * order, swapping the lengths of each named dimension to match the new order but leaving the
1437 * dimension name order unchanged.
1438 *
1439 * ```
1440 * use easy_ml::tensors::Tensor;
1441 * use easy_ml::tensors::indexing::TensorTranspose;
1442 * use easy_ml::tensors::views::TensorView;
1443 * let tensor = Tensor::from([("batch", 2), ("rows", 3), ("columns", 2)], vec![
1444 *     1, 2,
1445 *     3, 4,
1446 *     5, 6,
1447 *
1448 *     7, 8,
1449 *     9, 0,
1450 *     1, 2
1451 * ]);
1452 * let transposed = TensorView::from(TensorTranspose::from(&tensor, ["batch", "columns", "rows"]));
1453 * assert_eq!(
1454 *     transposed,
1455 *     Tensor::from([("batch", 2), ("rows", 2), ("columns", 3)], vec![
1456 *         1, 3, 5,
1457 *         2, 4, 6,
1458 *
1459 *         7, 9, 1,
1460 *         8, 0, 2
1461 *     ])
1462 * );
1463 * let also_transposed = tensor.transpose_view(["batch", "columns", "rows"]);
1464 * ```
1465 */
1466#[derive(Clone)]
1467pub struct TensorTranspose<T, S, const D: usize> {
1468    access: TensorAccess<T, S, D>,
1469}
1470
1471impl<T: fmt::Debug, S: fmt::Debug, const D: usize> fmt::Debug for TensorTranspose<T, S, D> {
1472    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1473        f.debug_struct("TensorTranspose")
1474            .field("source", &self.access.source)
1475            .field("dimension_mapping", &self.access.dimension_mapping)
1476            .field("_type", &self.access._type)
1477            .finish()
1478    }
1479}
1480
1481impl<T, S, const D: usize> TensorTranspose<T, S, D>
1482where
1483    S: TensorRef<T, D>,
1484{
1485    /**
1486     * Creates a TensorTranspose which makes the data appear in the order of the
1487     * supplied dimensions. The order of the dimension names is unchanged, although their lengths
1488     * may swap.
1489     *
1490     * # Panics
1491     *
1492     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1493     * order need not match.
1494     */
1495    #[track_caller]
1496    pub fn from(source: S, dimensions: [Dimension; D]) -> TensorTranspose<T, S, D> {
1497        TensorTranspose {
1498            access: match TensorAccess::try_from(source, dimensions) {
1499                Err(error) => panic!("{}", error),
1500                Ok(success) => success,
1501            },
1502        }
1503    }
1504
1505    /**
1506     * Creates a TensorTranspose which makes the data to appear in the order of the
1507     * supplied dimensions. The order of the dimension names is unchanged, although their lengths
1508     * may swap.
1509     *
1510     * Returns Err if the set of dimensions supplied do not match the set of dimensions in this
1511     * tensor's shape.
1512     */
1513    pub fn try_from(
1514        source: S,
1515        dimensions: [Dimension; D],
1516    ) -> Result<TensorTranspose<T, S, D>, InvalidDimensionsError<D>> {
1517        TensorAccess::try_from(source, dimensions).map(|access| TensorTranspose { access })
1518    }
1519
1520    /**
1521     * The shape of this TensorTranspose appears to rearrange the data to the order of supplied
1522     * dimensions. The actual data in the underlying tensor and the order of the dimension names
1523     * on this TensorTranspose remains unchanged, although the lengths of the dimensions in this
1524     * shape of may swap compared to the source's shape.
1525     */
1526    pub fn shape(&self) -> [(Dimension, usize); D] {
1527        let names = self.access.source.view_shape();
1528        let order = self.access.shape();
1529        std::array::from_fn(|d| (names[d].0, order[d].1))
1530    }
1531
1532    pub fn source(self) -> S {
1533        self.access.source
1534    }
1535
1536    // # Safety
1537    //
1538    // Giving out a mutable reference to our source could allow it to be changed out from under us
1539    // and make our dimmension mapping invalid. However, since the source implements TensorRef
1540    // interior mutability is not allowed, so we can give out shared references without breaking
1541    // our own integrity.
1542    pub fn source_ref(&self) -> &S {
1543        &self.access.source
1544    }
1545}
1546
1547// # Safety
1548//
1549// The TensorAccess must implement TensorRef correctly, so by delegating to it without changing
1550// anything other than the order of the dimension names we expose, we implement
1551// TensoTensorRefrMut correctly as well.
1552/**
1553 * A TensorTranspose implements TensorRef, with the dimension order and indexing matching that
1554 * of the TensorTranspose shape.
1555 */
1556unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorTranspose<T, S, D>
1557where
1558    S: TensorRef<T, D>,
1559{
1560    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1561        // we didn't change the lengths of any dimension in our shape from the TensorAccess so we
1562        // can delegate to the tensor access for non named indexing here
1563        self.access.try_get_reference(indexes)
1564    }
1565
1566    fn view_shape(&self) -> [(Dimension, usize); D] {
1567        self.shape()
1568    }
1569
1570    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1571        unsafe { self.access.get_reference_unchecked(indexes) }
1572    }
1573
1574    fn data_layout(&self) -> DataLayout<D> {
1575        let data_layout = self.access.data_layout();
1576        match data_layout {
1577            DataLayout::Linear(order) => DataLayout::Linear(
1578                self.access
1579                    .dimension_mapping
1580                    .map_linear_data_layout_to_transposed(&order),
1581            ),
1582            _ => data_layout,
1583        }
1584    }
1585}
1586
1587// # Safety
1588//
1589// The TensorAccess must implement TensorMut correctly, so so by delegating to it without changing
1590// anything other than the order of the dimension names we expose, we implement, we implement
1591// TensorMut correctly as well.
1592/**
1593 * A TensorTranspose implements TensorMut, with the dimension order and indexing matching that of
1594 * the TensorTranspose shape.
1595 */
1596unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorTranspose<T, S, D>
1597where
1598    S: TensorMut<T, D>,
1599{
1600    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1601        self.access.try_get_reference_mut(indexes)
1602    }
1603
1604    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1605        unsafe { self.access.get_reference_unchecked_mut(indexes) }
1606    }
1607}
1608
1609/**
1610 * Any tensor transpose of a Displayable type implements Display
1611 *
1612 * You can control the precision of the formatting using format arguments, i.e.
1613 * `format!("{:.3}", tensor)`
1614 */
1615impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorTranspose<T, S, D>
1616where
1617    T: std::fmt::Display,
1618    S: TensorRef<T, D>,
1619{
1620    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1621        crate::tensors::display::format_view(&self, f)?;
1622        writeln!(f)?;
1623        write!(f, "Data Layout = {:?}", self.data_layout())
1624    }
1625}