easy_ml/tensors/
views.rs

1/*!
2 * Generic views into a tensor.
3 *
4 * The concept of a view into a tensor is built from the low level [TensorRef] and
5 * [TensorMut] traits which define having read and read/write access to Tensor data
6 * respectively, and the high level API implemented on the [TensorView] struct.
7 *
8 * Since a Tensor is itself a TensorRef, the APIs for the traits are a little verbose to
9 * avoid name clashes with methods defined on the Tensor and TensorView types. You should
10 * typically use TensorRef and TensorMut implementations via the TensorView struct which provides
11 * an API closely resembling Tensor.
12 */
13
14use std::marker::PhantomData;
15
16use crate::linear_algebra;
17use crate::numeric::{Numeric, NumericRef};
18use crate::tensors::dimensions;
19use crate::tensors::indexing::{
20    TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
21    TensorReferenceMutIterator, TensorTranspose,
22};
23use crate::tensors::{Dimension, Tensor};
24
25mod indexes;
26mod map;
27mod ranges;
28mod renamed;
29mod reshape;
30mod reverse;
31pub mod traits;
32mod zip;
33
34pub use indexes::*;
35pub(crate) use map::*;
36pub use ranges::*;
37pub use renamed::*;
38pub use reshape::*;
39pub use reverse::*;
40pub use zip::*;
41
42/**
43* A shared/immutable reference to a tensor (or a portion of it) of some type and number of
44* dimensions.
45*
46* # Indexing
47*
48* A TensorRef has a shape of type `[(Dimension, usize); D]`. This defines the valid indexes along
49* each dimension name and length pair from 0 inclusive to the length exclusive. If the shape was
50* `[("r", 2), ("c", 2)]` the indexes used would be `[0,0]`, `[0,1]`, `[1,0]` and `[1,1]`.
51* Although the dimension name in each pair is used for many high level APIs, for TensorRef the
52* order of dimensions is used, and the indexes (`[usize; D]`) these trait methods are called with
53* must be in the same order as the shape. In general some `[("a", a), ("b", b), ("c", c)...]`
54* shape is indexed from `[0,0,0...]` through to `[a - 1, b - 1, c - 1...]`, regardless of how the
55* data is actually laid out in memory.
56*
57* # Safety
58*
59* In order to support returning references without bounds checking in a useful way, the
60* implementing type is required to uphold several invariants that cannot be checked by
61* the compiler.
62*
63* 1 - Any valid index as described in Indexing will yield a safe reference when calling
64* `get_reference_unchecked` and `get_reference_unchecked_mut`.
65*
66* 2 - The view shape that defines which indexes are valid may not be changed by a shared reference
67* to the TensorRef implementation. ie, the tensor may not be resized while a mutable reference is
68* held to it, except by that reference.
69*
70* 3 - All dimension names in the `view_shape` must be unique.
71*
72* 4 - All dimension lengths in the `view_shape` must be non zero.
73*
74* 5 - `data_layout` must return values correctly as documented on [`DataLayout`]
75*
76* Essentially, interior mutability causes problems, since code looping through the range of valid
77* indexes in a TensorRef needs to be able to rely on that range of valid indexes not changing.
78* This is trivially the case by default since a [Tensor] does not have any form of
79* interior mutability, and therefore an iterator holding a shared reference to a Tensor prevents
80* that tensor being resized. However, a type *wrongly* implementing TensorRef could introduce
81* interior mutability by putting the Tensor in an `Arc<Mutex<>>` which would allow another thread
82* to resize a tensor while an iterator was looping through previously valid indexes on a different
83* thread. This is the same contract as
84* [`NoInteriorMutability`](crate::matrices::views::NoInteriorMutability) used in the matrix APIs.
85*
86* Note that it is okay to be able to resize any TensorRef implementation if that always requires
87* an exclusive reference to the TensorRef/Tensor, since the exclusivity prevents the above
88* scenario.
89*/
90pub unsafe trait TensorRef<T, const D: usize> {
91    /**
92     * Gets a reference to the value at the index if the index is in range. Otherwise returns None.
93     */
94    fn get_reference(&self, indexes: [usize; D]) -> Option<&T>;
95
96    /**
97     * The shape this tensor has. See [dimensions] for an overview.
98     * The product of the lengths in the pairs define how many elements are in the tensor
99     * (or the portion of it that is visible).
100     */
101    fn view_shape(&self) -> [(Dimension, usize); D];
102
103    /**
104     * Gets a reference to the value at the index without doing any bounds checking. For a safe
105     * alternative see [get_reference](TensorRef::get_reference).
106     *
107     * # Safety
108     *
109     * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
110     * resulting reference is not used. Valid indexes are defined as in [TensorRef].
111     *
112     * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
113     * [TensorRef]: TensorRef
114     */
115    #[allow(clippy::missing_safety_doc)] // it's not missing
116    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T;
117
118    /**
119     * The way the data in this tensor is laid out in memory. In particular,
120     * [`Linear`](DataLayout) has several requirements on what is returned that must be upheld
121     * by implementations of this trait.
122     *
123     * For a [Tensor] this would return `DataLayout::Linear` in the same order as the `view_shape`,
124     * since the data is in a single line and the `view_shape` is in most significant dimension
125     * to least. Many views however, create shapes that do not correspond to linear memory, either
126     * by combining non array data with tensors, or hiding dimensions. Similarly, a row major
127     * Tensor might be transposed to a column major TensorView, so the view shape could be reversed
128     * compared to the order of significance of the dimensions in memory.
129     *
130     * In general, an array of dimension names matching the view shape order is big endian, and
131     * will be iterated through efficiently by [TensorIterator], and an array of dimension names
132     * in reverse of the view shape is in little endian order.
133     *
134     * The implementation of this trait must ensure that if it returns `DataLayout::Linear`
135     * the set of dimension names are returned in the order of most significant to least and match
136     * the set of the dimension names returned by `view_shape`.
137     */
138    fn data_layout(&self) -> DataLayout<D>;
139}
140
141/**
142 * How the data in the tensor is laid out in memory.
143 */
144#[derive(Clone, Debug, Eq, PartialEq)]
145pub enum DataLayout<const D: usize> {
146    /**
147     * The data is laid out in linear storage in memory, such that we could take a slice over the
148     * entire data specified by our `view_shape`.
149     *
150     * The `D` length array specifies the dimensions in the `view_shape` in the order of most
151     * significant dimension (in memory) to least.
152     *
153     * In general, an array of dimension names in the same order as the view shape is big endian
154     * order (implying the order of dimensions in the view shape is most significant to least),
155     * and will be iterated through efficiently by [TensorIterator], and an array of
156     * dimension names in reverse of the view shape is in little endian order
157     * (implying the order of dimensions in the view shape is least significant to most).
158     *
159     * In memory, the data will have some order such that if we want repeatedly take 1 step
160     * through memory from the first value to the last there will be a most significant dimension
161     * that always goes up, through to a least significant dimension with the most rapidly varying
162     * index.
163     *
164     * In most of Easy ML's Tensors, the `view_shape` dimensions would already be in the order of
165     * most significant dimension to least (since [Tensor] stores its data in big endian order),
166     * so the list of dimension names will just increment match the order of the view shape.
167     *
168     * For example, a tensor with a view shape of `[("batch", 2), ("row", 2), ("column", 3)]` that
169     * stores its data in most significant to least would be (indexed) like:
170     * ```ignore
171     * [
172     *   (0,0,0), (0,0,1), (0,0,2),
173     *   (0,1,0), (0,1,1), (0,1,2),
174     *   (1,0,0), (1,0,1), (1,0,2),
175     *   (1,1,0), (1,1,1), (1,1,2)
176     * ]
177     * ```
178     *
179     * To take one step in memory, we would increment the right most dimension index ("column"),
180     * counting our way up through to the left most dimension index ("batch"). If we changed
181     * this tensor to `[("column", 3), ("row", 2), ("batch", 2)]` so that the `view_shape` was
182     * swapped to least significant dimension to most but the data remained in the same order,
183     * our tensor would still have a DataLayout of `Linear(["batch", "row", "column"])`, since
184     * the indexes in the transposed `view_shape` correspond to an actual memory layout that's
185     * completely reversed. Alternatively, you could say we reversed the view shape but the
186     * memory layout never changed:
187     * ```ignore
188     * [
189     *   (0,0,0), (1,0,0), (2,0,0),
190     *   (0,1,0), (1,1,0), (2,1,0),
191     *   (0,0,1), (1,0,1), (2,0,1),
192     *   (0,1,1), (1,1,1), (2,1,1)
193     * ]
194     * ```
195     *
196     * To take one step in memory, we now need to increment the left most dimension index (on
197     * the view shape) ("column"), counting our way in reverse to the right most dimension
198     * index ("batch").
199     *
200     * That `["batch", "row", "column"]` is also exactly the order you would need to swap your
201     * dimensions on the `view_shape` to get back to most significant to least. A [TensorAccess]
202     * could reorder the tensor by this array order to get back to most significant to least
203     * ordering on the `view_shape` in order to iterate through the data efficiently.
204     */
205    Linear([Dimension; D]),
206    /**
207     * The data is not laid out in linear storage in memory.
208     */
209    NonLinear,
210    /**
211     * The data is not laid out in a linear or non linear way, or we don't know how it's laid
212     * out.
213     */
214    Other,
215}
216
217/**
218 * A unique/mutable reference to a tensor (or a portion of it) of some type.
219 *
220 * # Safety
221 *
222 * See [TensorRef].
223 */
224pub unsafe trait TensorMut<T, const D: usize>: TensorRef<T, D> {
225    /**
226     * Gets a mutable reference to the value at the index, if the index is in range. Otherwise
227     * returns None.
228     */
229    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T>;
230
231    /**
232     * Gets a mutable reference to the value at the index without doing any bounds checking.
233     * For a safe alternative see [get_reference_mut](TensorMut::get_reference_mut).
234     *
235     * # Safety
236     *
237     * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
238     * resulting reference is not used. Valid indexes are defined as in [TensorRef].
239     *
240     * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
241     * [TensorRef]: TensorRef
242     */
243    #[allow(clippy::missing_safety_doc)] // it's not missing
244    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T;
245}
246
247/**
248 * A view into some or all of a tensor.
249 *
250 * A TensorView has a similar relationship to a [`Tensor`] as a `&str` has to a `String`, or an
251 * array slice to an array. A TensorView cannot resize its source, and may span only a portion
252 * of the source Tensor in each dimension.
253 *
254 * However a TensorView is generic not only over the type of the data in the Tensor,
255 * but also over the way the Tensor is 'sliced' and the two are orthogonal to each other.
256 *
257 * TensorView closely mirrors the API of Tensor.
258 * Methods that create a new tensor do not return a TensorView, they return a Tensor.
259 */
260#[derive(Clone)]
261pub struct TensorView<T, S, const D: usize> {
262    source: S,
263    _type: PhantomData<T>,
264}
265
266/**
267 * TensorView methods which require only read access via a [TensorRef] source.
268 */
269impl<T, S, const D: usize> TensorView<T, S, D>
270where
271    S: TensorRef<T, D>,
272{
273    /**
274     * Creates a TensorView from a source of some type.
275     *
276     * The lifetime of the source determines the lifetime of the TensorView created. If the
277     * TensorView is created from a reference to a Tensor, then the TensorView cannot live
278     * longer than the Tensor referenced.
279     */
280    pub fn from(source: S) -> TensorView<T, S, D> {
281        TensorView {
282            source,
283            _type: PhantomData,
284        }
285    }
286
287    /**
288     * Consumes the tensor view, yielding the source it was created from.
289     */
290    pub fn source(self) -> S {
291        self.source
292    }
293
294    /**
295     * Gives a reference to the tensor view's source.
296     */
297    pub fn source_ref(&self) -> &S {
298        &self.source
299    }
300
301    /**
302     * Gives a mutable reference to the tensor view's source.
303     */
304    pub fn source_ref_mut(&mut self) -> &mut S {
305        &mut self.source
306    }
307
308    /**
309     * The shape of this tensor view. Since Tensors are named Tensors, their shape is not just a
310     * list of length along each dimension, but instead a list of pairs of names and lengths.
311     *
312     * Note that a TensorView may have a shape which is different than the Tensor it is providing
313     * access to the data of. The TensorView might be [masking dimensions](TensorIndex) or
314     * elements from the shape, or exposing [false ones](TensorExpansion).
315     *
316     * See also
317     * - [dimensions]
318     * - [indexing](crate::tensors::indexing)
319     */
320    pub fn shape(&self) -> [(Dimension, usize); D] {
321        self.source.view_shape()
322    }
323
324    /**
325     * Returns the length of the dimension name provided, if one is present in the tensor view.
326     *
327     * See also
328     * - [dimensions]
329     * - [indexing](crate::tensors::indexing)
330     */
331    pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
332        dimensions::length_of(&self.source.view_shape(), dimension)
333    }
334
335    /**
336     * Returns the last index of the dimension name provided, if one is present in the tensor view.
337     *
338     * This is always 1 less than the length, the 'index' in this sense is based on what the
339     * Tensor's shape is, not any implementation index.
340     *
341     * See also
342     * - [dimensions]
343     * - [indexing](crate::tensors::indexing)
344     */
345    pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
346        dimensions::last_index_of(&self.source.view_shape(), dimension)
347    }
348
349    /**
350     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
351     * to read values from this tensor view.
352     *
353     * # Panics
354     *
355     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
356     */
357    #[track_caller]
358    pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &S, D> {
359        TensorAccess::from(&self.source, dimensions)
360    }
361
362    /**
363     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
364     * to read or write values from this tensor view.
365     *
366     * # Panics
367     *
368     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
369     */
370    #[track_caller]
371    pub fn index_by_mut(&mut self, dimensions: [Dimension; D]) -> TensorAccess<T, &mut S, D> {
372        TensorAccess::from(&mut self.source, dimensions)
373    }
374
375    /**
376     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
377     * to read or write values from this tensor view.
378     *
379     * # Panics
380     *
381     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
382     */
383    #[track_caller]
384    pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
385        TensorAccess::from(self.source, dimensions)
386    }
387
388    /**
389     * Creates a TensorAccess which will index into the dimensions of the source this TensorView
390     * was created with in the same order as they were declared.
391     * See [TensorAccess::from_source_order].
392     */
393    pub fn index(&self) -> TensorAccess<T, &S, D> {
394        TensorAccess::from_source_order(&self.source)
395    }
396
397    /**
398     * Creates a TensorAccess which will index into the dimensions of the source this TensorView
399     * was created with in the same order as they were declared. The TensorAccess mutably borrows
400     * the source, and can therefore mutate it if it implements TensorMut.
401     * See [TensorAccess::from_source_order].
402     */
403    pub fn index_mut(&mut self) -> TensorAccess<T, &mut S, D> {
404        TensorAccess::from_source_order(&mut self.source)
405    }
406
407    /**
408     * Creates a TensorAccess which will index into the dimensions this Tensor was
409     * created with in the same order as they were provided. The TensorAccess takes ownership
410     * of the Tensor, and can therefore mutate it. The TensorAccess takes ownership of
411     * the source, and can therefore mutate it if it implements TensorMut.
412     * See [TensorAccess::from_source_order].
413     */
414    pub fn index_owned(self) -> TensorAccess<T, S, D> {
415        TensorAccess::from_source_order(self.source)
416    }
417
418    /**
419     * Returns an iterator over references to the data in this TensorView.
420     */
421    pub fn iter_reference(&self) -> TensorReferenceIterator<T, S, D> {
422        TensorReferenceIterator::from(&self.source)
423    }
424
425    /**
426     * Returns a TensorView with the dimension names of the shape renamed to the provided
427     * dimensions. The data of this tensor and the dimension lengths and order remain unchanged.
428     *
429     * This is a shorthand for constructing the TensorView from this TensorView. See
430     * [`Tensor::rename_view`](Tensor::rename_view).
431     *
432     * # Panics
433     *
434     * - If a dimension name is not unique
435     */
436    #[track_caller]
437    pub fn rename_view(
438        &self,
439        dimensions: [Dimension; D],
440    ) -> TensorView<T, TensorRename<T, &S, D>, D> {
441        TensorView::from(TensorRename::from(&self.source, dimensions))
442    }
443
444    /**
445     * Returns a TensorView with the dimensions changed to the provided shape without moving any
446     * data around. The new Tensor may also have a different number of dimensions.
447     *
448     * This is a shorthand for constructing the TensorView from this TensorView. See
449     * [`Tensor::reshape_view`](Tensor::reshape_view).
450     *
451     * # Panics
452     *
453     * - If the number of provided elements in the new shape does not match the product of the
454     * dimension lengths in the existing tensor's shape.
455     * - If a dimension name is not unique
456     */
457    pub fn reshape<const D2: usize>(
458        &self,
459        shape: [(Dimension, usize); D2],
460    ) -> TensorView<T, TensorReshape<T, &S, D, D2>, D2> {
461        TensorView::from(TensorReshape::from(&self.source, shape))
462    }
463
464    /**
465     * Returns a TensorView with the dimensions changed to the provided shape without moving any
466     * data around. The new Tensor may also have a different number of dimensions. The
467     * TensorReshape mutably borrows the source, and can therefore mutate it if it implements
468     * TensorMut.
469     *
470     * This is a shorthand for constructing the TensorView from this TensorView. See
471     * [`Tensor::reshape_view`](Tensor::reshape_view).
472     *
473     * # Panics
474     *
475     * - If the number of provided elements in the new shape does not match the product of the
476     * dimension lengths in the existing tensor's shape.
477     * - If a dimension name is not unique
478     */
479    pub fn reshape_mut<const D2: usize>(
480        &mut self,
481        shape: [(Dimension, usize); D2],
482    ) -> TensorView<T, TensorReshape<T, &mut S, D, D2>, D2> {
483        TensorView::from(TensorReshape::from(&mut self.source, shape))
484    }
485
486    /**
487     * Returns a TensorView with the dimensions changed to the provided shape without moving any
488     * data around. The new Tensor may also have a different number of dimensions. The
489     * TensorReshape takes ownership of the source, and can therefore mutate it if it implements
490     * TensorMut.
491     *
492     * This is a shorthand for constructing the TensorView from this TensorView. See
493     * [`Tensor::reshape_view`](Tensor::reshape_view).
494     *
495     * # Panics
496     *
497     * - If the number of provided elements in the new shape does not match the product of the
498     * dimension lengths in the existing tensor's shape.
499     * - If a dimension name is not unique
500     */
501    pub fn reshape_owned<const D2: usize>(
502        self,
503        shape: [(Dimension, usize); D2],
504    ) -> TensorView<T, TensorReshape<T, S, D, D2>, D2> {
505        TensorView::from(TensorReshape::from(self.source, shape))
506    }
507
508    /**
509     * Given the dimension name, returns a view of this tensor reshaped to one dimension
510     * with a length equal to the number of elements in this tensor.
511     */
512    pub fn flatten(&self, dimension: Dimension) -> TensorView<T, TensorReshape<T, &S, D, 1>, 1> {
513        self.reshape([(dimension, dimensions::elements(&self.shape()))])
514    }
515
516    /**
517     * Given the dimension name, returns a view of this tensor reshaped to one dimension
518     * with a length equal to the number of elements in this tensor.
519     */
520    pub fn flatten_mut(
521        &mut self,
522        dimension: Dimension,
523    ) -> TensorView<T, TensorReshape<T, &mut S, D, 1>, 1> {
524        self.reshape_mut([(dimension, dimensions::elements(&self.shape()))])
525    }
526
527    /**
528     * Given the dimension name, returns a view of this tensor reshaped to one dimension
529     * with a length equal to the number of elements in this tensor.
530     *
531     * If you intend to query the tensor a lot after creating the view, consider
532     * using [flatten_into_tensor](TensorView::flatten_into_tensor) instead as it will have
533     * less overhead to index after creation.
534     */
535    pub fn flatten_owned(
536        self,
537        dimension: Dimension,
538    ) -> TensorView<T, TensorReshape<T, S, D, 1>, 1> {
539        let length = dimensions::elements(&self.shape());
540        self.reshape_owned([(dimension, length)])
541    }
542
543    /**
544     * Given the dimension name, returns a new tensor reshaped to one dimension
545     * with a length equal to the number of elements in this tensor and all elements
546     * copied into the new tensor.
547     */
548    pub fn flatten_into_tensor(self, dimension: Dimension) -> Tensor<T, 1>
549    where
550        T: Clone,
551    {
552        // TODO: Want a specialisation here to use Tensor::reshape_owned when we
553        // know that our source type is Tensor
554        let length = dimensions::elements(&self.shape());
555        Tensor::from([(dimension, length)], self.iter().collect())
556    }
557
558    /**
559     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
560     * range from view. Error cases are documented on [TensorRange].
561     *
562     * This is a shorthand for constructing the TensorView from this TensorView. See
563     * [`Tensor::range`](Tensor::range).
564     */
565    pub fn range<R, const P: usize>(
566        &self,
567        ranges: [(Dimension, R); P],
568    ) -> Result<TensorView<T, TensorRange<T, &S, D>, D>, IndexRangeValidationError<D, P>>
569    where
570        R: Into<IndexRange>,
571    {
572        TensorRange::from(&self.source, ranges).map(|range| TensorView::from(range))
573    }
574
575    /**
576     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
577     * range from view. Error cases are documented on [TensorRange]. The TensorRange
578     * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
579     *
580     * This is a shorthand for constructing the TensorView from this TensorView. See
581     * [`Tensor::range`](Tensor::range).
582     */
583    pub fn range_mut<R, const P: usize>(
584        &mut self,
585        ranges: [(Dimension, R); P],
586    ) -> Result<TensorView<T, TensorRange<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
587    where
588        R: Into<IndexRange>,
589    {
590        TensorRange::from(&mut self.source, ranges).map(|range| TensorView::from(range))
591    }
592
593    /**
594     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
595     * range from view. Error cases are documented on [TensorRange]. The TensorRange
596     * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
597     *
598     * This is a shorthand for constructing the TensorView from this TensorView. See
599     * [`Tensor::range`](Tensor::range).
600     */
601    pub fn range_owned<R, const P: usize>(
602        self,
603        ranges: [(Dimension, R); P],
604    ) -> Result<TensorView<T, TensorRange<T, S, D>, D>, IndexRangeValidationError<D, P>>
605    where
606        R: Into<IndexRange>,
607    {
608        TensorRange::from(self.source, ranges).map(|range| TensorView::from(range))
609    }
610
611    /**
612     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
613     * range from view. Error cases are documented on [TensorMask].
614     *
615     * This is a shorthand for constructing the TensorView from this TensorView. See
616     * [`Tensor::mask`](Tensor::mask).
617     */
618    pub fn mask<R, const P: usize>(
619        &self,
620        masks: [(Dimension, R); P],
621    ) -> Result<TensorView<T, TensorMask<T, &S, D>, D>, IndexRangeValidationError<D, P>>
622    where
623        R: Into<IndexRange>,
624    {
625        TensorMask::from(&self.source, masks).map(|mask| TensorView::from(mask))
626    }
627
628    /**
629     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
630     * range from view. Error cases are documented on [TensorMask]. The TensorMask
631     * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
632     *
633     * This is a shorthand for constructing the TensorView from this TensorView. See
634     * [`Tensor::mask`](Tensor::mask).
635     */
636    pub fn mask_mut<R, const P: usize>(
637        &mut self,
638        masks: [(Dimension, R); P],
639    ) -> Result<TensorView<T, TensorMask<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
640    where
641        R: Into<IndexRange>,
642    {
643        TensorMask::from(&mut self.source, masks).map(|mask| TensorView::from(mask))
644    }
645
646    /**
647     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
648     * range from view. Error cases are documented on [TensorMask]. The TensorMask
649     * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
650     *
651     * This is a shorthand for constructing the TensorView from this TensorView. See
652     * [`Tensor::mask`](Tensor::mask).
653     */
654    pub fn mask_owned<R, const P: usize>(
655        self,
656        masks: [(Dimension, R); P],
657    ) -> Result<TensorView<T, TensorMask<T, S, D>, D>, IndexRangeValidationError<D, P>>
658    where
659        R: Into<IndexRange>,
660    {
661        TensorMask::from(self.source, masks).map(|mask| TensorView::from(mask))
662    }
663
664    /**
665     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
666     * order. The data of this tensor and the dimension lengths remain unchanged.
667     *
668     * This is a shorthand for constructing the TensorView from this TensorView. See
669     * [`Tensor::reverse`](Tensor::reverse).
670     *
671     * # Panics
672     *
673     * - If a dimension name is not in the tensor's shape or is repeated.
674     */
675    #[track_caller]
676    pub fn reverse(&self, dimensions: &[Dimension]) -> TensorView<T, TensorReverse<T, &S, D>, D> {
677        TensorView::from(TensorReverse::from(&self.source, dimensions))
678    }
679
680    /**
681     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
682     * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
683     * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
684     *
685     * This is a shorthand for constructing the TensorView from this TensorView. See
686     * [`Tensor::reverse`](Tensor::reverse).
687     *
688     * # Panics
689     *
690     * - If a dimension name is not in the tensor's shape or is repeated.
691     */
692    #[track_caller]
693    pub fn reverse_mut(
694        &mut self,
695        dimensions: &[Dimension],
696    ) -> TensorView<T, TensorReverse<T, &mut S, D>, D> {
697        TensorView::from(TensorReverse::from(&mut self.source, dimensions))
698    }
699
700    /**
701     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
702     * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
703     * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
704     *
705     * This is a shorthand for constructing the TensorView from this TensorView. See
706     * [`Tensor::reverse`](Tensor::reverse).
707     *
708     * # Panics
709     *
710     * - If a dimension name is not in the tensor's shape or is repeated.
711     */
712    #[track_caller]
713    pub fn reverse_owned(
714        self,
715        dimensions: &[Dimension],
716    ) -> TensorView<T, TensorReverse<T, S, D>, D> {
717        TensorView::from(TensorReverse::from(self.source, dimensions))
718    }
719
720    /**
721     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
722     * mapped by a function. The value pairs are not copied for you, if you're using `Copy` types
723     * or need to clone the values anyway, you can use
724     * [`TensorView::elementwise`](TensorView::elementwise) instead.
725     *
726     * # Generics
727     *
728     * This method can be called with any right hand side that can be converted to a TensorView,
729     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
730     *
731     * # Panics
732     *
733     * If the two tensors have different shapes.
734     */
735    #[track_caller]
736    pub fn elementwise_reference<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
737    where
738        I: Into<TensorView<T, S2, D>>,
739        S2: TensorRef<T, D>,
740        M: Fn(&T, &T) -> T,
741    {
742        self.elementwise_reference_less_generic(rhs.into(), mapping_function)
743    }
744
745    /**
746     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
747     * mapped by a function. The mapping function also receives each index corresponding to the
748     * value pairs. The value pairs are not copied for you, if you're using `Copy` types
749     * or need to clone the values anyway, you can use
750     * [`TensorView::elementwise_with_index`](TensorView::elementwise_with_index) instead.
751     *
752     * # Generics
753     *
754     * This method can be called with any right hand side that can be converted to a TensorView,
755     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
756     *
757     * # Panics
758     *
759     * If the two tensors have different shapes.
760     */
761    #[track_caller]
762    pub fn elementwise_reference_with_index<S2, I, M>(
763        &self,
764        rhs: I,
765        mapping_function: M,
766    ) -> Tensor<T, D>
767    where
768        I: Into<TensorView<T, S2, D>>,
769        S2: TensorRef<T, D>,
770        M: Fn([usize; D], &T, &T) -> T,
771    {
772        self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
773    }
774
775    #[track_caller]
776    fn elementwise_reference_less_generic<S2, M>(
777        &self,
778        rhs: TensorView<T, S2, D>,
779        mapping_function: M,
780    ) -> Tensor<T, D>
781    where
782        S2: TensorRef<T, D>,
783        M: Fn(&T, &T) -> T,
784    {
785        let left_shape = self.shape();
786        let right_shape = rhs.shape();
787        if left_shape != right_shape {
788            panic!(
789                "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
790                left_shape, right_shape
791            );
792        }
793        let mapped = self
794            .iter_reference()
795            .zip(rhs.iter_reference())
796            .map(|(x, y)| mapping_function(x, y))
797            .collect();
798        Tensor::from(left_shape, mapped)
799    }
800
801    #[track_caller]
802    fn elementwise_reference_less_generic_with_index<S2, M>(
803        &self,
804        rhs: TensorView<T, S2, D>,
805        mapping_function: M,
806    ) -> Tensor<T, D>
807    where
808        S2: TensorRef<T, D>,
809        M: Fn([usize; D], &T, &T) -> T,
810    {
811        let left_shape = self.shape();
812        let right_shape = rhs.shape();
813        if left_shape != right_shape {
814            panic!(
815                "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
816                left_shape, right_shape
817            );
818        }
819        // we just checked both shapes were the same, so we don't need to propagate indexes
820        // for both tensors because they'll be identical
821        let mapped = self
822            .iter_reference()
823            .with_index()
824            .zip(rhs.iter_reference())
825            .map(|((i, x), y)| mapping_function(i, x, y))
826            .collect();
827        Tensor::from(left_shape, mapped)
828    }
829
830    /**
831     * Returns a TensorView which makes the order of the data in this tensor appear to be in
832     * a different order. The order of the dimension names is unchanged, although their lengths
833     * may swap.
834     *
835     * This is a shorthand for constructing the TensorView from this TensorView.
836     *
837     * See also: [transpose](TensorView::transpose), [TensorTranspose]
838     *
839     * # Panics
840     *
841     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
842     * order need not match.
843     */
844    pub fn transpose_view(
845        &self,
846        dimensions: [Dimension; D],
847    ) -> TensorView<T, TensorTranspose<T, &S, D>, D> {
848        TensorView::from(TensorTranspose::from(&self.source, dimensions))
849    }
850
851    /// Unverified constructor for interal use when we know the dimensions/data/strides are
852    /// the same as the existing instance and don't need reverification
853    pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> TensorView<T, Tensor<T, D>, D> {
854        let shape = self.shape();
855        let strides = crate::tensors::compute_strides(&shape);
856        TensorView::from(Tensor {
857            data,
858            shape,
859            strides,
860        })
861    }
862}
863
864impl<T, S, const D: usize> TensorView<T, S, D>
865where
866    S: TensorMut<T, D>,
867{
868    /**
869     * Returns an iterator over mutable references to the data in this TensorView.
870     */
871    pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<T, S, D> {
872        TensorReferenceMutIterator::from(&mut self.source)
873    }
874
875    /**
876     * Creates an iterator over the values in this TensorView.
877     */
878    pub fn iter_owned(self) -> TensorOwnedIterator<T, S, D>
879    where
880        T: Default,
881    {
882        TensorOwnedIterator::from(self.source())
883    }
884}
885
886/**
887 * TensorView methods which require only read access via a [TensorRef] source
888 * and a clonable type.
889 */
890impl<T, S, const D: usize> TensorView<T, S, D>
891where
892    T: Clone,
893    S: TensorRef<T, D>,
894{
895    /**
896     * Gets a copy of the first value in this tensor.
897     * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
898     * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
899     */
900    pub fn first(&self) -> T {
901        self.iter()
902            .next()
903            .expect("Tensors always have at least 1 element")
904    }
905
906    /**
907     * Returns a new Tensor which has the same data as this tensor, but with the order of data
908     * changed. The order of the dimension names is unchanged, although their lengths may swap.
909     *
910     * For example, with a `[("x", x), ("y", y)]` tensor you could call
911     * `transpose(["y", "x"])` which would return a new tensor with a shape of
912     * `[("x", y), ("y", x)]` where every (x,y) of its data corresponds to (y,x) in the original.
913     *
914     * This method need not shift *all* the dimensions though, you could also swap the width
915     * and height of images in a tensor with a shape of
916     * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `transpose(["batch", "w", "h", "c"])`
917     * which would return a new tensor where all the images have been swapped over the diagonal.
918     *
919     * See also: [TensorAccess], [reorder](TensorView::reorder)
920     *
921     * # Panics
922     *
923     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
924     * order need not match (and if the order does match, this function is just an expensive
925     * clone).
926     */
927    #[track_caller]
928    pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
929        let shape = self.shape();
930        let mut reordered = self.reorder(dimensions);
931        // Transposition is essentially reordering, but we retain the dimension name ordering
932        // of the original order, this means we may swap dimension lengths, but the dimensions
933        // will not change order.
934        #[allow(clippy::needless_range_loop)]
935        for d in 0..D {
936            reordered.shape[d].0 = shape[d].0;
937        }
938        reordered
939    }
940
941    /**
942     * Returns a new Tensor which has the same data as this tensor, but with the order of the
943     * dimensions and corresponding order of data changed.
944     *
945     * For example, with a `[("x", x), ("y", y)]` tensor you could call
946     * `reorder(["y", "x"])` which would return a new tensor with a shape of
947     * `[("y", y), ("x", x)]` where every (y,x) of its data corresponds to (x,y) in the original.
948     *
949     * This method need not shift *all* the dimensions though, you could also swap the width
950     * and height of images in a tensor with a shape of
951     * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `reorder(["batch", "w", "h", "c"])`
952     * which would return a new tensor where every (b,w,h,c) of its data corresponds to (b,h,w,c)
953     * in the original.
954     *
955     * See also: [TensorAccess], [transpose](TensorView::transpose)
956     *
957     * # Panics
958     *
959     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
960     * order need not match (and if the order does match, this function is just an expensive
961     * clone).
962     */
963    #[track_caller]
964    pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
965        let reorderd = match TensorAccess::try_from(&self.source, dimensions) {
966            Ok(reordered) => reordered,
967            Err(_error) => panic!(
968                "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
969                dimensions,
970                self.shape(),
971            ),
972        };
973        let reorderd_shape = reorderd.shape();
974        Tensor::from(reorderd_shape, reorderd.iter().collect())
975    }
976
977    /**
978     * Creates and returns a new tensor with all values from the original with the
979     * function applied to each. This can be used to change the type of the tensor
980     * such as creating a mask:
981     * ```
982     * use easy_ml::tensors::Tensor;
983     * use easy_ml::tensors::views::TensorView;
984     * let x = TensorView::from(Tensor::from([("a", 2), ("b", 2)], vec![
985     *    0.0, 1.2,
986     *    5.8, 6.9
987     * ]));
988     * let y = x.map(|element| element > 2.0);
989     * let result = Tensor::from([("a", 2), ("b", 2)], vec![
990     *    false, false,
991     *    true, true
992     * ]);
993     * assert_eq!(&y, &result);
994     * ```
995     */
996    pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
997        let mapped = self.iter().map(mapping_function).collect();
998        Tensor::from(self.shape(), mapped)
999    }
1000
1001    /**
1002     * Creates and returns a new tensor with all values from the original and
1003     * the index of each value mapped by a function.
1004     */
1005    pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
1006        let mapped = self
1007            .iter()
1008            .with_index()
1009            .map(|(i, x)| mapping_function(i, x))
1010            .collect();
1011        Tensor::from(self.shape(), mapped)
1012    }
1013
1014    /**
1015     * Returns an iterator over copies of the data in this TensorView.
1016     */
1017    pub fn iter(&self) -> TensorIterator<T, S, D> {
1018        TensorIterator::from(&self.source)
1019    }
1020
1021    /**
1022     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1023     * mapped by a function.
1024     *
1025     * ```
1026     * use easy_ml::tensors::Tensor;
1027     * use easy_ml::tensors::views::TensorView;
1028     * let lhs = TensorView::from(Tensor::from([("a", 4)], vec![1, 2, 3, 4]));
1029     * let rhs = TensorView::from(Tensor::from([("a", 4)], vec![0, 1, 2, 3]));
1030     * let multiplied = lhs.elementwise(&rhs, |l, r| l * r);
1031     * assert_eq!(
1032     *     multiplied,
1033     *     Tensor::from([("a", 4)], vec![0, 2, 6, 12])
1034     * );
1035     * ```
1036     *
1037     * # Generics
1038     *
1039     * This method can be called with any right hand side that can be converted to a TensorView,
1040     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1041     *
1042     * # Panics
1043     *
1044     * If the two tensors have different shapes.
1045     */
1046    #[track_caller]
1047    pub fn elementwise<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1048    where
1049        I: Into<TensorView<T, S2, D>>,
1050        S2: TensorRef<T, D>,
1051        M: Fn(T, T) -> T,
1052    {
1053        self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
1054            mapping_function(lhs.clone(), rhs.clone())
1055        })
1056    }
1057
1058    /**
1059     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1060     * mapped by a function. The mapping function also receives each index corresponding to the
1061     * value pairs.
1062     *
1063     * # Generics
1064     *
1065     * This method can be called with any right hand side that can be converted to a TensorView,
1066     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1067     *
1068     * # Panics
1069     *
1070     * If the two tensors have different shapes.
1071     */
1072    #[track_caller]
1073    pub fn elementwise_with_index<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1074    where
1075        I: Into<TensorView<T, S2, D>>,
1076        S2: TensorRef<T, D>,
1077        M: Fn([usize; D], T, T) -> T,
1078    {
1079        self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
1080            mapping_function(i, lhs.clone(), rhs.clone())
1081        })
1082    }
1083}
1084
1085/**
1086 * TensorView methods which require only read access via a scalar [TensorRef] source
1087 * and a clonable type.
1088 */
1089impl<T, S> TensorView<T, S, 0>
1090where
1091    T: Clone,
1092    S: TensorRef<T, 0>,
1093{
1094    /**
1095     * Returns a copy of the sole element in the 0 dimensional tensor.
1096     */
1097    pub fn scalar(&self) -> T {
1098        self.source.get_reference([]).unwrap().clone()
1099    }
1100}
1101
1102/**
1103 * TensorView methods which require mutable access via a [TensorMut] source and a [Default].
1104 */
1105impl<T, S> TensorView<T, S, 0>
1106where
1107    T: Default,
1108    S: TensorMut<T, 0>,
1109{
1110    /**
1111     * Returns the sole element in the 0 dimensional tensor.
1112     */
1113    pub fn into_scalar(self) -> T {
1114        TensorOwnedIterator::from(self.source).next().unwrap()
1115    }
1116}
1117
1118/**
1119 * TensorView methods which require mutable access via a [TensorMut] source.
1120 */
1121impl<T, S, const D: usize> TensorView<T, S, D>
1122where
1123    T: Clone,
1124    S: TensorMut<T, D>,
1125{
1126    /**
1127     * Applies a function to all values in the tensor view, modifying
1128     * the tensor in place.
1129     */
1130    pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1131        self.iter_reference_mut()
1132            .for_each(|x| *x = mapping_function(x.clone()));
1133    }
1134
1135    /**
1136     * Applies a function to all values and each value's index in the tensor view, modifying
1137     * the tensor view in place.
1138     */
1139    pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
1140        self.iter_reference_mut()
1141            .with_index()
1142            .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
1143    }
1144}
1145
1146impl<T, S> TensorView<T, S, 1>
1147where
1148    T: Numeric,
1149    for<'a> &'a T: NumericRef<T>,
1150    S: TensorRef<T, 1>,
1151{
1152    /**
1153     * Computes the scalar product of two equal length vectors. For two vectors `[a,b,c]` and
1154     * `[d,e,f]`, returns `a*d + b*e + c*f`. This is also known as the dot product.
1155     *
1156     * ```
1157     * use easy_ml::tensors::Tensor;
1158     * use easy_ml::tensors::views::TensorView;
1159     * let tensor_view = TensorView::from(Tensor::from([("sequence", 5)], vec![3, 4, 5, 6, 7]));
1160     * assert_eq!(tensor_view.scalar_product(&tensor_view), 3*3 + 4*4 + 5*5 + 6*6 + 7*7);
1161     * ```
1162     *
1163     * # Generics
1164     *
1165     * This method can be called with any right hand side that can be converted to a TensorView,
1166     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1167     *
1168     * # Panics
1169     *
1170     * If the two vectors are not of equal length or their dimension names do not match.
1171     */
1172    // Would like this impl block to be in operations.rs too but then it would show first in the
1173    // TensorView docs which isn't ideal
1174    pub fn scalar_product<S2, I>(&self, rhs: I) -> T
1175    where
1176        I: Into<TensorView<T, S2, 1>>,
1177        S2: TensorRef<T, 1>,
1178    {
1179        self.scalar_product_less_generic(rhs.into())
1180    }
1181}
1182
1183impl<T, S> TensorView<T, S, 2>
1184where
1185    T: Numeric,
1186    for<'a> &'a T: NumericRef<T>,
1187    S: TensorRef<T, 2>,
1188{
1189    /**
1190     * Returns the determinant of this square matrix, or None if the matrix
1191     * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant_tensor())
1192     */
1193    pub fn determinant(&self) -> Option<T> {
1194        linear_algebra::determinant_tensor::<T, _, _>(self)
1195    }
1196
1197    /**
1198     * Computes the inverse of a matrix provided that it exists. To have an inverse a
1199     * matrix must be square (same number of rows and columns) and it must also have a
1200     * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse_tensor())
1201     */
1202    pub fn inverse(&self) -> Option<Tensor<T, 2>> {
1203        linear_algebra::inverse_tensor::<T, _, _>(self)
1204    }
1205
1206    /**
1207     * Computes the covariance matrix for this feature matrix along the specified feature
1208     * dimension in this matrix. See [`linear_algebra`](crate::linear_algebra::covariance()).
1209     */
1210    pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
1211        linear_algebra::covariance::<T, _, _>(self, feature_dimension)
1212    }
1213}
1214
1215macro_rules! tensor_view_select_impl {
1216    (impl TensorView $d:literal 1) => {
1217        impl<T, S> TensorView<T, S, $d>
1218        where
1219            S: TensorRef<T, $d>,
1220        {
1221            /**
1222             * Selects the provided dimension name and index pairs in this TensorView, returning a
1223             * TensorView which has fewer dimensions than this TensorView, with the removed dimensions
1224             * always indexed as the provided values.
1225             *
1226             * This is a shorthand for manually constructing the TensorView and
1227             * [TensorIndex]
1228             *
1229             * Note: due to limitations in Rust's const generics support, this method is only
1230             * implemented for `provided_indexes` of length 1 and `D` from 1 to 6. You can fall
1231             * back to manual construction to create `TensorIndex`es with multiple provided
1232             * indexes if you need to reduce dimensionality by more than 1 dimension at a time.
1233             */
1234            #[track_caller]
1235            pub fn select(
1236                &self,
1237                provided_indexes: [(Dimension, usize); 1],
1238            ) -> TensorView<T, TensorIndex<T, &S, $d, 1>, { $d - 1 }> {
1239                TensorView::from(TensorIndex::from(&self.source, provided_indexes))
1240            }
1241
1242            /**
1243             * Selects the provided dimension name and index pairs in this TensorView, returning a
1244             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1245             * always indexed as the provided values. The TensorIndex mutably borrows this
1246             * Tensor, and can therefore mutate it
1247             *
1248             * See [select](TensorView::select)
1249             */
1250            #[track_caller]
1251            pub fn select_mut(
1252                &mut self,
1253                provided_indexes: [(Dimension, usize); 1],
1254            ) -> TensorView<T, TensorIndex<T, &mut S, $d, 1>, { $d - 1 }> {
1255                TensorView::from(TensorIndex::from(&mut self.source, provided_indexes))
1256            }
1257
1258            /**
1259             * Selects the provided dimension name and index pairs in this TensorView, returning a
1260             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1261             * always indexed as the provided values. The TensorIndex takes ownership of this
1262             * Tensor, and can therefore mutate it
1263             *
1264             * See [select](TensorView::select)
1265             */
1266            #[track_caller]
1267            pub fn select_owned(
1268                self,
1269                provided_indexes: [(Dimension, usize); 1],
1270            ) -> TensorView<T, TensorIndex<T, S, $d, 1>, { $d - 1 }> {
1271                TensorView::from(TensorIndex::from(self.source, provided_indexes))
1272            }
1273        }
1274    };
1275}
1276
1277tensor_view_select_impl!(impl TensorView 6 1);
1278tensor_view_select_impl!(impl TensorView 5 1);
1279tensor_view_select_impl!(impl TensorView 4 1);
1280tensor_view_select_impl!(impl TensorView 3 1);
1281tensor_view_select_impl!(impl TensorView 2 1);
1282tensor_view_select_impl!(impl TensorView 1 1);
1283
1284macro_rules! tensor_view_expand_impl {
1285    (impl Tensor $d:literal 1) => {
1286        impl<T, S> TensorView<T, S, $d>
1287        where
1288            S: TensorRef<T, $d>,
1289        {
1290            /**
1291             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1292             * a particular position within the shape, returning a TensorView which has more
1293             * dimensions than this TensorView.
1294             *
1295             * This is a shorthand for manually constructing the TensorView and
1296             * [TensorExpansion]
1297             *
1298             * Note: due to limitations in Rust's const generics support, this method is only
1299             * implemented for `extra_dimension_names` of length 1 and `D` from 0 to 5. You can
1300             * fall back to manual construction to create `TensorExpansion`s with multiple provided
1301             * indexes if you need to increase dimensionality by more than 1 dimension at a time.
1302             */
1303            #[track_caller]
1304            pub fn expand(
1305                &self,
1306                extra_dimension_names: [(usize, Dimension); 1],
1307            ) -> TensorView<T, TensorExpansion<T, &S, $d, 1>, { $d + 1 }> {
1308                TensorView::from(TensorExpansion::from(&self.source, extra_dimension_names))
1309            }
1310
1311            /**
1312             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1313             * a particular position within the shape, returning a TensorView which has more
1314             * dimensions than this Tensor. The TensorIndex mutably borrows this
1315             * Tensor, and can therefore mutate it
1316             *
1317             * See [expand](Tensor::expand)
1318             */
1319            #[track_caller]
1320            pub fn expand_mut(
1321                &mut self,
1322                extra_dimension_names: [(usize, Dimension); 1],
1323            ) -> TensorView<T, TensorExpansion<T, &mut S, $d, 1>, { $d + 1 }> {
1324                TensorView::from(TensorExpansion::from(
1325                    &mut self.source,
1326                    extra_dimension_names,
1327                ))
1328            }
1329
1330            /**
1331             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1332             * a particular position within the shape, returning a TensorView which has more
1333             * dimensions than this Tensor. The TensorIndex takes ownership of this
1334             * Tensor, and can therefore mutate it
1335             *
1336             * See [expand](Tensor::expand)
1337             */
1338            #[track_caller]
1339            pub fn expand_owned(
1340                self,
1341                extra_dimension_names: [(usize, Dimension); 1],
1342            ) -> TensorView<T, TensorExpansion<T, S, $d, 1>, { $d + 1 }> {
1343                TensorView::from(TensorExpansion::from(self.source, extra_dimension_names))
1344            }
1345        }
1346    };
1347}
1348
1349tensor_view_expand_impl!(impl Tensor 0 1);
1350tensor_view_expand_impl!(impl Tensor 1 1);
1351tensor_view_expand_impl!(impl Tensor 2 1);
1352tensor_view_expand_impl!(impl Tensor 3 1);
1353tensor_view_expand_impl!(impl Tensor 4 1);
1354tensor_view_expand_impl!(impl Tensor 5 1);
1355
1356/**
1357 * Debug implementations for TensorView additionally show the visible data and visible dimensions
1358 * reported by the source as fields. This is in addition to recursive debug content of the actual
1359 * source.
1360 */
1361impl<T, S, const D: usize> std::fmt::Debug for TensorView<T, S, D>
1362where
1363    T: std::fmt::Debug,
1364    S: std::fmt::Debug + TensorRef<T, D>,
1365{
1366    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1367        f.debug_struct("TensorView")
1368            .field("visible", &DebugSourceVisible::from(&self.source))
1369            .field("shape", &self.source.view_shape())
1370            .field("source", &self.source)
1371            .finish()
1372    }
1373}
1374
1375struct DebugSourceVisible<T, S, const D: usize> {
1376    source: S,
1377    _type: PhantomData<T>,
1378}
1379
1380impl<T, S, const D: usize> DebugSourceVisible<T, S, D>
1381where
1382    T: std::fmt::Debug,
1383    S: std::fmt::Debug + TensorRef<T, D>,
1384{
1385    fn from(source: S) -> DebugSourceVisible<T, S, D> {
1386        DebugSourceVisible {
1387            source,
1388            _type: PhantomData,
1389        }
1390    }
1391}
1392
1393impl<T, S, const D: usize> std::fmt::Debug for DebugSourceVisible<T, S, D>
1394where
1395    T: std::fmt::Debug,
1396    S: std::fmt::Debug + TensorRef<T, D>,
1397{
1398    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1399        f.debug_list()
1400            .entries(TensorReferenceIterator::from(&self.source))
1401            .finish()
1402    }
1403}
1404
1405#[test]
1406fn test_debug() {
1407    let x = Tensor::from([("rows", 3), ("columns", 4)], (0..12).collect());
1408    let view = TensorView::from(&x);
1409    let debugged = format!("{:?}\n{:?}", x, view);
1410    assert_eq!(
1411        debugged,
1412        r#"Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] }
1413TensorView { visible: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], source: Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] } }"#
1414    )
1415}
1416
1417#[test]
1418fn test_debug_clipped() {
1419    let x = Tensor::from([("rows", 2), ("columns", 3)], (0..6).collect());
1420    let view = TensorView::from(&x)
1421        .range_owned([("columns", IndexRange::new(1, 2))])
1422        .unwrap();
1423    let debugged = format!("{:#?}\n{:#?}", x, view);
1424    println!("{:#?}\n{:#?}", x, view);
1425    assert_eq!(
1426        debugged,
1427        r#"Tensor {
1428    data: [
1429        0,
1430        1,
1431        2,
1432        3,
1433        4,
1434        5,
1435    ],
1436    shape: [
1437        (
1438            "rows",
1439            2,
1440        ),
1441        (
1442            "columns",
1443            3,
1444        ),
1445    ],
1446    strides: [
1447        3,
1448        1,
1449    ],
1450}
1451TensorView {
1452    visible: [
1453        1,
1454        2,
1455        4,
1456        5,
1457    ],
1458    shape: [
1459        (
1460            "rows",
1461            2,
1462        ),
1463        (
1464            "columns",
1465            2,
1466        ),
1467    ],
1468    source: TensorRange {
1469        source: Tensor {
1470            data: [
1471                0,
1472                1,
1473                2,
1474                3,
1475                4,
1476                5,
1477            ],
1478            shape: [
1479                (
1480                    "rows",
1481                    2,
1482                ),
1483                (
1484                    "columns",
1485                    3,
1486                ),
1487            ],
1488            strides: [
1489                3,
1490                1,
1491            ],
1492        },
1493        range: [
1494            IndexRange {
1495                start: 0,
1496                length: 2,
1497            },
1498            IndexRange {
1499                start: 1,
1500                length: 2,
1501            },
1502        ],
1503        _type: PhantomData<i32>,
1504    },
1505}"#
1506    )
1507}
1508
1509/**
1510 * Any tensor view of a Displayable type implements Display
1511 *
1512 * You can control the precision of the formatting using format arguments, i.e.
1513 * `format!("{:.3}", tensor)`
1514 */
1515impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorView<T, S, D>
1516where
1517    T: std::fmt::Display,
1518    S: TensorRef<T, D>,
1519{
1520    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1521        crate::tensors::display::format_view(&self.source, f)
1522    }
1523}
1524
1525/**
1526 * A Tensor can be converted to a TensorView of that Tensor.
1527 */
1528impl<T, const D: usize> From<Tensor<T, D>> for TensorView<T, Tensor<T, D>, D> {
1529    fn from(tensor: Tensor<T, D>) -> TensorView<T, Tensor<T, D>, D> {
1530        TensorView::from(tensor)
1531    }
1532}
1533
1534/**
1535 * A reference to a Tensor can be converted to a TensorView of that referenced Tensor.
1536 */
1537impl<'a, T, const D: usize> From<&'a Tensor<T, D>> for TensorView<T, &'a Tensor<T, D>, D> {
1538    fn from(tensor: &Tensor<T, D>) -> TensorView<T, &Tensor<T, D>, D> {
1539        TensorView::from(tensor)
1540    }
1541}
1542
1543/**
1544 * A mutable reference to a Tensor can be converted to a TensorView of that mutably referenced
1545 * Tensor.
1546 */
1547impl<'a, T, const D: usize> From<&'a mut Tensor<T, D>> for TensorView<T, &'a mut Tensor<T, D>, D> {
1548    fn from(tensor: &mut Tensor<T, D>) -> TensorView<T, &mut Tensor<T, D>, D> {
1549        TensorView::from(tensor)
1550    }
1551}
1552
1553/**
1554 * A reference to a TensorView can be converted to an owned TensorView with a reference to the
1555 * source type of that first TensorView.
1556 */
1557impl<'a, T, S, const D: usize> From<&'a TensorView<T, S, D>> for TensorView<T, &'a S, D>
1558where
1559    S: TensorRef<T, D>,
1560{
1561    fn from(tensor_view: &TensorView<T, S, D>) -> TensorView<T, &S, D> {
1562        TensorView::from(tensor_view.source_ref())
1563    }
1564}
1565
1566/**
1567 * A mutable reference to a TensorView can be converted to an owned TensorView with a mutable
1568 * reference to the source type of that first TensorView.
1569 */
1570impl<'a, T, S, const D: usize> From<&'a mut TensorView<T, S, D>> for TensorView<T, &'a mut S, D>
1571where
1572    S: TensorRef<T, D>,
1573{
1574    fn from(tensor_view: &mut TensorView<T, S, D>) -> TensorView<T, &mut S, D> {
1575        TensorView::from(tensor_view.source_ref_mut())
1576    }
1577}