easy_ml/tensors/
views.rs

1/*!
2 * Generic views into a tensor.
3 *
4 * The concept of a view into a tensor is built from the low level [TensorRef] and
5 * [TensorMut] traits which define having read and read/write access to Tensor data
6 * respectively, and the high level API implemented on the [TensorView] struct.
7 *
8 * Since a Tensor is itself a TensorRef, the APIs for the traits are a little verbose to
9 * avoid name clashes with methods defined on the Tensor and TensorView types. You should
10 * typically use TensorRef and TensorMut implementations via the TensorView struct which provides
11 * an API closely resembling Tensor.
12 */
13
14use std::marker::PhantomData;
15
16use crate::linear_algebra;
17use crate::numeric::{Numeric, NumericRef};
18use crate::tensors::dimensions;
19use crate::tensors::indexing::{
20    TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
21    TensorReferenceMutIterator, TensorTranspose,
22};
23use crate::tensors::{Dimension, Tensor};
24
25mod indexes;
26mod map;
27mod ranges;
28mod renamed;
29mod reshape;
30mod reverse;
31pub mod traits;
32mod zip;
33
34pub use indexes::*;
35pub(crate) use map::*;
36pub use ranges::*;
37pub use renamed::*;
38pub use reshape::*;
39pub use reverse::*;
40pub use zip::*;
41
42/**
43* A shared/immutable reference to a tensor (or a portion of it) of some type and number of
44* dimensions.
45*
46* # Indexing
47*
48* A TensorRef has a shape of type `[(Dimension, usize); D]`. This defines the valid indexes along
49* each dimension name and length pair from 0 inclusive to the length exclusive. If the shape was
50* `[("r", 2), ("c", 2)]` the indexes used would be `[0,0]`, `[0,1]`, `[1,0]` and `[1,1]`.
51* Although the dimension name in each pair is used for many high level APIs, for TensorRef the
52* order of dimensions is used, and the indexes (`[usize; D]`) these trait methods are called with
53* must be in the same order as the shape. In general some `[("a", a), ("b", b), ("c", c)...]`
54* shape is indexed from `[0,0,0...]` through to `[a - 1, b - 1, c - 1...]`, regardless of how the
55* data is actually laid out in memory.
56*
57* # Safety
58*
59* In order to support returning references without bounds checking in a useful way, the
60* implementing type is required to uphold several invariants that cannot be checked by
61* the compiler.
62*
63* 1 - Any valid index as described in Indexing will yield a safe reference when calling
64* `get_reference_unchecked` and `get_reference_unchecked_mut`.
65*
66* 2 - The view shape that defines which indexes are valid may not be changed by a shared reference
67* to the TensorRef implementation. ie, the tensor may not be resized while a mutable reference is
68* held to it, except by that reference.
69*
70* 3 - All dimension names in the `view_shape` must be unique.
71*
72* 4 - All dimension lengths in the `view_shape` must be non zero.
73*
74* 5 - `data_layout` must return values correctly as documented on [`DataLayout`]
75*
76* Essentially, interior mutability causes problems, since code looping through the range of valid
77* indexes in a TensorRef needs to be able to rely on that range of valid indexes not changing.
78* This is trivially the case by default since a [Tensor] does not have any form of
79* interior mutability, and therefore an iterator holding a shared reference to a Tensor prevents
80* that tensor being resized. However, a type *wrongly* implementing TensorRef could introduce
81* interior mutability by putting the Tensor in an `Arc<Mutex<>>` which would allow another thread
82* to resize a tensor while an iterator was looping through previously valid indexes on a different
83* thread. This is the same contract as
84* [`NoInteriorMutability`](crate::matrices::views::NoInteriorMutability) used in the matrix APIs.
85*
86* Note that it is okay to be able to resize any TensorRef implementation if that always requires
87* an exclusive reference to the TensorRef/Tensor, since the exclusivity prevents the above
88* scenario.
89*/
90pub unsafe trait TensorRef<T, const D: usize> {
91    /**
92     * Gets a reference to the value at the index if the index is in range. Otherwise returns None.
93     */
94    fn get_reference(&self, indexes: [usize; D]) -> Option<&T>;
95
96    /**
97     * The shape this tensor has. See [dimensions] for an overview.
98     * The product of the lengths in the pairs define how many elements are in the tensor
99     * (or the portion of it that is visible).
100     */
101    fn view_shape(&self) -> [(Dimension, usize); D];
102
103    /**
104     * Gets a reference to the value at the index without doing any bounds checking. For a safe
105     * alternative see [get_reference](TensorRef::get_reference).
106     *
107     * # Safety
108     *
109     * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
110     * resulting reference is not used. Valid indexes are defined as in [TensorRef].
111     *
112     * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
113     * [TensorRef]: TensorRef
114     */
115    #[allow(clippy::missing_safety_doc)] // it's not missing
116    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T;
117
118    /**
119     * The way the data in this tensor is laid out in memory. In particular,
120     * [`Linear`](DataLayout) has several requirements on what is returned that must be upheld
121     * by implementations of this trait.
122     *
123     * For a [Tensor] this would return `DataLayout::Linear` in the same order as the `view_shape`,
124     * since the data is in a single line and the `view_shape` is in most significant dimension
125     * to least. Many views however, create shapes that do not correspond to linear memory, either
126     * by combining non array data with tensors, or hiding dimensions. Similarly, a row major
127     * Tensor might be transposed to a column major TensorView, so the view shape could be reversed
128     * compared to the order of significance of the dimensions in memory.
129     *
130     * In general, an array of dimension names matching the view shape order is big endian, and
131     * will be iterated through efficiently by [TensorIterator], and an array of dimension names
132     * in reverse of the view shape is in little endian order.
133     *
134     * The implementation of this trait must ensure that if it returns `DataLayout::Linear`
135     * the set of dimension names are returned in the order of most significant to least and match
136     * the set of the dimension names returned by `view_shape`.
137     */
138    fn data_layout(&self) -> DataLayout<D>;
139}
140
141/**
142 * How the data in the tensor is laid out in memory.
143 */
144#[derive(Clone, Debug, Eq, PartialEq)]
145pub enum DataLayout<const D: usize> {
146    /**
147     * The data is laid out in linear storage in memory, such that we could take a slice over the
148     * entire data specified by our `view_shape`.
149     *
150     * The `D` length array specifies the dimensions in the `view_shape` in the order of most
151     * significant dimension (in memory) to least.
152     *
153     * In general, an array of dimension names in the same order as the view shape is big endian
154     * order (implying the order of dimensions in the view shape is most significant to least),
155     * and will be iterated through efficiently by [TensorIterator], and an array of
156     * dimension names in reverse of the view shape is in little endian order
157     * (implying the order of dimensions in the view shape is least significant to most).
158     *
159     * In memory, the data will have some order such that if we want repeatedly take 1 step
160     * through memory from the first value to the last there will be a most significant dimension
161     * that always goes up, through to a least significant dimension with the most rapidly varying
162     * index.
163     *
164     * In most of Easy ML's Tensors, the `view_shape` dimensions would already be in the order of
165     * most significant dimension to least (since [Tensor] stores its data in big endian order),
166     * so the list of dimension names will just increment match the order of the view shape.
167     *
168     * For example, a tensor with a view shape of `[("batch", 2), ("row", 2), ("column", 3)]` that
169     * stores its data in most significant to least would be (indexed) like:
170     * ```ignore
171     * [
172     *   (0,0,0), (0,0,1), (0,0,2),
173     *   (0,1,0), (0,1,1), (0,1,2),
174     *   (1,0,0), (1,0,1), (1,0,2),
175     *   (1,1,0), (1,1,1), (1,1,2)
176     * ]
177     * ```
178     *
179     * To take one step in memory, we would increment the right most dimension index ("column"),
180     * counting our way up through to the left most dimension index ("batch"). If we changed
181     * this tensor to `[("column", 3), ("row", 2), ("batch", 2)]` so that the `view_shape` was
182     * swapped to least significant dimension to most but the data remained in the same order,
183     * our tensor would still have a DataLayout of `Linear(["batch", "row", "column"])`, since
184     * the indexes in the transposed `view_shape` correspond to an actual memory layout that's
185     * completely reversed. Alternatively, you could say we reversed the view shape but the
186     * memory layout never changed:
187     * ```ignore
188     * [
189     *   (0,0,0), (1,0,0), (2,0,0),
190     *   (0,1,0), (1,1,0), (2,1,0),
191     *   (0,0,1), (1,0,1), (2,0,1),
192     *   (0,1,1), (1,1,1), (2,1,1)
193     * ]
194     * ```
195     *
196     * To take one step in memory, we now need to increment the left most dimension index (on
197     * the view shape) ("column"), counting our way in reverse to the right most dimension
198     * index ("batch").
199     *
200     * That `["batch", "row", "column"]` is also exactly the order you would need to swap your
201     * dimensions on the `view_shape` to get back to most significant to least. A [TensorAccess]
202     * could reorder the tensor by this array order to get back to most significant to least
203     * ordering on the `view_shape` in order to iterate through the data efficiently.
204     */
205    Linear([Dimension; D]),
206    /**
207     * The data is not laid out in linear storage in memory.
208     */
209    NonLinear,
210    /**
211     * The data is not laid out in a linear or non linear way, or we don't know how it's laid
212     * out.
213     */
214    Other,
215}
216
217/**
218 * A unique/mutable reference to a tensor (or a portion of it) of some type.
219 *
220 * # Safety
221 *
222 * See [TensorRef].
223 */
224pub unsafe trait TensorMut<T, const D: usize>: TensorRef<T, D> {
225    /**
226     * Gets a mutable reference to the value at the index, if the index is in range. Otherwise
227     * returns None.
228     */
229    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T>;
230
231    /**
232     * Gets a mutable reference to the value at the index without doing any bounds checking.
233     * For a safe alternative see [get_reference_mut](TensorMut::get_reference_mut).
234     *
235     * # Safety
236     *
237     * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
238     * resulting reference is not used. Valid indexes are defined as in [TensorRef].
239     *
240     * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
241     * [TensorRef]: TensorRef
242     */
243    #[allow(clippy::missing_safety_doc)] // it's not missing
244    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T;
245}
246
247/**
248 * A view into some or all of a tensor.
249 *
250 * A TensorView has a similar relationship to a [`Tensor`] as a `&str` has to a `String`, or an
251 * array slice to an array. A TensorView cannot resize its source, and may span only a portion
252 * of the source Tensor in each dimension.
253 *
254 * However a TensorView is generic not only over the type of the data in the Tensor,
255 * but also over the way the Tensor is 'sliced' and the two are orthogonal to each other.
256 *
257 * TensorView closely mirrors the API of Tensor.
258 * Methods that create a new tensor do not return a TensorView, they return a Tensor.
259 */
260#[derive(Clone)]
261pub struct TensorView<T, S, const D: usize> {
262    source: S,
263    _type: PhantomData<T>,
264}
265
266/**
267 * TensorView methods which require only read access via a [TensorRef] source.
268 */
269impl<T, S, const D: usize> TensorView<T, S, D>
270where
271    S: TensorRef<T, D>,
272{
273    /**
274     * Creates a TensorView from a source of some type.
275     *
276     * The lifetime of the source determines the lifetime of the TensorView created. If the
277     * TensorView is created from a reference to a Tensor, then the TensorView cannot live
278     * longer than the Tensor referenced.
279     */
280    pub fn from(source: S) -> TensorView<T, S, D> {
281        TensorView {
282            source,
283            _type: PhantomData,
284        }
285    }
286
287    /**
288     * Consumes the tensor view, yielding the source it was created from.
289     */
290    pub fn source(self) -> S {
291        self.source
292    }
293
294    /**
295     * Gives a reference to the tensor view's source.
296     */
297    pub fn source_ref(&self) -> &S {
298        &self.source
299    }
300
301    /**
302     * Gives a mutable reference to the tensor view's source.
303     */
304    pub fn source_ref_mut(&mut self) -> &mut S {
305        &mut self.source
306    }
307
308    /**
309     * The shape of this tensor view. Since Tensors are named Tensors, their shape is not just a
310     * list of length along each dimension, but instead a list of pairs of names and lengths.
311     *
312     * Note that a TensorView may have a shape which is different than the Tensor it is providing
313     * access to the data of. The TensorView might be [masking dimensions](TensorIndex) or
314     * elements from the shape, or exposing [false ones](TensorExpansion).
315     *
316     * See also
317     * - [dimensions]
318     * - [indexing](crate::tensors::indexing)
319     */
320    pub fn shape(&self) -> [(Dimension, usize); D] {
321        self.source.view_shape()
322    }
323
324    /**
325     * Returns the length of the dimension name provided, if one is present in the tensor view.
326     *
327     * See also
328     * - [dimensions]
329     * - [indexing](crate::tensors::indexing)
330     */
331    pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
332        dimensions::length_of(&self.source.view_shape(), dimension)
333    }
334
335    /**
336     * Returns the last index of the dimension name provided, if one is present in the tensor view.
337     *
338     * This is always 1 less than the length, the 'index' in this sense is based on what the
339     * Tensor's shape is, not any implementation index.
340     *
341     * See also
342     * - [dimensions]
343     * - [indexing](crate::tensors::indexing)
344     */
345    pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
346        dimensions::last_index_of(&self.source.view_shape(), dimension)
347    }
348
349    /**
350     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
351     * to read values from this tensor view.
352     *
353     * # Panics
354     *
355     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
356     */
357    #[track_caller]
358    pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &S, D> {
359        TensorAccess::from(&self.source, dimensions)
360    }
361
362    /**
363     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
364     * to read or write values from this tensor view.
365     *
366     * # Panics
367     *
368     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
369     */
370    #[track_caller]
371    pub fn index_by_mut(&mut self, dimensions: [Dimension; D]) -> TensorAccess<T, &mut S, D> {
372        TensorAccess::from(&mut self.source, dimensions)
373    }
374
375    /**
376     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
377     * to read or write values from this tensor view.
378     *
379     * # Panics
380     *
381     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
382     */
383    #[track_caller]
384    pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
385        TensorAccess::from(self.source, dimensions)
386    }
387
388    /**
389     * Creates a TensorAccess which will index into the dimensions of the source this TensorView
390     * was created with in the same order as they were declared.
391     * See [TensorAccess::from_source_order].
392     */
393    pub fn index(&self) -> TensorAccess<T, &S, D> {
394        TensorAccess::from_source_order(&self.source)
395    }
396
397    /**
398     * Creates a TensorAccess which will index into the dimensions of the source this TensorView
399     * was created with in the same order as they were declared. The TensorAccess mutably borrows
400     * the source, and can therefore mutate it if it implements TensorMut.
401     * See [TensorAccess::from_source_order].
402     */
403    pub fn index_mut(&mut self) -> TensorAccess<T, &mut S, D> {
404        TensorAccess::from_source_order(&mut self.source)
405    }
406
407    /**
408     * Creates a TensorAccess which will index into the dimensions this Tensor was
409     * created with in the same order as they were provided. The TensorAccess takes ownership
410     * of the Tensor, and can therefore mutate it. The TensorAccess takes ownership of
411     * the source, and can therefore mutate it if it implements TensorMut.
412     * See [TensorAccess::from_source_order].
413     */
414    pub fn index_owned(self) -> TensorAccess<T, S, D> {
415        TensorAccess::from_source_order(self.source)
416    }
417
418    /**
419     * Returns an iterator over references to the data in this TensorView.
420     */
421    pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, S, D> {
422        TensorReferenceIterator::from(&self.source)
423    }
424
425    /**
426     * Returns a TensorView with the dimension names of the shape renamed to the provided
427     * dimensions. The data of this tensor and the dimension lengths and order remain unchanged.
428     *
429     * This is a shorthand for constructing the TensorView from this TensorView. See
430     * [`Tensor::rename_view`](Tensor::rename_view).
431     *
432     * # Panics
433     *
434     * - If a dimension name is not unique
435     */
436    #[track_caller]
437    pub fn rename_view(
438        &self,
439        dimensions: [Dimension; D],
440    ) -> TensorView<T, TensorRename<T, &S, D>, D> {
441        TensorView::from(TensorRename::from(&self.source, dimensions))
442    }
443
444    /**
445     * Returns a TensorView with the dimensions changed to the provided shape without moving any
446     * data around. The new Tensor may also have a different number of dimensions.
447     *
448     * This is a shorthand for constructing the TensorView from this TensorView. See
449     * [`Tensor::reshape_view`](Tensor::reshape_view).
450     *
451     * # Panics
452     *
453     * - If the number of provided elements in the new shape does not match the product of the
454     * dimension lengths in the existing tensor's shape.
455     * - If a dimension name is not unique
456     */
457    pub fn reshape<const D2: usize>(
458        &self,
459        shape: [(Dimension, usize); D2],
460    ) -> TensorView<T, TensorReshape<T, &S, D, D2>, D2> {
461        TensorView::from(TensorReshape::from(&self.source, shape))
462    }
463
464    /**
465     * Returns a TensorView with the dimensions changed to the provided shape without moving any
466     * data around. The new Tensor may also have a different number of dimensions. The
467     * TensorReshape mutably borrows the source, and can therefore mutate it if it implements
468     * TensorMut.
469     *
470     * This is a shorthand for constructing the TensorView from this TensorView. See
471     * [`Tensor::reshape_view`](Tensor::reshape_view).
472     *
473     * # Panics
474     *
475     * - If the number of provided elements in the new shape does not match the product of the
476     * dimension lengths in the existing tensor's shape.
477     * - If a dimension name is not unique
478     */
479    pub fn reshape_mut<const D2: usize>(
480        &mut self,
481        shape: [(Dimension, usize); D2],
482    ) -> TensorView<T, TensorReshape<T, &mut S, D, D2>, D2> {
483        TensorView::from(TensorReshape::from(&mut self.source, shape))
484    }
485
486    /**
487     * Returns a TensorView with the dimensions changed to the provided shape without moving any
488     * data around. The new Tensor may also have a different number of dimensions. The
489     * TensorReshape takes ownership of the source, and can therefore mutate it if it implements
490     * TensorMut.
491     *
492     * This is a shorthand for constructing the TensorView from this TensorView. See
493     * [`Tensor::reshape_view`](Tensor::reshape_view).
494     *
495     * # Panics
496     *
497     * - If the number of provided elements in the new shape does not match the product of the
498     * dimension lengths in the existing tensor's shape.
499     * - If a dimension name is not unique
500     */
501    pub fn reshape_owned<const D2: usize>(
502        self,
503        shape: [(Dimension, usize); D2],
504    ) -> TensorView<T, TensorReshape<T, S, D, D2>, D2> {
505        TensorView::from(TensorReshape::from(self.source, shape))
506    }
507
508    /**
509     * Given the dimension name, returns a view of this tensor reshaped to one dimension
510     * with a length equal to the number of elements in this tensor.
511     */
512    pub fn flatten(&self, dimension: Dimension) -> TensorView<T, TensorReshape<T, &S, D, 1>, 1> {
513        self.reshape([(dimension, dimensions::elements(&self.shape()))])
514    }
515
516    /**
517     * Given the dimension name, returns a view of this tensor reshaped to one dimension
518     * with a length equal to the number of elements in this tensor.
519     */
520    pub fn flatten_mut(
521        &mut self,
522        dimension: Dimension,
523    ) -> TensorView<T, TensorReshape<T, &mut S, D, 1>, 1> {
524        self.reshape_mut([(dimension, dimensions::elements(&self.shape()))])
525    }
526
527    /**
528     * Given the dimension name, returns a view of this tensor reshaped to one dimension
529     * with a length equal to the number of elements in this tensor.
530     *
531     * If you intend to query the tensor a lot after creating the view, consider
532     * using [flatten_into_tensor](TensorView::flatten_into_tensor) instead as it will have
533     * less overhead to index after creation.
534     */
535    pub fn flatten_owned(
536        self,
537        dimension: Dimension,
538    ) -> TensorView<T, TensorReshape<T, S, D, 1>, 1> {
539        let length = dimensions::elements(&self.shape());
540        self.reshape_owned([(dimension, length)])
541    }
542
543    /**
544     * Given the dimension name, returns a new tensor reshaped to one dimension
545     * with a length equal to the number of elements in this tensor and all elements
546     * copied into the new tensor.
547     */
548    pub fn flatten_into_tensor(self, dimension: Dimension) -> Tensor<T, 1>
549    where
550        T: Clone,
551    {
552        // TODO: Want a specialisation here to use Tensor::reshape_owned when we
553        // know that our source type is Tensor
554        let length = dimensions::elements(&self.shape());
555        Tensor::from([(dimension, length)], self.iter().collect())
556    }
557
558    /**
559     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
560     * range from view. Error cases are documented on [TensorRange].
561     *
562     * This is a shorthand for constructing the TensorView from this TensorView. See
563     * [`Tensor::range`](Tensor::range).
564     */
565    pub fn range<R, const P: usize>(
566        &self,
567        ranges: [(Dimension, R); P],
568    ) -> Result<TensorView<T, TensorRange<T, &S, D>, D>, IndexRangeValidationError<D, P>>
569    where
570        R: Into<IndexRange>,
571    {
572        TensorRange::from(&self.source, ranges).map(|range| TensorView::from(range))
573    }
574
575    /**
576     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
577     * range from view. Error cases are documented on [TensorRange]. The TensorRange
578     * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
579     *
580     * This is a shorthand for constructing the TensorView from this TensorView. See
581     * [`Tensor::range`](Tensor::range).
582     */
583    pub fn range_mut<R, const P: usize>(
584        &mut self,
585        ranges: [(Dimension, R); P],
586    ) -> Result<TensorView<T, TensorRange<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
587    where
588        R: Into<IndexRange>,
589    {
590        TensorRange::from(&mut self.source, ranges).map(|range| TensorView::from(range))
591    }
592
593    /**
594     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
595     * range from view. Error cases are documented on [TensorRange]. The TensorRange
596     * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
597     *
598     * This is a shorthand for constructing the TensorView from this TensorView. See
599     * [`Tensor::range`](Tensor::range).
600     */
601    pub fn range_owned<R, const P: usize>(
602        self,
603        ranges: [(Dimension, R); P],
604    ) -> Result<TensorView<T, TensorRange<T, S, D>, D>, IndexRangeValidationError<D, P>>
605    where
606        R: Into<IndexRange>,
607    {
608        TensorRange::from(self.source, ranges).map(|range| TensorView::from(range))
609    }
610
611    /**
612     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
613     * range from view. Error cases are documented on [TensorMask].
614     *
615     * This is a shorthand for constructing the TensorView from this TensorView. See
616     * [`Tensor::mask`](Tensor::mask).
617     */
618    pub fn mask<R, const P: usize>(
619        &self,
620        masks: [(Dimension, R); P],
621    ) -> Result<TensorView<T, TensorMask<T, &S, D>, D>, IndexRangeValidationError<D, P>>
622    where
623        R: Into<IndexRange>,
624    {
625        TensorMask::from(&self.source, masks).map(|mask| TensorView::from(mask))
626    }
627
628    /**
629     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
630     * range from view. Error cases are documented on [TensorMask]. The TensorMask
631     * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
632     *
633     * This is a shorthand for constructing the TensorView from this TensorView. See
634     * [`Tensor::mask`](Tensor::mask).
635     */
636    pub fn mask_mut<R, const P: usize>(
637        &mut self,
638        masks: [(Dimension, R); P],
639    ) -> Result<TensorView<T, TensorMask<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
640    where
641        R: Into<IndexRange>,
642    {
643        TensorMask::from(&mut self.source, masks).map(|mask| TensorView::from(mask))
644    }
645
646    /**
647     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
648     * range from view. Error cases are documented on [TensorMask]. The TensorMask
649     * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
650     *
651     * This is a shorthand for constructing the TensorView from this TensorView. See
652     * [`Tensor::mask`](Tensor::mask).
653     */
654    pub fn mask_owned<R, const P: usize>(
655        self,
656        masks: [(Dimension, R); P],
657    ) -> Result<TensorView<T, TensorMask<T, S, D>, D>, IndexRangeValidationError<D, P>>
658    where
659        R: Into<IndexRange>,
660    {
661        TensorMask::from(self.source, masks).map(|mask| TensorView::from(mask))
662    }
663
664    /**
665     * Returns a TensorView with a mask taken in the provided dimension, hiding
666     * all but the start_and_end number of values at the start and end of the
667     * dimension from view.
668     *
669     * This is a shorthand for constructing the TensorView from this TensorView. See
670     * [`Tensor::start_and_end_of`](Tensor::start_and_end_of).
671     *
672     * # Panics
673     *
674     * - If the start_and_end value is 0 - this is not a valid mask as it would
675     * hide all elements
676     * - If the dimension is not in the tensor's shape.
677     */
678    #[track_caller]
679    pub fn start_and_end_of(
680        &self,
681        dimension: Dimension,
682        start_and_end: usize,
683    ) -> TensorView<T, TensorMask<T, &S, D>, D> {
684        TensorMask::panicking_start_and_end_of(&self.source, dimension, start_and_end)
685    }
686
687    /**
688     * Returns a TensorView with a mask taken in the provided dimension, hiding
689     * all but the start_and_end number of values at the start and end of the
690     * dimension from view. The TensorMask mutably borrows the source, and can
691     * therefore mutate it if it implements TensorMut.
692     *
693     * This is a shorthand for constructing the TensorView from this TensorView. See
694     * [`Tensor::start_and_end_of_mut`](Tensor::start_and_end_of_mut).
695     *
696     * # Panics
697     *
698     * - If the start_and_end value is 0 - this is not a valid mask as it would
699     * hide all elements
700     * - If the dimension is not in the tensor's shape.
701     */
702    #[track_caller]
703    pub fn start_and_end_of_mut(
704        &mut self,
705        dimension: Dimension,
706        start_and_end: usize,
707    ) -> TensorView<T, TensorMask<T, &mut S, D>, D> {
708        TensorMask::panicking_start_and_end_of(&mut self.source, dimension, start_and_end)
709    }
710
711    /**
712     * Returns a TensorView with a mask taken in the provided dimension, hiding
713     * all but the start_and_end number of values at the start and end of the
714     * dimension from view. The TensorMask takes ownership of the source, and
715     * can therefore mutate it if it implements TensorMut.
716     *
717     * This is a shorthand for constructing the TensorView from this TensorView. See
718     * [`Tensor::start_and_end_of_owned`](Tensor::start_and_end_of_owned).
719     *
720     * # Panics
721     *
722     * - If the start_and_end value is 0 - this is not a valid mask as it would
723     * hide all elements
724     * - If the dimension is not in the tensor's shape.
725     */
726    #[track_caller]
727    pub fn start_and_end_of_owned(
728        self,
729        dimension: Dimension,
730        start_and_end: usize,
731    ) -> TensorView<T, TensorMask<T, S, D>, D> {
732        TensorMask::panicking_start_and_end_of(self.source, dimension, start_and_end)
733    }
734
735    /**
736     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
737     * order. The data of this tensor and the dimension lengths remain unchanged.
738     *
739     * This is a shorthand for constructing the TensorView from this TensorView. See
740     * [`Tensor::reverse`](Tensor::reverse).
741     *
742     * # Panics
743     *
744     * - If a dimension name is not in the tensor's shape or is repeated.
745     */
746    #[track_caller]
747    pub fn reverse(&self, dimensions: &[Dimension]) -> TensorView<T, TensorReverse<T, &S, D>, D> {
748        TensorView::from(TensorReverse::from(&self.source, dimensions))
749    }
750
751    /**
752     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
753     * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
754     * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
755     *
756     * This is a shorthand for constructing the TensorView from this TensorView. See
757     * [`Tensor::reverse`](Tensor::reverse).
758     *
759     * # Panics
760     *
761     * - If a dimension name is not in the tensor's shape or is repeated.
762     */
763    #[track_caller]
764    pub fn reverse_mut(
765        &mut self,
766        dimensions: &[Dimension],
767    ) -> TensorView<T, TensorReverse<T, &mut S, D>, D> {
768        TensorView::from(TensorReverse::from(&mut self.source, dimensions))
769    }
770
771    /**
772     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
773     * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
774     * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
775     *
776     * This is a shorthand for constructing the TensorView from this TensorView. See
777     * [`Tensor::reverse`](Tensor::reverse).
778     *
779     * # Panics
780     *
781     * - If a dimension name is not in the tensor's shape or is repeated.
782     */
783    #[track_caller]
784    pub fn reverse_owned(
785        self,
786        dimensions: &[Dimension],
787    ) -> TensorView<T, TensorReverse<T, S, D>, D> {
788        TensorView::from(TensorReverse::from(self.source, dimensions))
789    }
790
791    /**
792     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
793     * mapped by a function. The value pairs are not copied for you, if you're using `Copy` types
794     * or need to clone the values anyway, you can use
795     * [`TensorView::elementwise`](TensorView::elementwise) instead.
796     *
797     * # Generics
798     *
799     * This method can be called with any right hand side that can be converted to a TensorView,
800     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
801     *
802     * # Panics
803     *
804     * If the two tensors have different shapes.
805     */
806    #[track_caller]
807    pub fn elementwise_reference<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
808    where
809        I: Into<TensorView<T, S2, D>>,
810        S2: TensorRef<T, D>,
811        M: Fn(&T, &T) -> T,
812    {
813        self.elementwise_reference_less_generic(rhs.into(), mapping_function)
814    }
815
816    /**
817     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
818     * mapped by a function. The mapping function also receives each index corresponding to the
819     * value pairs. The value pairs are not copied for you, if you're using `Copy` types
820     * or need to clone the values anyway, you can use
821     * [`TensorView::elementwise_with_index`](TensorView::elementwise_with_index) instead.
822     *
823     * # Generics
824     *
825     * This method can be called with any right hand side that can be converted to a TensorView,
826     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
827     *
828     * # Panics
829     *
830     * If the two tensors have different shapes.
831     */
832    #[track_caller]
833    pub fn elementwise_reference_with_index<S2, I, M>(
834        &self,
835        rhs: I,
836        mapping_function: M,
837    ) -> Tensor<T, D>
838    where
839        I: Into<TensorView<T, S2, D>>,
840        S2: TensorRef<T, D>,
841        M: Fn([usize; D], &T, &T) -> T,
842    {
843        self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
844    }
845
846    #[track_caller]
847    fn elementwise_reference_less_generic<S2, M>(
848        &self,
849        rhs: TensorView<T, S2, D>,
850        mapping_function: M,
851    ) -> Tensor<T, D>
852    where
853        S2: TensorRef<T, D>,
854        M: Fn(&T, &T) -> T,
855    {
856        let left_shape = self.shape();
857        let right_shape = rhs.shape();
858        if left_shape != right_shape {
859            panic!(
860                "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
861                left_shape, right_shape
862            );
863        }
864        let mapped = self
865            .iter_reference()
866            .zip(rhs.iter_reference())
867            .map(|(x, y)| mapping_function(x, y))
868            .collect();
869        Tensor::from(left_shape, mapped)
870    }
871
872    #[track_caller]
873    fn elementwise_reference_less_generic_with_index<S2, M>(
874        &self,
875        rhs: TensorView<T, S2, D>,
876        mapping_function: M,
877    ) -> Tensor<T, D>
878    where
879        S2: TensorRef<T, D>,
880        M: Fn([usize; D], &T, &T) -> T,
881    {
882        let left_shape = self.shape();
883        let right_shape = rhs.shape();
884        if left_shape != right_shape {
885            panic!(
886                "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
887                left_shape, right_shape
888            );
889        }
890        // we just checked both shapes were the same, so we don't need to propagate indexes
891        // for both tensors because they'll be identical
892        let mapped = self
893            .iter_reference()
894            .with_index()
895            .zip(rhs.iter_reference())
896            .map(|((i, x), y)| mapping_function(i, x, y))
897            .collect();
898        Tensor::from(left_shape, mapped)
899    }
900
901    /**
902     * Returns a TensorView which makes the order of the data in this tensor appear to be in
903     * a different order. The order of the dimension names is unchanged, although their lengths
904     * may swap.
905     *
906     * This is a shorthand for constructing the TensorView from this TensorView.
907     *
908     * See also: [transpose](TensorView::transpose), [TensorTranspose]
909     *
910     * # Panics
911     *
912     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
913     * order need not match.
914     */
915    pub fn transpose_view(
916        &self,
917        dimensions: [Dimension; D],
918    ) -> TensorView<T, TensorTranspose<T, &S, D>, D> {
919        TensorView::from(TensorTranspose::from(&self.source, dimensions))
920    }
921
922    /// Unverified constructor for interal use when we know the dimensions/data/strides are
923    /// the same as the existing instance and don't need reverification
924    pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> TensorView<T, Tensor<T, D>, D> {
925        let shape = self.shape();
926        let strides = crate::tensors::compute_strides(&shape);
927        TensorView::from(Tensor {
928            data,
929            shape,
930            strides,
931        })
932    }
933}
934
935impl<T, S, const D: usize> TensorView<T, S, D>
936where
937    S: TensorMut<T, D>,
938{
939    /**
940     * Returns an iterator over mutable references to the data in this TensorView.
941     */
942    pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<'_, T, S, D> {
943        TensorReferenceMutIterator::from(&mut self.source)
944    }
945
946    /**
947     * Creates an iterator over the values in this TensorView.
948     */
949    pub fn iter_owned(self) -> TensorOwnedIterator<T, S, D>
950    where
951        T: Default,
952    {
953        TensorOwnedIterator::from(self.source())
954    }
955}
956
957/**
958 * TensorView methods which require only read access via a [TensorRef] source
959 * and a clonable type.
960 */
961impl<T, S, const D: usize> TensorView<T, S, D>
962where
963    T: Clone,
964    S: TensorRef<T, D>,
965{
966    /**
967     * Gets a copy of the first value in this tensor.
968     * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
969     * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
970     */
971    pub fn first(&self) -> T {
972        self.iter()
973            .next()
974            .expect("Tensors always have at least 1 element")
975    }
976
977    /**
978     * Returns a new Tensor which has the same data as this tensor, but with the order of data
979     * changed. The order of the dimension names is unchanged, although their lengths may swap.
980     *
981     * For example, with a `[("x", x), ("y", y)]` tensor you could call
982     * `transpose(["y", "x"])` which would return a new tensor with a shape of
983     * `[("x", y), ("y", x)]` where every (x,y) of its data corresponds to (y,x) in the original.
984     *
985     * This method need not shift *all* the dimensions though, you could also swap the width
986     * and height of images in a tensor with a shape of
987     * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `transpose(["batch", "w", "h", "c"])`
988     * which would return a new tensor where all the images have been swapped over the diagonal.
989     *
990     * See also: [TensorAccess], [reorder](TensorView::reorder)
991     *
992     * # Panics
993     *
994     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
995     * order need not match (and if the order does match, this function is just an expensive
996     * clone).
997     */
998    #[track_caller]
999    pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1000        let shape = self.shape();
1001        let mut reordered = self.reorder(dimensions);
1002        // Transposition is essentially reordering, but we retain the dimension name ordering
1003        // of the original order, this means we may swap dimension lengths, but the dimensions
1004        // will not change order.
1005        #[allow(clippy::needless_range_loop)]
1006        for d in 0..D {
1007            reordered.shape[d].0 = shape[d].0;
1008        }
1009        reordered
1010    }
1011
1012    /**
1013     * Returns a new Tensor which has the same data as this tensor, but with the order of the
1014     * dimensions and corresponding order of data changed.
1015     *
1016     * For example, with a `[("x", x), ("y", y)]` tensor you could call
1017     * `reorder(["y", "x"])` which would return a new tensor with a shape of
1018     * `[("y", y), ("x", x)]` where every (y,x) of its data corresponds to (x,y) in the original.
1019     *
1020     * This method need not shift *all* the dimensions though, you could also swap the width
1021     * and height of images in a tensor with a shape of
1022     * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `reorder(["batch", "w", "h", "c"])`
1023     * which would return a new tensor where every (b,w,h,c) of its data corresponds to (b,h,w,c)
1024     * in the original.
1025     *
1026     * See also: [TensorAccess], [transpose](TensorView::transpose)
1027     *
1028     * # Panics
1029     *
1030     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1031     * order need not match (and if the order does match, this function is just an expensive
1032     * clone).
1033     */
1034    #[track_caller]
1035    pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1036        let reorderd = match TensorAccess::try_from(&self.source, dimensions) {
1037            Ok(reordered) => reordered,
1038            Err(_error) => panic!(
1039                "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1040                dimensions,
1041                self.shape(),
1042            ),
1043        };
1044        let reorderd_shape = reorderd.shape();
1045        Tensor::from(reorderd_shape, reorderd.iter().collect())
1046    }
1047
1048    /**
1049     * Creates and returns a new tensor with all values from the original with the
1050     * function applied to each. This can be used to change the type of the tensor
1051     * such as creating a mask:
1052     * ```
1053     * use easy_ml::tensors::Tensor;
1054     * use easy_ml::tensors::views::TensorView;
1055     * let x = TensorView::from(Tensor::from([("a", 2), ("b", 2)], vec![
1056     *    0.0, 1.2,
1057     *    5.8, 6.9
1058     * ]));
1059     * let y = x.map(|element| element > 2.0);
1060     * let result = Tensor::from([("a", 2), ("b", 2)], vec![
1061     *    false, false,
1062     *    true, true
1063     * ]);
1064     * assert_eq!(&y, &result);
1065     * ```
1066     */
1067    pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
1068        let mapped = self.iter().map(mapping_function).collect();
1069        Tensor::from(self.shape(), mapped)
1070    }
1071
1072    /**
1073     * Creates and returns a new tensor with all values from the original and
1074     * the index of each value mapped by a function.
1075     */
1076    pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
1077        let mapped = self
1078            .iter()
1079            .with_index()
1080            .map(|(i, x)| mapping_function(i, x))
1081            .collect();
1082        Tensor::from(self.shape(), mapped)
1083    }
1084
1085    /**
1086     * Returns an iterator over copies of the data in this TensorView.
1087     */
1088    pub fn iter(&self) -> TensorIterator<'_, T, S, D> {
1089        TensorIterator::from(&self.source)
1090    }
1091
1092    /**
1093     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1094     * mapped by a function.
1095     *
1096     * ```
1097     * use easy_ml::tensors::Tensor;
1098     * use easy_ml::tensors::views::TensorView;
1099     * let lhs = TensorView::from(Tensor::from([("a", 4)], vec![1, 2, 3, 4]));
1100     * let rhs = TensorView::from(Tensor::from([("a", 4)], vec![0, 1, 2, 3]));
1101     * let multiplied = lhs.elementwise(&rhs, |l, r| l * r);
1102     * assert_eq!(
1103     *     multiplied,
1104     *     Tensor::from([("a", 4)], vec![0, 2, 6, 12])
1105     * );
1106     * ```
1107     *
1108     * # Generics
1109     *
1110     * This method can be called with any right hand side that can be converted to a TensorView,
1111     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1112     *
1113     * # Panics
1114     *
1115     * If the two tensors have different shapes.
1116     */
1117    #[track_caller]
1118    pub fn elementwise<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1119    where
1120        I: Into<TensorView<T, S2, D>>,
1121        S2: TensorRef<T, D>,
1122        M: Fn(T, T) -> T,
1123    {
1124        self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
1125            mapping_function(lhs.clone(), rhs.clone())
1126        })
1127    }
1128
1129    /**
1130     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1131     * mapped by a function. The mapping function also receives each index corresponding to the
1132     * value pairs.
1133     *
1134     * # Generics
1135     *
1136     * This method can be called with any right hand side that can be converted to a TensorView,
1137     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1138     *
1139     * # Panics
1140     *
1141     * If the two tensors have different shapes.
1142     */
1143    #[track_caller]
1144    pub fn elementwise_with_index<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1145    where
1146        I: Into<TensorView<T, S2, D>>,
1147        S2: TensorRef<T, D>,
1148        M: Fn([usize; D], T, T) -> T,
1149    {
1150        self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
1151            mapping_function(i, lhs.clone(), rhs.clone())
1152        })
1153    }
1154}
1155
1156/**
1157 * TensorView methods which require only read access via a scalar [TensorRef] source
1158 * and a clonable type.
1159 */
1160impl<T, S> TensorView<T, S, 0>
1161where
1162    T: Clone,
1163    S: TensorRef<T, 0>,
1164{
1165    /**
1166     * Returns a copy of the sole element in the 0 dimensional tensor.
1167     */
1168    pub fn scalar(&self) -> T {
1169        self.source.get_reference([]).unwrap().clone()
1170    }
1171}
1172
1173/**
1174 * TensorView methods which require mutable access via a [TensorMut] source and a [Default].
1175 */
1176impl<T, S> TensorView<T, S, 0>
1177where
1178    T: Default,
1179    S: TensorMut<T, 0>,
1180{
1181    /**
1182     * Returns the sole element in the 0 dimensional tensor.
1183     */
1184    pub fn into_scalar(self) -> T {
1185        TensorOwnedIterator::from(self.source).next().unwrap()
1186    }
1187}
1188
1189/**
1190 * TensorView methods which require mutable access via a [TensorMut] source.
1191 */
1192impl<T, S, const D: usize> TensorView<T, S, D>
1193where
1194    T: Clone,
1195    S: TensorMut<T, D>,
1196{
1197    /**
1198     * Applies a function to all values in the tensor view, modifying
1199     * the tensor in place.
1200     */
1201    pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1202        self.iter_reference_mut()
1203            .for_each(|x| *x = mapping_function(x.clone()));
1204    }
1205
1206    /**
1207     * Applies a function to all values and each value's index in the tensor view, modifying
1208     * the tensor view in place.
1209     */
1210    pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
1211        self.iter_reference_mut()
1212            .with_index()
1213            .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
1214    }
1215}
1216
1217impl<T, S> TensorView<T, S, 1>
1218where
1219    T: Numeric,
1220    for<'a> &'a T: NumericRef<T>,
1221    S: TensorRef<T, 1>,
1222{
1223    /**
1224     * Computes the scalar product of two equal length vectors. For two vectors `[a,b,c]` and
1225     * `[d,e,f]`, returns `a*d + b*e + c*f`. This is also known as the dot product.
1226     *
1227     * ```
1228     * use easy_ml::tensors::Tensor;
1229     * use easy_ml::tensors::views::TensorView;
1230     * let tensor_view = TensorView::from(Tensor::from([("sequence", 5)], vec![3, 4, 5, 6, 7]));
1231     * assert_eq!(tensor_view.scalar_product(&tensor_view), 3*3 + 4*4 + 5*5 + 6*6 + 7*7);
1232     * ```
1233     *
1234     * # Generics
1235     *
1236     * This method can be called with any right hand side that can be converted to a TensorView,
1237     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1238     *
1239     * # Panics
1240     *
1241     * If the two vectors are not of equal length or their dimension names do not match.
1242     */
1243    // Would like this impl block to be in operations.rs too but then it would show first in the
1244    // TensorView docs which isn't ideal
1245    pub fn scalar_product<S2, I>(&self, rhs: I) -> T
1246    where
1247        I: Into<TensorView<T, S2, 1>>,
1248        S2: TensorRef<T, 1>,
1249    {
1250        self.scalar_product_less_generic(rhs.into())
1251    }
1252}
1253
1254impl<T, S> TensorView<T, S, 2>
1255where
1256    T: Numeric,
1257    for<'a> &'a T: NumericRef<T>,
1258    S: TensorRef<T, 2>,
1259{
1260    /**
1261     * Returns the determinant of this square matrix, or None if the matrix
1262     * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant_tensor())
1263     */
1264    pub fn determinant(&self) -> Option<T> {
1265        linear_algebra::determinant_tensor::<T, _, _>(self)
1266    }
1267
1268    /**
1269     * Computes the inverse of a matrix provided that it exists. To have an inverse a
1270     * matrix must be square (same number of rows and columns) and it must also have a
1271     * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse_tensor())
1272     */
1273    pub fn inverse(&self) -> Option<Tensor<T, 2>> {
1274        linear_algebra::inverse_tensor::<T, _, _>(self)
1275    }
1276
1277    /**
1278     * Computes the covariance matrix for this feature matrix along the specified feature
1279     * dimension in this matrix. See [`linear_algebra`](crate::linear_algebra::covariance()).
1280     */
1281    pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
1282        linear_algebra::covariance::<T, _, _>(self, feature_dimension)
1283    }
1284}
1285
1286macro_rules! tensor_view_select_impl {
1287    (impl TensorView $d:literal 1) => {
1288        impl<T, S> TensorView<T, S, $d>
1289        where
1290            S: TensorRef<T, $d>,
1291        {
1292            /**
1293             * Selects the provided dimension name and index pairs in this TensorView, returning a
1294             * TensorView which has fewer dimensions than this TensorView, with the removed dimensions
1295             * always indexed as the provided values.
1296             *
1297             * This is a shorthand for manually constructing the TensorView and
1298             * [TensorIndex]
1299             *
1300             * Note: due to limitations in Rust's const generics support, this method is only
1301             * implemented for `provided_indexes` of length 1 and `D` from 1 to 6. You can fall
1302             * back to manual construction to create `TensorIndex`es with multiple provided
1303             * indexes if you need to reduce dimensionality by more than 1 dimension at a time.
1304             */
1305            #[track_caller]
1306            pub fn select(
1307                &self,
1308                provided_indexes: [(Dimension, usize); 1],
1309            ) -> TensorView<T, TensorIndex<T, &S, $d, 1>, { $d - 1 }> {
1310                TensorView::from(TensorIndex::from(&self.source, provided_indexes))
1311            }
1312
1313            /**
1314             * Selects the provided dimension name and index pairs in this TensorView, returning a
1315             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1316             * always indexed as the provided values. The TensorIndex mutably borrows this
1317             * Tensor, and can therefore mutate it if it implements TensorMut.
1318             *
1319             * See [select](TensorView::select)
1320             */
1321            #[track_caller]
1322            pub fn select_mut(
1323                &mut self,
1324                provided_indexes: [(Dimension, usize); 1],
1325            ) -> TensorView<T, TensorIndex<T, &mut S, $d, 1>, { $d - 1 }> {
1326                TensorView::from(TensorIndex::from(&mut self.source, provided_indexes))
1327            }
1328
1329            /**
1330             * Selects the provided dimension name and index pairs in this TensorView, returning a
1331             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1332             * always indexed as the provided values. The TensorIndex takes ownership of this
1333             * Tensor, and can therefore mutate it if it implements TensorMut.
1334             *
1335             * See [select](TensorView::select)
1336             */
1337            #[track_caller]
1338            pub fn select_owned(
1339                self,
1340                provided_indexes: [(Dimension, usize); 1],
1341            ) -> TensorView<T, TensorIndex<T, S, $d, 1>, { $d - 1 }> {
1342                TensorView::from(TensorIndex::from(self.source, provided_indexes))
1343            }
1344        }
1345    };
1346}
1347
1348tensor_view_select_impl!(impl TensorView 6 1);
1349tensor_view_select_impl!(impl TensorView 5 1);
1350tensor_view_select_impl!(impl TensorView 4 1);
1351tensor_view_select_impl!(impl TensorView 3 1);
1352tensor_view_select_impl!(impl TensorView 2 1);
1353tensor_view_select_impl!(impl TensorView 1 1);
1354
1355macro_rules! tensor_view_expand_impl {
1356    (impl Tensor $d:literal 1) => {
1357        impl<T, S> TensorView<T, S, $d>
1358        where
1359            S: TensorRef<T, $d>,
1360        {
1361            /**
1362             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1363             * a particular position within the shape, returning a TensorView which has more
1364             * dimensions than this TensorView.
1365             *
1366             * This is a shorthand for manually constructing the TensorView and
1367             * [TensorExpansion]
1368             *
1369             * Note: due to limitations in Rust's const generics support, this method is only
1370             * implemented for `extra_dimension_names` of length 1 and `D` from 0 to 5. You can
1371             * fall back to manual construction to create `TensorExpansion`s with multiple provided
1372             * indexes if you need to increase dimensionality by more than 1 dimension at a time.
1373             */
1374            #[track_caller]
1375            pub fn expand(
1376                &self,
1377                extra_dimension_names: [(usize, Dimension); 1],
1378            ) -> TensorView<T, TensorExpansion<T, &S, $d, 1>, { $d + 1 }> {
1379                TensorView::from(TensorExpansion::from(&self.source, extra_dimension_names))
1380            }
1381
1382            /**
1383             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1384             * a particular position within the shape, returning a TensorView which has more
1385             * dimensions than this Tensor. The TensorIndex mutably borrows this
1386             * Tensor, and can therefore mutate it if it implements TensorMut.
1387             *
1388             * See [expand](Tensor::expand)
1389             */
1390            #[track_caller]
1391            pub fn expand_mut(
1392                &mut self,
1393                extra_dimension_names: [(usize, Dimension); 1],
1394            ) -> TensorView<T, TensorExpansion<T, &mut S, $d, 1>, { $d + 1 }> {
1395                TensorView::from(TensorExpansion::from(
1396                    &mut self.source,
1397                    extra_dimension_names,
1398                ))
1399            }
1400
1401            /**
1402             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1403             * a particular position within the shape, returning a TensorView which has more
1404             * dimensions than this Tensor. The TensorIndex takes ownership of this
1405             * Tensor, and can therefore mutate it if it implements TensorMut.
1406             *
1407             * See [expand](Tensor::expand)
1408             */
1409            #[track_caller]
1410            pub fn expand_owned(
1411                self,
1412                extra_dimension_names: [(usize, Dimension); 1],
1413            ) -> TensorView<T, TensorExpansion<T, S, $d, 1>, { $d + 1 }> {
1414                TensorView::from(TensorExpansion::from(self.source, extra_dimension_names))
1415            }
1416        }
1417    };
1418}
1419
1420tensor_view_expand_impl!(impl Tensor 0 1);
1421tensor_view_expand_impl!(impl Tensor 1 1);
1422tensor_view_expand_impl!(impl Tensor 2 1);
1423tensor_view_expand_impl!(impl Tensor 3 1);
1424tensor_view_expand_impl!(impl Tensor 4 1);
1425tensor_view_expand_impl!(impl Tensor 5 1);
1426
1427/**
1428 * Debug implementations for TensorView additionally show the visible data and visible dimensions
1429 * reported by the source as fields. This is in addition to recursive debug content of the actual
1430 * source.
1431 */
1432impl<T, S, const D: usize> std::fmt::Debug for TensorView<T, S, D>
1433where
1434    T: std::fmt::Debug,
1435    S: std::fmt::Debug + TensorRef<T, D>,
1436{
1437    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1438        f.debug_struct("TensorView")
1439            .field("visible", &DebugSourceVisible::from(&self.source))
1440            .field("shape", &self.source.view_shape())
1441            .field("source", &self.source)
1442            .finish()
1443    }
1444}
1445
1446struct DebugSourceVisible<T, S, const D: usize> {
1447    source: S,
1448    _type: PhantomData<T>,
1449}
1450
1451impl<T, S, const D: usize> DebugSourceVisible<T, S, D>
1452where
1453    T: std::fmt::Debug,
1454    S: std::fmt::Debug + TensorRef<T, D>,
1455{
1456    fn from(source: S) -> DebugSourceVisible<T, S, D> {
1457        DebugSourceVisible {
1458            source,
1459            _type: PhantomData,
1460        }
1461    }
1462}
1463
1464impl<T, S, const D: usize> std::fmt::Debug for DebugSourceVisible<T, S, D>
1465where
1466    T: std::fmt::Debug,
1467    S: std::fmt::Debug + TensorRef<T, D>,
1468{
1469    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1470        f.debug_list()
1471            .entries(TensorReferenceIterator::from(&self.source))
1472            .finish()
1473    }
1474}
1475
1476#[test]
1477fn test_debug() {
1478    let x = Tensor::from([("rows", 3), ("columns", 4)], (0..12).collect());
1479    let view = TensorView::from(&x);
1480    let debugged = format!("{:?}\n{:?}", x, view);
1481    assert_eq!(
1482        debugged,
1483        r#"Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] }
1484TensorView { visible: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], source: Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] } }"#
1485    )
1486}
1487
1488#[test]
1489fn test_debug_clipped() {
1490    let x = Tensor::from([("rows", 2), ("columns", 3)], (0..6).collect());
1491    let view = TensorView::from(&x)
1492        .range_owned([("columns", IndexRange::new(1, 2))])
1493        .unwrap();
1494    let debugged = format!("{:#?}\n{:#?}", x, view);
1495    println!("{:#?}\n{:#?}", x, view);
1496    assert_eq!(
1497        debugged,
1498        r#"Tensor {
1499    data: [
1500        0,
1501        1,
1502        2,
1503        3,
1504        4,
1505        5,
1506    ],
1507    shape: [
1508        (
1509            "rows",
1510            2,
1511        ),
1512        (
1513            "columns",
1514            3,
1515        ),
1516    ],
1517    strides: [
1518        3,
1519        1,
1520    ],
1521}
1522TensorView {
1523    visible: [
1524        1,
1525        2,
1526        4,
1527        5,
1528    ],
1529    shape: [
1530        (
1531            "rows",
1532            2,
1533        ),
1534        (
1535            "columns",
1536            2,
1537        ),
1538    ],
1539    source: TensorRange {
1540        source: Tensor {
1541            data: [
1542                0,
1543                1,
1544                2,
1545                3,
1546                4,
1547                5,
1548            ],
1549            shape: [
1550                (
1551                    "rows",
1552                    2,
1553                ),
1554                (
1555                    "columns",
1556                    3,
1557                ),
1558            ],
1559            strides: [
1560                3,
1561                1,
1562            ],
1563        },
1564        range: [
1565            IndexRange {
1566                start: 0,
1567                length: 2,
1568            },
1569            IndexRange {
1570                start: 1,
1571                length: 2,
1572            },
1573        ],
1574        _type: PhantomData<i32>,
1575    },
1576}"#
1577    )
1578}
1579
1580/**
1581 * Any tensor view of a Displayable type implements Display
1582 *
1583 * You can control the precision of the formatting using format arguments, i.e.
1584 * `format!("{:.3}", tensor)`
1585 */
1586impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorView<T, S, D>
1587where
1588    T: std::fmt::Display,
1589    S: TensorRef<T, D>,
1590{
1591    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1592        crate::tensors::display::format_view(&self.source, f)
1593    }
1594}
1595
1596/**
1597 * A Tensor can be converted to a TensorView of that Tensor.
1598 */
1599impl<T, const D: usize> From<Tensor<T, D>> for TensorView<T, Tensor<T, D>, D> {
1600    fn from(tensor: Tensor<T, D>) -> TensorView<T, Tensor<T, D>, D> {
1601        TensorView::from(tensor)
1602    }
1603}
1604
1605/**
1606 * A reference to a Tensor can be converted to a TensorView of that referenced Tensor.
1607 */
1608impl<'a, T, const D: usize> From<&'a Tensor<T, D>> for TensorView<T, &'a Tensor<T, D>, D> {
1609    fn from(tensor: &Tensor<T, D>) -> TensorView<T, &Tensor<T, D>, D> {
1610        TensorView::from(tensor)
1611    }
1612}
1613
1614/**
1615 * A mutable reference to a Tensor can be converted to a TensorView of that mutably referenced
1616 * Tensor.
1617 */
1618impl<'a, T, const D: usize> From<&'a mut Tensor<T, D>> for TensorView<T, &'a mut Tensor<T, D>, D> {
1619    fn from(tensor: &mut Tensor<T, D>) -> TensorView<T, &mut Tensor<T, D>, D> {
1620        TensorView::from(tensor)
1621    }
1622}
1623
1624/**
1625 * A reference to a TensorView can be converted to an owned TensorView with a reference to the
1626 * source type of that first TensorView.
1627 */
1628impl<'a, T, S, const D: usize> From<&'a TensorView<T, S, D>> for TensorView<T, &'a S, D>
1629where
1630    S: TensorRef<T, D>,
1631{
1632    fn from(tensor_view: &TensorView<T, S, D>) -> TensorView<T, &S, D> {
1633        TensorView::from(tensor_view.source_ref())
1634    }
1635}
1636
1637/**
1638 * A mutable reference to a TensorView can be converted to an owned TensorView with a mutable
1639 * reference to the source type of that first TensorView.
1640 */
1641impl<'a, T, S, const D: usize> From<&'a mut TensorView<T, S, D>> for TensorView<T, &'a mut S, D>
1642where
1643    S: TensorRef<T, D>,
1644{
1645    fn from(tensor_view: &mut TensorView<T, S, D>) -> TensorView<T, &mut S, D> {
1646        TensorView::from(tensor_view.source_ref_mut())
1647    }
1648}