Skip to main content

easy_ml/tensors/
mod.rs

1/*!
2 * Generic N dimensional [named tensors](http://nlp.seas.harvard.edu/NamedTensor).
3 *
4 * Tensors are generic over some type `T` and some usize `D`. If `T` is [Numeric](super::numeric)
5 * then the tensor can be used in a mathematical way. `D` is the number of dimensions in the tensor
6 * and a compile time constant. Each tensor also carries `D` dimension name and length pairs.
7 */
8use crate::linear_algebra;
9use crate::numeric::extra::{Real, RealRef};
10use crate::numeric::{Numeric, NumericRef};
11use crate::tensors::indexing::{
12    ShapeIterator, TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
13    TensorReferenceMutIterator, TensorTranspose,
14};
15use crate::tensors::views::{
16    DataLayout, IndexRange, IndexRangeValidationError, TensorExpansion, TensorIndex, TensorMask,
17    TensorMut, TensorRange, TensorRef, TensorRename, TensorReshape, TensorReverse, TensorView,
18};
19
20use std::error::Error;
21use std::fmt;
22
23#[cfg(feature = "serde")]
24use serde::Serialize;
25
26pub mod dimensions;
27mod display;
28pub mod einsum;
29pub mod indexing;
30pub mod operations;
31pub mod views;
32
33#[cfg(feature = "serde")]
34pub use serde_impls::{TensorDeserialize, TensorDeserializeOwned};
35
36/**
37 * Dimension names are represented as static string references.
38 *
39 * This allows you to use string literals to refer to named dimensions, for example you might want
40 * to construct a tensor with a shape of
41 * `[("batch", 1000), ("height", 100), ("width", 100), ("rgba", 4)]`.
42 *
43 * Alternatively you can define the strings once as constants and refer to your dimension
44 * names by the constant identifiers.
45 *
46 * ```
47 * const BATCH: &'static str = "batch";
48 * const HEIGHT: &'static str = "height";
49 * const WIDTH: &'static str = "width";
50 * const RGBA: &'static str = "rgba";
51 * ```
52 *
53 * Although `Dimension` is interchangable with `&'static str` as it is just a type alias, Easy ML
54 * uses `Dimension` whenever dimension names are expected to distinguish the types from just
55 * strings.
56 */
57pub type Dimension = &'static str;
58
59/**
60 * An error indicating failure to do something with a Tensor because the requested shape
61 * is not valid.
62 */
63#[derive(Clone, Debug, Eq, PartialEq)]
64pub struct InvalidShapeError<const D: usize> {
65    shape: [(Dimension, usize); D],
66}
67
68impl<const D: usize> InvalidShapeError<D> {
69    /**
70     * Checks if this shape is valid. This is mainly for internal library use but may also be
71     * useful for unit testing.
72     *
73     * Note: in some functions and methods, an InvalidShapeError may be returned which is a valid
74     * shape, but not the right size for the quantity of data provided.
75     */
76    pub fn is_valid(&self) -> bool {
77        !crate::tensors::dimensions::has_duplicates(&self.shape)
78            && !self.shape.iter().any(|d| d.1 == 0)
79    }
80
81    /**
82     * Constructs an InvalidShapeError for assistance with unit testing. Note that you can
83     * construct an InvalidShapeError that *is* a valid shape in this way.
84     */
85    pub fn new(shape: [(Dimension, usize); D]) -> InvalidShapeError<D> {
86        InvalidShapeError { shape }
87    }
88
89    pub fn shape(&self) -> [(Dimension, usize); D] {
90        self.shape
91    }
92
93    pub fn shape_ref(&self) -> &[(Dimension, usize); D] {
94        &self.shape
95    }
96
97    // Panics if the shape is invalid for any reason with the appropriate error message.
98    #[track_caller]
99    #[inline]
100    fn validate_dimensions_or_panic(shape: &[(Dimension, usize); D], data_len: usize) {
101        let elements = crate::tensors::dimensions::elements(shape);
102        if data_len != elements {
103            panic!(
104                "Product of dimension lengths must match size of data. {} != {}",
105                elements, data_len
106            );
107        }
108        if crate::tensors::dimensions::has_duplicates(shape) {
109            panic!("Dimension names must all be unique: {:?}", &shape);
110        }
111        if shape.iter().any(|d| d.1 == 0) {
112            panic!("No dimension can have 0 elements: {:?}", &shape);
113        }
114    }
115
116    // Returns true if the shape is valid and matches the data length
117    fn validate_dimensions(shape: &[(Dimension, usize); D], data_len: usize) -> bool {
118        let elements = crate::tensors::dimensions::elements(shape);
119        data_len == elements
120            && !crate::tensors::dimensions::has_duplicates(shape)
121            && !shape.iter().any(|d| d.1 == 0)
122    }
123}
124
125impl<const D: usize> fmt::Display for InvalidShapeError<D> {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(
128            f,
129            "Dimensions must all be at least length 1 with unique names: {:?}",
130            self.shape
131        )
132    }
133}
134
135impl<const D: usize> Error for InvalidShapeError<D> {}
136
137/**
138 * An error indicating failure to do something with a Tensor because the dimension names that
139 * were provided did not match with the dimension names that were valid.
140 *
141 * Typically this would be due to the same dimension name being provided multiple times, or a
142 * dimension name being provided that is not present in the shape of the Tensor in use.
143 */
144#[derive(Clone, Debug, Eq, PartialEq)]
145pub struct InvalidDimensionsError<const D: usize, const P: usize> {
146    valid: [Dimension; D],
147    provided: [Dimension; P],
148}
149
150impl<const D: usize, const P: usize> InvalidDimensionsError<D, P> {
151    /**
152     * Checks if the provided dimensions have duplicate names. This is mainly for internal library
153     * use but may also be useful for unit testing.
154     */
155    pub fn has_duplicates(&self) -> bool {
156        crate::tensors::dimensions::has_duplicates_names(&self.provided)
157    }
158
159    // TODO: method to check provided is a subset of valid
160
161    /**
162     * Constructs an InvalidDimensions for assistance with unit testing.
163     */
164    pub fn new(provided: [Dimension; P], valid: [Dimension; D]) -> InvalidDimensionsError<D, P> {
165        InvalidDimensionsError { valid, provided }
166    }
167
168    pub fn provided_names(&self) -> [Dimension; P] {
169        self.provided
170    }
171
172    pub fn provided_names_ref(&self) -> &[Dimension; P] {
173        &self.provided
174    }
175
176    pub fn valid_names(&self) -> [Dimension; D] {
177        self.valid
178    }
179
180    pub fn valid_names_ref(&self) -> &[Dimension; D] {
181        &self.valid
182    }
183}
184
185impl<const D: usize, const P: usize> fmt::Display for InvalidDimensionsError<D, P> {
186    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187        if P > 0 {
188            write!(
189                f,
190                "Dimensions names {:?} were incorrect, valid dimensions in this context are: {:?}",
191                self.provided, self.valid
192            )
193        } else {
194            write!(f, "Dimensions names {:?} were incorrect", self.provided)
195        }
196    }
197}
198
199impl<const D: usize, const P: usize> Error for InvalidDimensionsError<D, P> {}
200
201#[test]
202fn test_sync() {
203    fn assert_sync<T: Sync>() {}
204    assert_sync::<InvalidShapeError<2>>();
205    assert_sync::<InvalidDimensionsError<2, 2>>();
206}
207
208#[test]
209fn test_send() {
210    fn assert_send<T: Send>() {}
211    assert_send::<InvalidShapeError<2>>();
212    assert_send::<InvalidDimensionsError<2, 2>>();
213}
214
215/**
216 * A [named tensor](http://nlp.seas.harvard.edu/NamedTensor) of some type `T` and number of
217 * dimensions `D`.
218 *
219 * Tensors are a generalisation of matrices; whereas [Matrix](crate::matrices::Matrix) only
220 * supports 2 dimensions, and vectors are represented in Matrix by making either the rows or
221 * columns have a length of one, [Tensor] supports an arbitary number of dimensions,
222 * with 0 through 6 having full API support. A `Tensor<T, 2>` is very similar to a `Matrix<T>`
223 * except that this type associates each dimension with a name, and favors names to refer to
224 * dimensions instead of index order.
225 *
226 * Like Matrix, the type of the data in this Tensor may implement no traits, in which case the
227 * tensor will be rather useless. If the type implements Clone most storage and accessor methods
228 * are defined and if the type implements Numeric then the tensor can be used in a mathematical
229 * way.
230 *
231 * Like Matrix, a Tensor must always contain at least one element, and it may not not have more
232 * elements than `std::isize::MAX`. Concerned readers should note that on a 64 bit computer this
233 * maximum value is 9,223,372,036,854,775,807 so running out of memory is likely to occur first.
234 *
235 * When doing numeric operations with Tensors you should be careful to not consume a tensor by
236 * accidentally using it by value. All the operations are also defined on references to tensors
237 * so you should favor &x + &y style notation for tensors you intend to continue using.
238 *
239 * See also:
240 * - [indexing]
241 */
242#[derive(Debug)]
243#[cfg_attr(feature = "serde", derive(Serialize))]
244pub struct Tensor<T, const D: usize> {
245    data: Vec<T>,
246    #[cfg_attr(feature = "serde", serde(with = "serde_arrays"))]
247    shape: [(Dimension, usize); D],
248    #[cfg_attr(feature = "serde", serde(skip))]
249    strides: [usize; D],
250}
251
252impl<T, const D: usize> Tensor<T, D> {
253    /**
254     * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
255     *
256     * The product of the dimension lengths corresponds to the number of elements the Tensor
257     * will store. Elements are stored in what would be row major order for a Matrix.
258     * Each step in memory through the N dimensions corresponds to incrementing the rightmost
259     * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
260     * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
261     * for the 25th and final element.
262     *
263     * # Panics
264     *
265     * - If the number of provided elements does not match the product of the dimension lengths.
266     * - If a dimension name is not unique
267     * - If any dimension has 0 elements
268     *
269     * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
270     * a single element (since the product of an empty list is 1).
271     */
272    #[track_caller]
273    pub fn from(shape: [(Dimension, usize); D], data: Vec<T>) -> Self {
274        InvalidShapeError::validate_dimensions_or_panic(&shape, data.len());
275        let strides = compute_strides(&shape);
276        Tensor {
277            data,
278            shape,
279            strides,
280        }
281    }
282
283    /**
284     * Creates a Tensor with a particular shape initialised from a function.
285     *
286     * The product of the dimension lengths corresponds to the number of elements the Tensor
287     * will store. Elements are stored in what would be row major order for a Matrix.
288     * Each step in memory through the N dimensions corresponds to incrementing the rightmost
289     * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
290     * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
291     * for the 25th and final element. These same indexes will be passed to the producer function
292     * to initialised the values for the Tensor.
293     *
294     * ```
295     * use easy_ml::tensors::Tensor;
296     * let tensor = Tensor::from_fn([("rows", 4), ("columns", 4)], |[r, c]| r * c);
297     * assert_eq!(
298     *     tensor,
299     *     Tensor::from([("rows", 4), ("columns", 4)], vec![
300     *         0, 0, 0, 0,
301     *         0, 1, 2, 3,
302     *         0, 2, 4, 6,
303     *         0, 3, 6, 9,
304     *     ])
305     * );
306     * ```
307     *
308     * # Panics
309     *
310     * - If a dimension name is not unique
311     * - If any dimension has 0 elements
312     *
313     * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
314     * a single element (since the product of an empty list is 1).
315     */
316    #[track_caller]
317    pub fn from_fn<F>(shape: [(Dimension, usize); D], mut producer: F) -> Self
318    where
319        F: FnMut([usize; D]) -> T,
320    {
321        let length = dimensions::elements(&shape);
322        let mut data = Vec::with_capacity(length);
323        let iterator = ShapeIterator::from(shape);
324        for index in iterator {
325            data.push(producer(index));
326        }
327        Tensor::from(shape, data)
328    }
329
330    /**
331     * The shape of this tensor. Since Tensors are named Tensors, their shape is not just a
332     * list of lengths along each dimension, but instead a list of pairs of names and lengths.
333     *
334     * See also
335     * - [dimensions]
336     * - [indexing]
337     */
338    pub fn shape(&self) -> [(Dimension, usize); D] {
339        self.shape
340    }
341
342    /**
343     * Returns the length of the dimension name provided, if one is present in the Tensor.
344     *
345     * See also
346     * - [dimensions]
347     * - [indexing]
348     */
349    pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
350        dimensions::length_of(&self.shape, dimension)
351    }
352
353    /**
354     * Returns the last index of the dimension name provided, if one is present in the Tensor.
355     *
356     * This is always 1 less than the length, the 'index' in this sense is based on what the
357     * Tensor's shape is, not any implementation index.
358     *
359     * See also
360     * - [dimensions]
361     * - [indexing]
362     */
363    pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
364        dimensions::last_index_of(&self.shape, dimension)
365    }
366
367    /**
368     * A non panicking version of [from](Tensor::from) which returns `Result::Err` if the input
369     * is invalid.
370     *
371     * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
372     *
373     * The product of the dimension lengths corresponds to the number of elements the Tensor
374     * will store. Elements are stored in what would be row major order for a Matrix.
375     * Each step in memory through the N dimensions corresponds to incrementing the rightmost
376     * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
377     * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
378     * for the 25th and final element.
379     *
380     * Returns the Err variant if
381     * - If the number of provided elements does not match the product of the dimension lengths.
382     * - If a dimension name is not unique
383     * - If any dimension has 0 elements
384     *
385     * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
386     * a single element (since the product of an empty list is 1).
387     */
388    pub fn try_from(
389        shape: [(Dimension, usize); D],
390        data: Vec<T>,
391    ) -> Result<Self, InvalidShapeError<D>> {
392        let valid = InvalidShapeError::validate_dimensions(&shape, data.len());
393        if !valid {
394            return Err(InvalidShapeError::new(shape));
395        }
396        let strides = compute_strides(&shape);
397        Ok(Tensor {
398            data,
399            shape,
400            strides,
401        })
402    }
403
404    /// Unverified constructor for interal use when we know the dimensions/data/strides are
405    /// unchanged and don't need reverification
406    pub(crate) fn direct_from(
407        data: Vec<T>,
408        shape: [(Dimension, usize); D],
409        strides: [usize; D],
410    ) -> Self {
411        Tensor {
412            data,
413            shape,
414            strides,
415        }
416    }
417
418    /// Unverified constructor for interal use when we know the dimensions/data/strides are
419    /// the same as the existing instance and don't need reverification
420    #[allow(dead_code)] // pretty sure something else will want this in the future
421    pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> Self {
422        Tensor {
423            data,
424            shape: self.shape,
425            strides: self.strides,
426        }
427    }
428}
429
430impl<T> Tensor<T, 0> {
431    /**
432     * Creates a 0 dimensional tensor from some scalar
433     */
434    pub fn from_scalar(value: T) -> Tensor<T, 0> {
435        Tensor {
436            data: vec![value],
437            shape: [],
438            strides: [],
439        }
440    }
441
442    /**
443     * Returns the sole element of the 0 dimensional tensor.
444     */
445    pub fn into_scalar(self) -> T {
446        self.data
447            .into_iter()
448            .next()
449            .expect("Tensors always have at least 1 element")
450    }
451}
452
453impl<T> Tensor<T, 0>
454where
455    T: Clone,
456{
457    /**
458     * Returns a copy of the sole element in the 0 dimensional tensor.
459     */
460    pub fn scalar(&self) -> T {
461        self.data
462            .first()
463            .expect("Tensors always have at least 1 element")
464            .clone()
465    }
466}
467
468impl<T> From<T> for Tensor<T, 0> {
469    fn from(scalar: T) -> Tensor<T, 0> {
470        Tensor::from_scalar(scalar)
471    }
472}
473// TODO: See if we can find a way to write the reverse Tensor<T, 0> -> T conversion using From or Into (doesn't seem like we can?)
474
475// # Safety
476//
477// We promise to never implement interior mutability for Tensor.
478/**
479 * A Tensor implements TensorRef.
480 */
481unsafe impl<T, const D: usize> TensorRef<T, D> for Tensor<T, D> {
482    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
483        let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
484        self.data.get(i)
485    }
486
487    fn view_shape(&self) -> [(Dimension, usize); D] {
488        Tensor::shape(self)
489    }
490
491    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
492        unsafe {
493            // The point of get_reference_unchecked is no bounds checking, and therefore
494            // it does not make any sense to just use `unwrap` here. The trait documents that
495            // it's undefind behaviour to call this method with an out of bounds index, so we
496            // can assume the None case will never happen.
497            let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
498            self.data.get_unchecked(i)
499        }
500    }
501
502    fn data_layout(&self) -> DataLayout<D> {
503        // We always have our memory in most significant to least
504        DataLayout::Linear(std::array::from_fn(|i| self.shape[i].0))
505    }
506}
507
508// # Safety
509//
510// We promise to never implement interior mutability for Tensor.
511unsafe impl<T, const D: usize> TensorMut<T, D> for Tensor<T, D> {
512    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
513        let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
514        self.data.get_mut(i)
515    }
516
517    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
518        unsafe {
519            // The point of get_reference_unchecked_mut is no bounds checking, and therefore
520            // it does not make any sense to just use `unwrap` here. The trait documents that
521            // it's undefind behaviour to call this method with an out of bounds index, so we
522            // can assume the None case will never happen.
523            let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
524            self.data.get_unchecked_mut(i)
525        }
526    }
527}
528
529/**
530 * Any tensor of a Cloneable type implements Clone.
531 */
532impl<T: Clone, const D: usize> Clone for Tensor<T, D> {
533    fn clone(&self) -> Self {
534        self.map(|element| element)
535    }
536}
537
538/**
539 * Any tensor of a Displayable type implements Display
540 *
541 * You can control the precision of the formatting using format arguments, i.e.
542 * `format!("{:.3}", tensor)`
543 */
544impl<T: std::fmt::Display, const D: usize> std::fmt::Display for Tensor<T, D> {
545    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
546        crate::tensors::display::format_view(self, f)
547    }
548}
549
550/**
551 * Any 2 dimensional tensor can be converted to a matrix with rows equal to the length of the
552 * first dimension in the tensor, and columns equal to the length of the second.
553 */
554impl<T> From<Tensor<T, 2>> for crate::matrices::Matrix<T> {
555    fn from(tensor: Tensor<T, 2>) -> Self {
556        crate::matrices::Matrix::from_flat_row_major(
557            (tensor.shape[0].1, tensor.shape[1].1),
558            tensor.data,
559        )
560    }
561}
562
563pub(crate) fn compute_strides<const D: usize>(shape: &[(Dimension, usize); D]) -> [usize; D] {
564    std::array::from_fn(|d| shape.iter().skip(d + 1).map(|d| d.1).product())
565}
566
567/// returns the 1 dimensional index to use to get the requested index into some tensor
568#[inline]
569pub(crate) fn get_index_direct<const D: usize>(
570    // indexes to use
571    indexes: &[usize; D],
572    // strides for indexing into the tensor
573    strides: &[usize; D],
574    // shape of the tensor to index into
575    shape: &[(Dimension, usize); D],
576) -> Option<usize> {
577    let mut index = 0;
578    for d in 0..D {
579        let n = indexes[d];
580        if n >= shape[d].1 {
581            return None;
582        }
583        index += n * strides[d];
584    }
585    Some(index)
586}
587
588/// returns the 1 dimensional index to use to get the requested index into some tensor, without
589/// checking the indexes are within bounds for the shape.
590#[inline]
591fn get_index_direct_unchecked<const D: usize>(
592    // indexes to use
593    indexes: &[usize; D],
594    // strides for indexing into the tensor
595    strides: &[usize; D],
596) -> usize {
597    let mut index = 0;
598    for d in 0..D {
599        let n = indexes[d];
600        index += n * strides[d];
601    }
602    index
603}
604
605impl<T, const D: usize> Tensor<T, D> {
606    pub fn view(&self) -> TensorView<T, &Tensor<T, D>, D> {
607        TensorView::from(self)
608    }
609
610    pub fn view_mut(&mut self) -> TensorView<T, &mut Tensor<T, D>, D> {
611        TensorView::from(self)
612    }
613
614    pub fn view_owned(self) -> TensorView<T, Tensor<T, D>, D> {
615        TensorView::from(self)
616    }
617
618    /**
619     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
620     * to read values from this tensor.
621     *
622     * # Panics
623     *
624     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
625     */
626    #[track_caller]
627    pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &Tensor<T, D>, D> {
628        TensorAccess::from(self, dimensions)
629    }
630
631    /**
632     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
633     * to read or write values from this tensor.
634     *
635     * # Panics
636     *
637     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
638     */
639    #[track_caller]
640    pub fn index_by_mut(
641        &mut self,
642        dimensions: [Dimension; D],
643    ) -> TensorAccess<T, &mut Tensor<T, D>, D> {
644        TensorAccess::from(self, dimensions)
645    }
646
647    /**
648     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
649     * to read or write values from this tensor.
650     *
651     * # Panics
652     *
653     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
654     */
655    #[track_caller]
656    pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, Tensor<T, D>, D> {
657        TensorAccess::from(self, dimensions)
658    }
659
660    /**
661     * Creates a TensorAccess which will index into the dimensions this Tensor was created with
662     * in the same order as they were provided. See [TensorAccess::from_source_order].
663     */
664    pub fn index(&self) -> TensorAccess<T, &Tensor<T, D>, D> {
665        TensorAccess::from_source_order(self)
666    }
667
668    /**
669     * Creates a TensorAccess which will index into the dimensions this Tensor was
670     * created with in the same order as they were provided. The TensorAccess mutably borrows
671     * the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
672     */
673    pub fn index_mut(&mut self) -> TensorAccess<T, &mut Tensor<T, D>, D> {
674        TensorAccess::from_source_order(self)
675    }
676
677    /**
678     * Creates a TensorAccess which will index into the dimensions this Tensor was
679     * created with in the same order as they were provided. The TensorAccess takes ownership
680     * of the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
681     */
682    pub fn index_owned(self) -> TensorAccess<T, Tensor<T, D>, D> {
683        TensorAccess::from_source_order(self)
684    }
685
686    /**
687     * Returns an iterator over references to the data in this Tensor.
688     */
689    pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, Tensor<T, D>, D> {
690        TensorReferenceIterator::from(self)
691    }
692
693    /**
694     * Returns an iterator over mutable references to the data in this Tensor.
695     */
696    pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<'_, T, Tensor<T, D>, D> {
697        TensorReferenceMutIterator::from(self)
698    }
699
700    /**
701     * Creates an iterator over the values in this Tensor.
702     */
703    pub fn iter_owned(self) -> TensorOwnedIterator<T, Tensor<T, D>, D>
704    where
705        T: Default,
706    {
707        TensorOwnedIterator::from(self)
708    }
709
710    // Non public index order reference iterator since we don't want to expose our implementation
711    // details to public API since then we could never change them.
712    pub(crate) fn direct_iter_reference(&self) -> std::slice::Iter<'_, T> {
713        self.data.iter()
714    }
715
716    // Non public index order reference iterator since we don't want to expose our implementation
717    // details to public API since then we could never change them.
718    pub(crate) fn direct_iter_reference_mut(&mut self) -> std::slice::IterMut<'_, T> {
719        self.data.iter_mut()
720    }
721
722    /**
723     * Renames the dimension names of the tensor without changing the lengths of the dimensions
724     * in the tensor or moving any data around.
725     *
726     * ```
727     * use easy_ml::tensors::Tensor;
728     * let mut tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6]);
729     * tensor.rename(["y", "z"]);
730     * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
731     * ```
732     *
733     * # Panics
734     *
735     * - If a dimension name is not unique
736     */
737    #[track_caller]
738    pub fn rename(&mut self, dimensions: [Dimension; D]) {
739        if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
740            panic!("Dimension names must all be unique: {:?}", &dimensions);
741        }
742        #[allow(clippy::needless_range_loop)]
743        for d in 0..D {
744            self.shape[d].0 = dimensions[d];
745        }
746    }
747
748    /**
749     * Renames the dimension names of the tensor and returns it without changing the lengths
750     * of the dimensions in the tensor or moving any data around.
751     *
752     * ```
753     * use easy_ml::tensors::Tensor;
754     * let tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6])
755     *     .rename_owned(["y", "z"]);
756     * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
757     * ```
758     *
759     * # Panics
760     *
761     * - If a dimension name is not unique
762     */
763    #[track_caller]
764    pub fn rename_owned(mut self, dimensions: [Dimension; D]) -> Tensor<T, D> {
765        self.rename(dimensions);
766        self
767    }
768
769    /**
770     * Returns a TensorView with the dimension names of the shape renamed to the provided
771     * dimensions. The data of this tensor and the dimension lengths and order remain unchanged.
772     *
773     * This is a shorthand for constructing the TensorView from this Tensor.
774     *
775     * ```
776     * use easy_ml::tensors::Tensor;
777     * use easy_ml::tensors::views::{TensorView, TensorRename};
778     * let abc = Tensor::from([("a", 3), ("b", 3), ("c", 3)], (0..27).collect());
779     * let xyz = abc.rename_view(["x", "y", "z"]);
780     * let also_xyz = TensorView::from(TensorRename::from(&abc, ["x", "y", "z"]));
781     * assert_eq!(xyz, also_xyz);
782     * assert_eq!(xyz, Tensor::from([("x", 3), ("y", 3), ("z", 3)], (0..27).collect()));
783     * ```
784     *
785     * # Panics
786     *
787     * - If a dimension name is not unique
788     */
789    #[track_caller]
790    pub fn rename_view(
791        &self,
792        dimensions: [Dimension; D],
793    ) -> TensorView<T, TensorRename<T, &Tensor<T, D>, D>, D> {
794        TensorView::from(TensorRename::from(self, dimensions))
795    }
796
797    /**
798     * Changes the shape of the tensor without changing the number of dimensions or moving any
799     * data around.
800     *
801     * # Panics
802     *
803     * - If the number of provided elements in the new shape does not match the product of the
804     * dimension lengths in the existing tensor's shape.
805     * - If a dimension name is not unique
806     * - If any dimension has 0 elements
807     *
808     * ```
809     * use easy_ml::tensors::Tensor;
810     * let mut tensor = Tensor::from([("width", 2), ("height", 2)], vec![
811     *     1, 2,
812     *     3, 4
813     * ]);
814     * tensor.reshape_mut([("batch", 1), ("image", 4)]);
815     * assert_eq!(tensor, Tensor::from([("batch", 1), ("image", 4)], vec![ 1, 2, 3, 4 ]));
816     * ```
817     */
818    #[track_caller]
819    pub fn reshape_mut(&mut self, shape: [(Dimension, usize); D]) {
820        InvalidShapeError::validate_dimensions_or_panic(&shape, self.data.len());
821        let strides = compute_strides(&shape);
822        self.shape = shape;
823        self.strides = strides;
824    }
825
826    /**
827     * Consumes the tensor and changes the shape of the tensor without moving any
828     * data around. The new Tensor may also have a different number of dimensions.
829     *
830     * # Panics
831     *
832     * - If the number of provided elements in the new shape does not match the product of the
833     * dimension lengths in the existing tensor's shape.
834     * - If a dimension name is not unique
835     *
836     * ```
837     * use easy_ml::tensors::Tensor;
838     * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
839     *     1, 2,
840     *     3, 4
841     * ]);
842     * let flattened = tensor.reshape_owned([("image", 4)]);
843     * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
844     * ```
845     *
846     * See also [reshape_view_owned](Tensor::reshape_view_owned)
847     */
848    #[track_caller]
849    pub fn reshape_owned<const D2: usize>(self, shape: [(Dimension, usize); D2]) -> Tensor<T, D2> {
850        Tensor::from(shape, self.data)
851    }
852
853    /**
854     * Returns a TensorView with the dimensions changed to the provided shape without moving any
855     * data around. The new Tensor may also have a different number of dimensions.
856     *
857     * This is a shorthand for constructing the TensorView from this Tensor.
858     *
859     * # Panics
860     *
861     * - If the number of provided elements in the new shape does not match the product of the
862     * dimension lengths in the existing tensor's shape.
863     * - If a dimension name is not unique
864     *
865     * ```
866     * use easy_ml::tensors::Tensor;
867     * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
868     *     1, 2,
869     *     3, 4
870     * ]);
871     * let flattened = tensor.reshape_view([("image", 4)]);
872     * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
873     * ```
874     */
875    pub fn reshape_view<const D2: usize>(
876        &self,
877        shape: [(Dimension, usize); D2],
878    ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, D2>, D2> {
879        TensorView::from(TensorReshape::from(self, shape))
880    }
881
882    /**
883     * Returns a TensorView with the dimensions changed to the provided shape without moving any
884     * data around. The new Tensor may also have a different number of dimensions.
885     *
886     * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
887     * mutably borrows this Tensor, and can therefore mutate it
888     *
889     * # Panics
890     *
891     * - If the number of provided elements in the new shape does not match the product of the
892     * dimension lengths in the existing tensor's shape.
893     * - If a dimension name is not unique
894     */
895    pub fn reshape_view_mut<const D2: usize>(
896        &mut self,
897        shape: [(Dimension, usize); D2],
898    ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, D2>, D2> {
899        TensorView::from(TensorReshape::from(self, shape))
900    }
901
902    /**
903     * Returns a TensorView with the dimensions changed to the provided shape without moving any
904     * data around. The new Tensor may also have a different number of dimensions.
905     *
906     * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
907     * takes ownership of this Tensor, and can therefore mutate it
908     *
909     * # Panics
910     *
911     * - If the number of provided elements in the new shape does not match the product of the
912     * dimension lengths in the existing tensor's shape.
913     * - If a dimension name is not unique
914     *
915     * See also [reshape_owned](Tensor::reshape_owned)
916     */
917    pub fn reshape_view_owned<const D2: usize>(
918        self,
919        shape: [(Dimension, usize); D2],
920    ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, D2>, D2> {
921        TensorView::from(TensorReshape::from(self, shape))
922    }
923
924    /**
925     * Given the dimension name, returns a view of this tensor reshaped to one dimension
926     * with a length equal to the number of elements in this tensor.
927     */
928    pub fn flatten_view(
929        &self,
930        dimension: Dimension,
931    ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, 1>, 1> {
932        self.reshape_view([(dimension, dimensions::elements(&self.shape))])
933    }
934
935    /**
936     * Given the dimension name, returns a view of this tensor reshaped to one dimension
937     * with a length equal to the number of elements in this tensor.
938     */
939    pub fn flatten_view_mut(
940        &mut self,
941        dimension: Dimension,
942    ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, 1>, 1> {
943        self.reshape_view_mut([(dimension, dimensions::elements(&self.shape))])
944    }
945
946    /**
947     * Given the dimension name, returns a view of this tensor reshaped to one dimension
948     * with a length equal to the number of elements in this tensor.
949     *
950     * If you intend to query the tensor a lot after creating the view, consider
951     * using [flatten](Tensor::flatten) instead as it will have less overhead to
952     * index after creation.
953     */
954    pub fn flatten_view_owned(
955        self,
956        dimension: Dimension,
957    ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, 1>, 1> {
958        let length = dimensions::elements(&self.shape);
959        self.reshape_view_owned([(dimension, length)])
960    }
961
962    /**
963     * Given the dimension name, returns a new tensor reshaped to one dimension
964     * with a length equal to the number of elements in this tensor.
965     */
966    pub fn flatten(self, dimension: Dimension) -> Tensor<T, 1> {
967        let length = dimensions::elements(&self.shape);
968        self.reshape_owned([(dimension, length)])
969    }
970
971    /**
972     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
973     * range from view. Error cases are documented on [TensorRange].
974     *
975     * This is a shorthand for constructing the TensorView from this Tensor.
976     *
977     * ```
978     * use easy_ml::tensors::Tensor;
979     * use easy_ml::tensors::views::{TensorView, TensorRange, IndexRange};
980     * # use easy_ml::tensors::views::IndexRangeValidationError;
981     * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
982     * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
983     * let cropped = samples.range([("x", IndexRange::new(1, 5)), ("y", IndexRange::new(1, 5))])?;
984     * let also_cropped = TensorView::from(
985     *     TensorRange::from(&samples, [("x", 1..6), ("y", 1..6)])?
986     * );
987     * assert_eq!(cropped, also_cropped);
988     * assert_eq!(
989     *     cropped.select([("batch", 0)]),
990     *     Tensor::from([("x", 5), ("y", 5)], vec![
991     *          8,  9, 10, 11, 12,
992     *         15, 16, 17, 18, 19,
993     *         22, 23, 24, 25, 26,
994     *         29, 30, 31, 32, 33,
995     *         36, 37, 38, 39, 40
996     *     ])
997     * );
998     * # Ok(())
999     * # }
1000     * ```
1001     */
1002    pub fn range<R, const P: usize>(
1003        &self,
1004        ranges: [(Dimension, R); P],
1005    ) -> Result<TensorView<T, TensorRange<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1006    where
1007        R: Into<IndexRange>,
1008    {
1009        TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1010    }
1011
1012    /**
1013     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1014     * range from view. Error cases are documented on [TensorRange]. The TensorRange
1015     * mutably borrows this Tensor, and can therefore mutate it
1016     *
1017     * This is a shorthand for constructing the TensorView from this Tensor.
1018     */
1019    pub fn range_mut<R, const P: usize>(
1020        &mut self,
1021        ranges: [(Dimension, R); P],
1022    ) -> Result<
1023        TensorView<T, TensorRange<T, &mut Tensor<T, D>, D>, D>,
1024        IndexRangeValidationError<D, P>,
1025    >
1026    where
1027        R: Into<IndexRange>,
1028    {
1029        TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1030    }
1031
1032    /**
1033     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1034     * range from view. Error cases are documented on [TensorRange]. The TensorRange
1035     * takes ownership of this Tensor, and can therefore mutate it
1036     *
1037     * This is a shorthand for constructing the TensorView from this Tensor.
1038     */
1039    pub fn range_owned<R, const P: usize>(
1040        self,
1041        ranges: [(Dimension, R); P],
1042    ) -> Result<TensorView<T, TensorRange<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1043    where
1044        R: Into<IndexRange>,
1045    {
1046        TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1047    }
1048
1049    /**
1050     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1051     * range from view. Error cases are documented on [TensorMask].
1052     *
1053     * This is a shorthand for constructing the TensorView from this Tensor.
1054     *
1055     * ```
1056     * use easy_ml::tensors::Tensor;
1057     * use easy_ml::tensors::views::{TensorView, TensorMask, IndexRange};
1058     * # use easy_ml::tensors::views::IndexRangeValidationError;
1059     * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
1060     * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
1061     * let corners = samples.mask([("x", IndexRange::new(3, 2)), ("y", IndexRange::new(3, 2))])?;
1062     * let also_corners = TensorView::from(
1063     *     TensorMask::from(&samples, [("x", 3..5), ("y", 3..5)])?
1064     * );
1065     * assert_eq!(corners, also_corners);
1066     * assert_eq!(
1067     *     corners.select([("batch", 0)]),
1068     *     Tensor::from([("x", 5), ("y", 5)], vec![
1069     *          0,  1,  2,    5, 6,
1070     *          7,  8,  9,   12, 13,
1071     *         14, 15, 16,   19, 20,
1072     *
1073     *         35, 36, 37,   40, 41,
1074     *         42, 43, 44,   47, 48
1075     *     ])
1076     * );
1077     * # Ok(())
1078     * # }
1079     * ```
1080     */
1081    pub fn mask<R, const P: usize>(
1082        &self,
1083        masks: [(Dimension, R); P],
1084    ) -> Result<TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1085    where
1086        R: Into<IndexRange>,
1087    {
1088        TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1089    }
1090
1091    /**
1092     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1093     * range from view. Error cases are documented on [TensorMask]. The TensorMask
1094     * mutably borrows this Tensor, and can therefore mutate it
1095     *
1096     * This is a shorthand for constructing the TensorView from this Tensor.
1097     */
1098    pub fn mask_mut<R, const P: usize>(
1099        &mut self,
1100        masks: [(Dimension, R); P],
1101    ) -> Result<
1102        TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D>,
1103        IndexRangeValidationError<D, P>,
1104    >
1105    where
1106        R: Into<IndexRange>,
1107    {
1108        TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1109    }
1110
1111    /**
1112     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1113     * range from view. Error cases are documented on [TensorMask]. The TensorMask
1114     * takes ownership of this Tensor, and can therefore mutate it
1115     *
1116     * This is a shorthand for constructing the TensorView from this Tensor.
1117     */
1118    pub fn mask_owned<R, const P: usize>(
1119        self,
1120        masks: [(Dimension, R); P],
1121    ) -> Result<TensorView<T, TensorMask<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1122    where
1123        R: Into<IndexRange>,
1124    {
1125        TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1126    }
1127
1128    /**
1129     * Returns a TensorView with a mask taken in the provided dimension, hiding
1130     * all but the start_and_end number of values at the start and end of the
1131     * dimension from view.
1132     *
1133     * This is a shorthand for constructing the TensorView from this Tensor.
1134     *
1135     * # Panics
1136     *
1137     * - If the start_and_end value is 0 - this is not a valid mask as it would
1138     * hide all elements
1139     * - If the dimension is not in the tensor's shape.
1140     *
1141     * ```
1142     * use easy_ml::tensors::Tensor;
1143     * let samples = Tensor::from([("batch", 5), ("x", 2), ("y", 2)], (0..20).collect());
1144     * let shortlist = samples.start_and_end_of("batch", 1);
1145     * assert_eq!(
1146     *     shortlist,
1147     *     Tensor::from([("batch", 2), ("x", 2), ("y", 2)], vec![
1148     *         0, 1,
1149     *         2, 3,
1150     *
1151     *         16, 17,
1152     *         18, 19,
1153     *     ])
1154     * )
1155     * ```
1156     */
1157    #[track_caller]
1158    pub fn start_and_end_of(
1159        &self,
1160        dimension: Dimension,
1161        start_and_end: usize,
1162    ) -> TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D> {
1163        TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
1164    }
1165
1166    /**
1167     * Returns a TensorView with a mask taken in the provided dimension, hiding
1168     * all but the start_and_end number of values at the start and end of the
1169     * dimension from view. The TensorMask mutably borrows this Tensor, and can
1170     * therefore mutate it
1171     *
1172     * This is a shorthand for constructing the TensorView from this Tensor.
1173     *
1174     * # Panics
1175     *
1176     * - If the start_and_end value is 0 - this is not a valid mask as it would
1177     * hide all elements
1178     * - If the dimension is not in the tensor's shape.
1179     */
1180    #[track_caller]
1181    pub fn start_and_end_of_mut(
1182        &mut self,
1183        dimension: Dimension,
1184        start_and_end: usize,
1185    ) -> TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D> {
1186        TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
1187    }
1188
1189    /**
1190     * Returns a TensorView with a mask taken in the provided dimension, hiding
1191     * all but the start_and_end number of values at the start and end of the
1192     * dimension from view. The TensorMask takes ownership of this Tensor, and
1193     * can therefore mutate it
1194     *
1195     * This is a shorthand for constructing the TensorView from this Tensor.
1196     *
1197     * # Panics
1198     *
1199     * - If the start_and_end value is 0 - this is not a valid mask as it would
1200     * hide all elements
1201     * - If the dimension is not in the tensor's shape.
1202     */
1203    #[track_caller]
1204    pub fn start_and_end_of_owned(
1205        self,
1206        dimension: Dimension,
1207        start_and_end: usize,
1208    ) -> TensorView<T, TensorMask<T, Tensor<T, D>, D>, D> {
1209        TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
1210    }
1211
1212    /**
1213     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1214     * order. The data of this tensor and the dimension lengths remain unchanged.
1215     *
1216     * This is a shorthand for constructing the TensorView from this Tensor.
1217     *
1218     * ```
1219     * use easy_ml::tensors::Tensor;
1220     * use easy_ml::tensors::views::{TensorView, TensorReverse};
1221     * let ab = Tensor::from([("a", 2), ("b", 3)], (0..6).collect());
1222     * let reversed = ab.reverse(&["a"]);
1223     * let also_reversed = TensorView::from(TensorReverse::from(&ab, &["a"]));
1224     * assert_eq!(reversed, also_reversed);
1225     * assert_eq!(
1226     *     reversed,
1227     *     Tensor::from(
1228     *         [("a", 2), ("b", 3)],
1229     *         vec![
1230     *             3, 4, 5,
1231     *             0, 1, 2,
1232     *         ]
1233     *     )
1234     * );
1235     * ```
1236     *
1237     * # Panics
1238     *
1239     * - If a dimension name is not in the tensor's shape or is repeated.
1240     */
1241    #[track_caller]
1242    pub fn reverse(
1243        &self,
1244        dimensions: &[Dimension],
1245    ) -> TensorView<T, TensorReverse<T, &Tensor<T, D>, D>, D> {
1246        TensorView::from(TensorReverse::from(self, dimensions))
1247    }
1248
1249    /**
1250     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1251     * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1252     * mutably borrows this Tensor, and can therefore mutate it
1253     *
1254     * This is a shorthand for constructing the TensorView from this Tensor.
1255     *
1256     * # Panics
1257     *
1258     * - If a dimension name is not in the tensor's shape or is repeated.
1259     */
1260    #[track_caller]
1261    pub fn reverse_mut(
1262        &mut self,
1263        dimensions: &[Dimension],
1264    ) -> TensorView<T, TensorReverse<T, &mut Tensor<T, D>, D>, D> {
1265        TensorView::from(TensorReverse::from(self, dimensions))
1266    }
1267
1268    /**
1269     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1270     * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1271     * takes ownership of this Tensor, and can therefore mutate it
1272     *
1273     * This is a shorthand for constructing the TensorView from this Tensor.
1274     *
1275     * # Panics
1276     *
1277     * - If a dimension name is not in the tensor's shape or is repeated.
1278     */
1279    #[track_caller]
1280    pub fn reverse_owned(
1281        self,
1282        dimensions: &[Dimension],
1283    ) -> TensorView<T, TensorReverse<T, Tensor<T, D>, D>, D> {
1284        TensorView::from(TensorReverse::from(self, dimensions))
1285    }
1286
1287    /**
1288     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1289     * mapped by a function. The value pairs are not copied for you, if you're using `Copy` types
1290     * or need to clone the values anyway, you can use
1291     * [`Tensor::elementwise`](Tensor::elementwise) instead.
1292     *
1293     * # Generics
1294     *
1295     * This method can be called with any right hand side that can be converted to a TensorView,
1296     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1297     *
1298     * # Panics
1299     *
1300     * If the two tensors have different shapes.
1301     */
1302    #[track_caller]
1303    pub fn elementwise_reference<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1304    where
1305        I: Into<TensorView<T, S, D>>,
1306        S: TensorRef<T, D>,
1307        M: Fn(&T, &T) -> T,
1308    {
1309        self.elementwise_reference_less_generic(rhs.into(), mapping_function)
1310    }
1311
1312    /**
1313     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1314     * mapped by a function. The mapping function also receives each index corresponding to the
1315     * value pairs. The value pairs are not copied for you, if you're using `Copy` types
1316     * or need to clone the values anyway, you can use
1317     * [`Tensor::elementwise_with_index`](Tensor::elementwise_with_index) instead.
1318     *
1319     * # Generics
1320     *
1321     * This method can be called with any right hand side that can be converted to a TensorView,
1322     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1323     *
1324     * # Panics
1325     *
1326     * If the two tensors have different shapes.
1327     */
1328    #[track_caller]
1329    pub fn elementwise_reference_with_index<S, I, M>(
1330        &self,
1331        rhs: I,
1332        mapping_function: M,
1333    ) -> Tensor<T, D>
1334    where
1335        I: Into<TensorView<T, S, D>>,
1336        S: TensorRef<T, D>,
1337        M: Fn([usize; D], &T, &T) -> T,
1338    {
1339        self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
1340    }
1341
1342    #[track_caller]
1343    fn elementwise_reference_less_generic<S, M>(
1344        &self,
1345        rhs: TensorView<T, S, D>,
1346        mapping_function: M,
1347    ) -> Tensor<T, D>
1348    where
1349        S: TensorRef<T, D>,
1350        M: Fn(&T, &T) -> T,
1351    {
1352        let left_shape = self.shape();
1353        let right_shape = rhs.shape();
1354        if left_shape != right_shape {
1355            panic!(
1356                "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1357                left_shape, right_shape
1358            );
1359        }
1360        let mapped = self
1361            .direct_iter_reference()
1362            .zip(rhs.iter_reference())
1363            .map(|(x, y)| mapping_function(x, y))
1364            .collect();
1365        // We're not changing the shape of the Tensor, so don't need to revalidate
1366        Tensor::direct_from(mapped, self.shape, self.strides)
1367    }
1368
1369    #[track_caller]
1370    fn elementwise_reference_less_generic_with_index<S, M>(
1371        &self,
1372        rhs: TensorView<T, S, D>,
1373        mapping_function: M,
1374    ) -> Tensor<T, D>
1375    where
1376        S: TensorRef<T, D>,
1377        M: Fn([usize; D], &T, &T) -> T,
1378    {
1379        let left_shape = self.shape();
1380        let right_shape = rhs.shape();
1381        if left_shape != right_shape {
1382            panic!(
1383                "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1384                left_shape, right_shape
1385            );
1386        }
1387        // we just checked both shapes were the same, so we don't need to propagate indexes
1388        // for both tensors because they'll be identical
1389        let mapped = self
1390            .direct_iter_reference()
1391            .zip(rhs.iter_reference().with_index())
1392            .map(|(x, (i, y))| mapping_function(i, x, y))
1393            .collect();
1394        // We're not changing the shape of the Tensor, so don't need to revalidate
1395        Tensor::direct_from(mapped, self.shape, self.strides)
1396    }
1397
1398    /**
1399     * Returns a TensorView which makes the order of the data in this tensor appear to be in
1400     * a different order. The order of the dimension names is unchanged, although their lengths
1401     * may swap.
1402     *
1403     * This is a shorthand for constructing the TensorView from this Tensor.
1404     *
1405     * See also: [transpose](Tensor::transpose), [TensorTranspose]
1406     *
1407     * # Panics
1408     *
1409     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1410     * order need not match.
1411     */
1412    pub fn transpose_view(
1413        &self,
1414        dimensions: [Dimension; D],
1415    ) -> TensorView<T, TensorTranspose<T, &Tensor<T, D>, D>, D> {
1416        TensorView::from(TensorTranspose::from(self, dimensions))
1417    }
1418}
1419
1420impl<T, const D: usize> Tensor<T, D>
1421where
1422    T: Clone,
1423{
1424    /**
1425     * Creates a tensor with a particular number of dimensions and length in each dimension
1426     * with all elements initialised to the provided value.
1427     *
1428     * # Panics
1429     *
1430     * - If a dimension name is not unique
1431     * - If any dimension has 0 elements
1432     */
1433    #[track_caller]
1434    pub fn empty(shape: [(Dimension, usize); D], value: T) -> Self {
1435        let elements = crate::tensors::dimensions::elements(&shape);
1436        Tensor::from(shape, vec![value; elements])
1437    }
1438
1439    /**
1440     * Gets a copy of the first value in this tensor.
1441     * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
1442     * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
1443     */
1444    pub fn first(&self) -> T {
1445        self.data
1446            .first()
1447            .expect("Tensors always have at least 1 element")
1448            .clone()
1449    }
1450
1451    /**
1452     * Returns a new Tensor which has the same data as this tensor, but with the order of data
1453     * changed. The order of the dimension names is unchanged, although their lengths may swap.
1454     *
1455     * For example, with a `[("x", x), ("y", y)]` tensor you could call
1456     * `transpose(["y", "x"])` which would return a new tensor with a shape of
1457     * `[("x", y), ("y", x)]` where every (x,y) of its data corresponds to (y,x) in the original.
1458     *
1459     * This method need not shift *all* the dimensions though, you could also swap the width
1460     * and height of images in a tensor with a shape of
1461     * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `transpose(["batch", "w", "h", "c"])`
1462     * which would return a new tensor where all the images have been swapped over the diagonal.
1463     *
1464     * See also: [TensorAccess], [reorder](Tensor::reorder)
1465     *
1466     * # Panics
1467     *
1468     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1469     * order need not match (and if the order does match, this function is just an expensive
1470     * clone).
1471     */
1472    #[track_caller]
1473    pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1474        let shape = self.shape;
1475        let mut reordered = self.reorder(dimensions);
1476        // Transposition is essentially reordering, but we retain the dimension name ordering
1477        // of the original order, this means we may swap dimension lengths, but the dimensions
1478        // will not change order.
1479        #[allow(clippy::needless_range_loop)]
1480        for d in 0..D {
1481            reordered.shape[d].0 = shape[d].0;
1482        }
1483        reordered
1484    }
1485
1486    /**
1487     * Modifies this tensor to have the same data as before, but with the order of data changed.
1488     * The order of the dimension names is unchanged, although their lengths may swap.
1489     *
1490     * For example, with a `[("x", x), ("y", y)]` tensor you could call
1491     * `transpose_mut(["y", "x"])` which would edit the tensor, updating its shape to
1492     * `[("x", y), ("y", x)]`, so every (x,y) of its data corresponds to (y,x) before the
1493     * transposition.
1494     *
1495     * The order swapping will try to be in place, but this is currently only supported for
1496     * square tensors with 2 dimensions. Other types of tensors will not be transposed in place.
1497     *
1498     * # Panics
1499     *
1500     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1501     * order need not match (and if the order does match, this function is just an expensive
1502     * clone).
1503     */
1504    #[track_caller]
1505    pub fn transpose_mut(&mut self, dimensions: [Dimension; D]) {
1506        let shape = self.shape;
1507        self.reorder_mut(dimensions);
1508        // Transposition is essentially reordering, but we retain the dimension name ordering
1509        // we had before, this means we may swap dimension lengths, but the dimensions
1510        // will not change order.
1511        #[allow(clippy::needless_range_loop)]
1512        for d in 0..D {
1513            self.shape[d].0 = shape[d].0;
1514        }
1515    }
1516
1517    /**
1518     * Returns a new Tensor which has the same data as this tensor, but with the order of the
1519     * dimensions and corresponding order of data changed.
1520     *
1521     * For example, with a `[("x", x), ("y", y)]` tensor you could call
1522     * `reorder(["y", "x"])` which would return a new tensor with a shape of
1523     * `[("y", y), ("x", x)]` where every (y,x) of its data corresponds to (x,y) in the original.
1524     *
1525     * This method need not shift *all* the dimensions though, you could also swap the width
1526     * and height of images in a tensor with a shape of
1527     * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `reorder(["batch", "w", "h", "c"])`
1528     * which would return a new tensor where every (b,w,h,c) of its data corresponds to (b,h,w,c)
1529     * in the original.
1530     *
1531     * See also: [TensorAccess], [transpose](Tensor::transpose)
1532     *
1533     * # Panics
1534     *
1535     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1536     * order need not match (and if the order does match, this function is just an expensive
1537     * clone).
1538     */
1539    #[track_caller]
1540    pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1541        let reorderd = match TensorAccess::try_from(&self, dimensions) {
1542            Ok(reordered) => reordered,
1543            Err(_error) => panic!(
1544                "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1545                dimensions, &self.shape,
1546            ),
1547        };
1548        let reorderd_shape = reorderd.shape();
1549        Tensor::from(reorderd_shape, reorderd.iter().collect())
1550    }
1551
1552    /**
1553     * Modifies this tensor to have the same data as before, but with the order of the
1554     * dimensions and corresponding order of data changed.
1555     *
1556     * For example, with a `[("x", x), ("y", y)]` tensor you could call
1557     * `reorder_mut(["y", "x"])` which would edit the tensor, updating its shape to
1558     * `[("y", y), ("x", x)]`, so every (y,x) of its data corresponds to (x,y) before the
1559     * transposition.
1560     *
1561     * The order swapping will try to be in place, but this is currently only supported for
1562     * square tensors with 2 dimensions. Other types of tensors will not be reordered in place.
1563     *
1564     * # Panics
1565     *
1566     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1567     * order need not match (and if the order does match, this function is just an expensive
1568     * clone).
1569     */
1570    #[track_caller]
1571    pub fn reorder_mut(&mut self, dimensions: [Dimension; D]) {
1572        use crate::tensors::dimensions::DimensionMappings;
1573        if D == 2 && crate::tensors::dimensions::is_square(&self.shape) {
1574            let dimension_mapping = match DimensionMappings::new(&self.shape, &dimensions) {
1575                Some(dimension_mapping) => dimension_mapping,
1576                None => panic!(
1577                    "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1578                    dimensions, &self.shape,
1579                ),
1580            };
1581
1582            let shape = dimension_mapping.map_shape_to_requested(&self.shape);
1583            let shape_iterator = ShapeIterator::from(shape);
1584
1585            for index in shape_iterator {
1586                let i = index[0];
1587                let j = index[1];
1588                if j >= i {
1589                    let mapped_index = dimension_mapping.map_dimensions_to_source(&index);
1590                    // Swap elements from the upper triangle (using index order of the actual tensor's
1591                    // shape)
1592                    let temp = self.get_reference(index).unwrap().clone();
1593                    // tensor[i,j] becomes tensor[mapping(i,j)]
1594                    *self.get_reference_mut(index).unwrap() =
1595                        self.get_reference(mapped_index).unwrap().clone();
1596                    // tensor[mapping(i,j)] becomes tensor[i,j]
1597                    *self.get_reference_mut(mapped_index).unwrap() = temp;
1598                    // If the mapping is a noop we've assigned i,j to i,j
1599                    // If the mapping is i,j -> j,i we've assigned i,j to j,i and j,i to i,j
1600                }
1601            }
1602
1603            // now update our shape and strides to match
1604            self.shape = shape;
1605            self.strides = compute_strides(&shape);
1606        } else {
1607            // fallback to allocating a new reordered tensor
1608            let reordered = self.reorder(dimensions);
1609            self.data = reordered.data;
1610            self.shape = reordered.shape;
1611            self.strides = reordered.strides;
1612        }
1613    }
1614
1615    /**
1616     * Returns an iterator over copies of the data in this Tensor.
1617     */
1618    pub fn iter(&self) -> TensorIterator<'_, T, Tensor<T, D>, D> {
1619        TensorIterator::from(self)
1620    }
1621
1622    /**
1623     * Creates and returns a new tensor with all values from the original with the
1624     * function applied to each. This can be used to change the type of the tensor
1625     * such as creating a mask:
1626     * ```
1627     * use easy_ml::tensors::Tensor;
1628     * let x = Tensor::from([("a", 2), ("b", 2)], vec![
1629     *    0.0, 1.2,
1630     *    5.8, 6.9
1631     * ]);
1632     * let y = x.map(|element| element > 2.0);
1633     * let result = Tensor::from([("a", 2), ("b", 2)], vec![
1634     *    false, false,
1635     *    true, true
1636     * ]);
1637     * assert_eq!(&y, &result);
1638     * ```
1639     */
1640    pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
1641        let mapped = self
1642            .data
1643            .iter()
1644            .map(|x| mapping_function(x.clone()))
1645            .collect();
1646        // We're not changing the shape of the Tensor, so don't need to revalidate
1647        Tensor::direct_from(mapped, self.shape, self.strides)
1648    }
1649
1650    /**
1651     * Creates and returns a new tensor with all values from the original and
1652     * the index of each value mapped by a function.
1653     */
1654    pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
1655        let mapped = self
1656            .iter()
1657            .with_index()
1658            .map(|(i, x)| mapping_function(i, x))
1659            .collect();
1660        // We're not changing the shape of the Tensor, so don't need to revalidate
1661        Tensor::direct_from(mapped, self.shape, self.strides)
1662    }
1663
1664    /**
1665     * Applies a function to all values in the tensor, modifying
1666     * the tensor in place.
1667     */
1668    pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1669        for value in self.data.iter_mut() {
1670            *value = mapping_function(value.clone());
1671        }
1672    }
1673
1674    /**
1675     * Applies a function to all values and each value's index in the tensor, modifying
1676     * the tensor in place.
1677     */
1678    pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
1679        self.iter_reference_mut()
1680            .with_index()
1681            .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
1682    }
1683
1684    /**
1685     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1686     * mapped by a function.
1687     *
1688     * ```
1689     * use easy_ml::tensors::Tensor;
1690     * let lhs = Tensor::from([("a", 4)], vec![1, 2, 3, 4]);
1691     * let rhs = Tensor::from([("a", 4)], vec![0, 1, 2, 3]);
1692     * let multiplied = lhs.elementwise(&rhs, |l, r| l * r);
1693     * assert_eq!(
1694     *     multiplied,
1695     *     Tensor::from([("a", 4)], vec![0, 2, 6, 12])
1696     * );
1697     * ```
1698     *
1699     * # Generics
1700     *
1701     * This method can be called with any right hand side that can be converted to a TensorView,
1702     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1703     *
1704     * # Panics
1705     *
1706     * If the two tensors have different shapes.
1707     */
1708    #[track_caller]
1709    pub fn elementwise<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1710    where
1711        I: Into<TensorView<T, S, D>>,
1712        S: TensorRef<T, D>,
1713        M: Fn(T, T) -> T,
1714    {
1715        self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
1716            mapping_function(lhs.clone(), rhs.clone())
1717        })
1718    }
1719
1720    /**
1721     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1722     * mapped by a function. The mapping function also receives each index corresponding to the
1723     * value pairs.
1724     *
1725     * # Generics
1726     *
1727     * This method can be called with any right hand side that can be converted to a TensorView,
1728     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1729     *
1730     * # Panics
1731     *
1732     * If the two tensors have different shapes.
1733     */
1734    #[track_caller]
1735    pub fn elementwise_with_index<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1736    where
1737        I: Into<TensorView<T, S, D>>,
1738        S: TensorRef<T, D>,
1739        M: Fn([usize; D], T, T) -> T,
1740    {
1741        self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
1742            mapping_function(i, lhs.clone(), rhs.clone())
1743        })
1744    }
1745}
1746
1747impl<T> Tensor<T, 1>
1748where
1749    T: Numeric,
1750    for<'a> &'a T: NumericRef<T>,
1751{
1752    /**
1753     * Computes the scalar product of two equal length vectors. For two vectors `[a,b,c]` and
1754     * `[d,e,f]`, returns `a*d + b*e + c*f`. This is also known as the dot product.
1755     *
1756     * ```
1757     * use easy_ml::tensors::Tensor;
1758     * let tensor = Tensor::from([("sequence", 5)], vec![3, 4, 5, 6, 7]);
1759     * assert_eq!(tensor.scalar_product(&tensor), 3*3 + 4*4 + 5*5 + 6*6 + 7*7);
1760     * ```
1761     *
1762     * # Generics
1763     *
1764     * This method can be called with any right hand side that can be converted to a TensorView,
1765     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1766     *
1767     * # Panics
1768     *
1769     * If the two vectors are not of equal length or their dimension names do not match.
1770     */
1771    // Would like this impl block to be in operations.rs too but then it would show first in the
1772    // Tensor docs which isn't ideal
1773    pub fn scalar_product<S, I>(&self, rhs: I) -> T
1774    where
1775        I: Into<TensorView<T, S, 1>>,
1776        S: TensorRef<T, 1>,
1777    {
1778        self.scalar_product_less_generic(rhs.into())
1779    }
1780}
1781
1782impl<T> Tensor<T, 2>
1783where
1784    T: Numeric,
1785    for<'a> &'a T: NumericRef<T>,
1786{
1787    /**
1788     * Returns the determinant of this square matrix, or None if the matrix
1789     * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant_tensor())
1790     */
1791    pub fn determinant(&self) -> Option<T> {
1792        linear_algebra::determinant_tensor::<T, _, _>(self)
1793    }
1794
1795    /**
1796     * Computes the inverse of a matrix provided that it exists. To have an inverse a
1797     * matrix must be square (same number of rows and columns) and it must also have a
1798     * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse_tensor())
1799     */
1800    pub fn inverse(&self) -> Option<Tensor<T, 2>> {
1801        linear_algebra::inverse_tensor::<T, _, _>(self)
1802    }
1803
1804    /**
1805     * Computes the covariance matrix for this feature matrix along the specified feature
1806     * dimension in this matrix. See [`linear_algebra`](crate::linear_algebra::covariance()).
1807     */
1808    pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
1809        linear_algebra::covariance::<T, _, _>(self, feature_dimension)
1810    }
1811}
1812
1813// FIXME: want this to be callable in the main numeric impl block
1814impl<T> Tensor<T, 2>
1815where
1816    T: Numeric,
1817{
1818    /**
1819     * Creates a diagonal matrix of the provided size with the diagonal elements
1820     * set to the provided value and all other elements in the tensor set to 0.
1821     * A diagonal matrix is always square.
1822     *
1823     * The size is still taken as a shape to facilitate creating a diagonal matrix
1824     * from the dimensionality of an existing one. If the provided value is 1 then
1825     * this will create an identity matrix.
1826     *
1827     * A 3 x 3 identity matrix:
1828     * ```ignore
1829     * [
1830     *   1, 0, 0
1831     *   0, 1, 0
1832     *   0, 0, 1
1833     * ]
1834     * ```
1835     *
1836     * # Panics
1837     *
1838     * - If the shape is not square.
1839     * - If a dimension name is not unique
1840     * - If any dimension has 0 elements
1841     */
1842    #[track_caller]
1843    pub fn diagonal(shape: [(Dimension, usize); 2], value: T) -> Tensor<T, 2> {
1844        if !crate::tensors::dimensions::is_square(&shape) {
1845            panic!("Shape must be square: {:?}", shape);
1846        }
1847        let mut tensor = Tensor::empty(shape, T::zero());
1848        for ([r, c], x) in tensor.iter_reference_mut().with_index() {
1849            if r == c {
1850                *x = value.clone();
1851            }
1852        }
1853        tensor
1854    }
1855}
1856
1857impl<T> Tensor<T, 2> {
1858    /**
1859     * Converts this 2 dimensional Tensor into a Matrix.
1860     *
1861     * This is a wrapper around the `From<Tensor<T, 2>>` implementation.
1862     *
1863     * The Matrix will have the data in the same order, with rows equal to the length of
1864     * the first dimension in the tensor, and columns equal to the length of the second.
1865     */
1866    pub fn into_matrix(self) -> crate::matrices::Matrix<T> {
1867        self.into()
1868    }
1869}
1870
1871/**
1872 * Methods for tensors with numerical real valued types, such as f32 or f64.
1873 *
1874 * This excludes signed and unsigned integers as they do not support decimal
1875 * precision and hence can't be used for operations like square roots.
1876 *
1877 * Third party fixed precision and infinite precision decimal types should
1878 * be able to implement all of the methods for [Real] and then utilise these functions.
1879 */
1880impl<T: Real> Tensor<T, 1>
1881where
1882    for<'a> &'a T: RealRef<T>,
1883{
1884    /**
1885     * Computes the [L2 norm](https://en.wikipedia.org/wiki/Euclidean_vector#Length)
1886     * of this vector, also referred to as the length or magnitude,
1887     * and written as ||x||, or sometimes |x|.
1888     *
1889     * ||**a**|| = sqrt(a<sub>1</sub><sup>2</sup> + a<sub>2</sub><sup>2</sup> + a<sub>3</sub><sup>2</sup>...) = sqrt(**a**<sup>T</sup> * **a**)
1890     *
1891     * This is a shorthand for `(x.iter().map(|x| x * x).sum().sqrt()`, ie
1892     * the square root of the dot product of a vector with itself.
1893     *
1894     * The euclidean length can be used to compute a
1895     * [unit vector](https://en.wikipedia.org/wiki/Unit_vector), that is, a
1896     * vector with length of 1. This should not be confused with a unit matrix,
1897     * which is another name for an identity matrix.
1898     *
1899     * ```
1900     * use easy_ml::tensors::Tensor;
1901     * let a = Tensor::from([("data", 3)], vec![ 1.0, 2.0, 3.0 ]);
1902     * let length = a.euclidean_length(); // (1^2 + 2^2 + 3^2)^0.5
1903     * let unit = a.map(|x| x / length);
1904     * assert_eq!(unit.euclidean_length(), 1.0);
1905     * ```
1906     */
1907    // TODO: Scalar ops for tensors
1908    pub fn euclidean_length(&self) -> T {
1909        self.direct_iter_reference()
1910            .map(|x| x * x)
1911            .sum::<T>()
1912            .sqrt()
1913    }
1914}
1915
1916#[cfg(feature = "serde")]
1917mod serde_impls {
1918    use crate::tensors::{Dimension, InvalidShapeError, Tensor};
1919    use serde::Deserialize;
1920    use std::convert::TryFrom;
1921
1922    /**
1923     * Deserialized data for a Tensor. Can be converted into a Tensor by
1924     * providing `&'static str` dimension names.
1925     *
1926     * This struct borrows the string references, so will not be supported
1927     * by any serde library that doesn't support deserializing to borrowed types.
1928     * However, if used with a serde library that can deserialize to borrowed
1929     * types, you could create a TensorDeserialize with a static lifetime by
1930     * using a static input, and then it will be possible to convert to a
1931     * Tensor using the [TryFrom](TryFrom) implementation without providing
1932     * dimension names to use via [into_tensor](TensorDeserialize::into_tensor).
1933     */
1934    #[derive(Deserialize, Debug)]
1935    #[serde(rename = "Tensor")]
1936    pub struct TensorDeserialize<'a, T, const D: usize> {
1937        data: Vec<T>,
1938        #[serde(with = "serde_arrays")]
1939        #[serde(borrow)]
1940        shape: [(&'a str, usize); D],
1941    }
1942
1943    /**
1944     * Deserialized data for a Tensor. Can be converted into a Tensor by providing
1945     * `&'static str` dimension names.
1946     *
1947     * This struct owns the string references, so can be parsed by serde
1948     * libraries that don't support deserializing to borrowed types. However,
1949     * we can't convert from `String` to `&'static str` without leaking, so
1950     * there is no TryFrom implementation to convert to a Tensor, and
1951     * instead new static dimension names must always be provided via
1952     * [into_tensor](TensorDeserializeOwned::into_tensor).
1953     */
1954    #[derive(Deserialize, Debug)]
1955    #[serde(rename = "Tensor")]
1956    pub struct TensorDeserializeOwned<T, const D: usize> {
1957        data: Vec<T>,
1958        #[serde(with = "serde_arrays")]
1959        shape: [(String, usize); D],
1960    }
1961
1962    impl<'a, T, const D: usize> TensorDeserialize<'a, T, D> {
1963        /**
1964         * Converts this deserialised Tensor data to a Tensor, using the provided `&'static str`
1965         * dimension names in place of what was serialised (which wouldn't necessarily live
1966         * long enough).
1967         */
1968        pub fn into_tensor(
1969            self,
1970            dimensions: [Dimension; D],
1971        ) -> Result<Tensor<T, D>, InvalidShapeError<D>> {
1972            let shape = std::array::from_fn(|d| (dimensions[d], self.shape[d].1));
1973            // Safety: Use the normal constructor that performs validation to prevent invalid
1974            // serialized data being created as a Tensor, which would then break all the
1975            // code that's relying on these invariants.
1976            // By never serialising the strides in the first place, we reduce the possibility
1977            // of creating invalid serialised representations at the slight increase in
1978            // serialisation work.
1979            Tensor::try_from(shape, self.data)
1980        }
1981    }
1982
1983    impl<T, const D: usize> TensorDeserializeOwned<T, D> {
1984        /**
1985         * Converts this deserialised Tensor data to a Tensor, using the provided `&'static str`
1986         * dimension names in place of what was serialised (which wouldn't live
1987         * long enough).
1988         */
1989        pub fn into_tensor(
1990            self,
1991            dimensions: [Dimension; D],
1992        ) -> Result<Tensor<T, D>, InvalidShapeError<D>> {
1993            let shape = std::array::from_fn(|d| (dimensions[d], self.shape[d].1));
1994            // Safety: Use the normal constructor that performs validation to prevent invalid
1995            // serialized data being created as a Tensor, which would then break all the
1996            // code that's relying on these invariants.
1997            // By never serialising the strides in the first place, we reduce the possibility
1998            // of creating invalid serialised representations at the slight increase in
1999            // serialisation work.
2000            Tensor::try_from(shape, self.data)
2001        }
2002    }
2003
2004    /**
2005     * Converts this deserialised Tensor data which has a static lifetime for the dimension
2006     * names to a Tensor, using the serialised data.
2007     */
2008    impl<T, const D: usize> TryFrom<TensorDeserialize<'static, T, D>> for Tensor<T, D> {
2009        type Error = InvalidShapeError<D>;
2010
2011        fn try_from(value: TensorDeserialize<'static, T, D>) -> Result<Self, Self::Error> {
2012            Tensor::try_from(value.shape, value.data)
2013        }
2014    }
2015}
2016
2017#[cfg(feature = "serde")]
2018#[test]
2019fn test_serialize() {
2020    fn assert_serialize<T: Serialize>() {}
2021    assert_serialize::<Tensor<f64, 3>>();
2022    assert_serialize::<Tensor<f64, 2>>();
2023    assert_serialize::<Tensor<f64, 1>>();
2024    assert_serialize::<Tensor<f64, 0>>();
2025}
2026
2027#[cfg(feature = "serde")]
2028#[test]
2029fn test_deserialize() {
2030    use serde::Deserialize;
2031    fn assert_deserialize<'de, T: Deserialize<'de>>() {}
2032    assert_deserialize::<TensorDeserialize<f64, 3>>();
2033    assert_deserialize::<TensorDeserialize<f64, 2>>();
2034    assert_deserialize::<TensorDeserialize<f64, 1>>();
2035    assert_deserialize::<TensorDeserialize<f64, 0>>();
2036}
2037
2038#[cfg(feature = "serde")]
2039#[test]
2040fn test_serialization_deserialization_loop_toml() {
2041    #[rustfmt::skip]
2042    let tensor = Tensor::from(
2043        [("rows", 3), ("columns", 4)],
2044        vec![
2045            1,  2,  3,  4,
2046            5,  6,  7,  8,
2047            9, 10, 11, 12
2048        ],
2049    );
2050    let encoded = toml::to_string(&tensor).unwrap();
2051    assert_eq!(
2052        encoded,
2053        r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
2054shape = [["rows", 3], ["columns", 4]]
2055"#,
2056    );
2057    let parsed: Result<TensorDeserializeOwned<i32, 2>, _> = toml::from_str(&encoded);
2058    assert!(parsed.is_ok());
2059    let result = parsed.unwrap().into_tensor(["rows", "columns"]);
2060    assert!(result.is_ok());
2061    assert_eq!(result.unwrap(), tensor);
2062}
2063
2064#[cfg(feature = "serde")]
2065#[test]
2066fn test_serialization_deserialization_loop_json() {
2067    #[rustfmt::skip]
2068    let tensor = Tensor::from(
2069        [("rows", 3), ("columns", 4)],
2070        vec![
2071            1,  2,  3,  4,
2072            5,  6,  7,  8,
2073            9, 10, 11, 12
2074        ],
2075    );
2076    let encoded = serde_json::ser::to_string(&tensor).unwrap();
2077    assert_eq!(
2078        encoded,
2079        r#"{"data":[1,2,3,4,5,6,7,8,9,10,11,12],"shape":[["rows",3],["columns",4]]}"#,
2080    );
2081    let parsed: Result<TensorDeserialize<i32, 2>, _> = serde_json::de::from_str(&encoded);
2082    assert!(parsed.is_ok());
2083    let result = parsed.unwrap().into_tensor(["rows", "columns"]);
2084    assert!(result.is_ok());
2085    assert_eq!(result.unwrap(), tensor);
2086}
2087
2088#[cfg(feature = "serde")]
2089#[test]
2090fn test_deserialization_validation() {
2091    let parsed: Result<TensorDeserializeOwned<i32, 2>, _> = toml::from_str(
2092        r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
2093shape = [["rows", 4], ["columns", 4]]
2094"#,
2095    );
2096    assert!(parsed.is_ok());
2097    let result = parsed.unwrap().into_tensor(["rows", "columns"]);
2098    assert!(result.is_err());
2099}
2100
2101#[cfg(feature = "serde")]
2102#[cfg(test)]
2103const TENSOR_DATA: &'static str = r#"
2104{
2105    "data": [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
2106    "shape": [["rows", 3], ["columns", 4]]
2107}
2108"#;
2109
2110#[cfg(feature = "serde")]
2111#[test]
2112fn test_deserialization_static_data() {
2113    #[rustfmt::skip]
2114    let tensor = Tensor::from(
2115        [("rows", 3), ("columns", 4)],
2116        vec![
2117            12, 11, 10, 9,
2118            8,   7,  6,  5,
2119            4,   3,  2, 1,
2120        ],
2121    );
2122    // To test TensorDeserialize we can't use toml because later
2123    // versions don't support parsing borrowed data, so we use
2124    // serde_json instead
2125    let parsed: Result<TensorDeserialize<i32, 2>, _> = serde_json::de::from_str(TENSOR_DATA);
2126    assert!(parsed.is_ok());
2127    let result: Result<Tensor<i32, 2>, _> = parsed.unwrap().try_into();
2128    assert!(result.is_ok());
2129    assert_eq!(result.unwrap(), tensor);
2130}
2131
2132macro_rules! tensor_select_impl {
2133    (impl Tensor $d:literal 1) => {
2134        impl<T> Tensor<T, $d> {
2135            /**
2136             * Selects the provided dimension name and index pairs in this Tensor, returning a
2137             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
2138             * always indexed as the provided values.
2139             *
2140             * This is a shorthand for manually constructing the TensorView and
2141             * [TensorIndex]
2142             *
2143             * Note: due to limitations in Rust's const generics support, this method is only
2144             * implemented for `provided_indexes` of length 1 and `D` from 1 to 6. You can fall
2145             * back to manual construction to create `TensorIndex`es with multiple provided
2146             * indexes if you need to reduce dimensionality by more than 1 dimension at a time.
2147             */
2148            #[track_caller]
2149            pub fn select(
2150                &self,
2151                provided_indexes: [(Dimension, usize); 1],
2152            ) -> TensorView<T, TensorIndex<T, &Tensor<T, $d>, $d, 1>, { $d - 1 }> {
2153                TensorView::from(TensorIndex::from(self, provided_indexes))
2154            }
2155
2156            /**
2157             * Selects the provided dimension name and index pairs in this Tensor, returning a
2158             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
2159             * always indexed as the provided values. The TensorIndex mutably borrows this
2160             * Tensor, and can therefore mutate it
2161             *
2162             * See [select](Tensor::select)
2163             */
2164            #[track_caller]
2165            pub fn select_mut(
2166                &mut self,
2167                provided_indexes: [(Dimension, usize); 1],
2168            ) -> TensorView<T, TensorIndex<T, &mut Tensor<T, $d>, $d, 1>, { $d - 1 }> {
2169                TensorView::from(TensorIndex::from(self, provided_indexes))
2170            }
2171
2172            /**
2173             * Selects the provided dimension name and index pairs in this Tensor, returning a
2174             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
2175             * always indexed as the provided values. The TensorIndex takes ownership of this
2176             * Tensor, and can therefore mutate it
2177             *
2178             * See [select](Tensor::select)
2179             */
2180            #[track_caller]
2181            pub fn select_owned(
2182                self,
2183                provided_indexes: [(Dimension, usize); 1],
2184            ) -> TensorView<T, TensorIndex<T, Tensor<T, $d>, $d, 1>, { $d - 1 }> {
2185                TensorView::from(TensorIndex::from(self, provided_indexes))
2186            }
2187        }
2188    };
2189}
2190
2191tensor_select_impl!(impl Tensor 6 1);
2192tensor_select_impl!(impl Tensor 5 1);
2193tensor_select_impl!(impl Tensor 4 1);
2194tensor_select_impl!(impl Tensor 3 1);
2195tensor_select_impl!(impl Tensor 2 1);
2196tensor_select_impl!(impl Tensor 1 1);
2197
2198macro_rules! tensor_expand_impl {
2199    (impl Tensor $d:literal 1) => {
2200        impl<T> Tensor<T, $d> {
2201            /**
2202             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2203             * a particular position within the shape, returning a TensorView which has more
2204             * dimensions than this Tensor.
2205             *
2206             * This is a shorthand for manually constructing the TensorView and
2207             * [TensorExpansion]
2208             *
2209             * Note: due to limitations in Rust's const generics support, this method is only
2210             * implemented for `extra_dimension_names` of length 1 and `D` from 0 to 5. You can
2211             * fall back to manual construction to create `TensorExpansion`s with multiple provided
2212             * indexes if you need to increase dimensionality by more than 1 dimension at a time.
2213             */
2214            #[track_caller]
2215            pub fn expand(
2216                &self,
2217                extra_dimension_names: [(usize, Dimension); 1],
2218            ) -> TensorView<T, TensorExpansion<T, &Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2219                TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2220            }
2221
2222            /**
2223             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2224             * a particular position within the shape, returning a TensorView which has more
2225             * dimensions than this Tensor. The TensorIndex mutably borrows this
2226             * Tensor, and can therefore mutate it
2227             *
2228             * See [expand](Tensor::expand)
2229             */
2230            #[track_caller]
2231            pub fn expand_mut(
2232                &mut self,
2233                extra_dimension_names: [(usize, Dimension); 1],
2234            ) -> TensorView<T, TensorExpansion<T, &mut Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2235                TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2236            }
2237
2238            /**
2239             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2240             * a particular position within the shape, returning a TensorView which has more
2241             * dimensions than this Tensor. The TensorIndex takes ownership of this
2242             * Tensor, and can therefore mutate it
2243             *
2244             * See [expand](Tensor::expand)
2245             */
2246            #[track_caller]
2247            pub fn expand_owned(
2248                self,
2249                extra_dimension_names: [(usize, Dimension); 1],
2250            ) -> TensorView<T, TensorExpansion<T, Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2251                TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2252            }
2253        }
2254    };
2255}
2256
2257tensor_expand_impl!(impl Tensor 0 1);
2258tensor_expand_impl!(impl Tensor 1 1);
2259tensor_expand_impl!(impl Tensor 2 1);
2260tensor_expand_impl!(impl Tensor 3 1);
2261tensor_expand_impl!(impl Tensor 4 1);
2262tensor_expand_impl!(impl Tensor 5 1);