easy_ml/tensors/
mod.rs

1/*!
2 * Generic N dimensional [named tensors](http://nlp.seas.harvard.edu/NamedTensor).
3 *
4 * Tensors are generic over some type `T` and some usize `D`. If `T` is [Numeric](super::numeric)
5 * then the tensor can be used in a mathematical way. `D` is the number of dimensions in the tensor
6 * and a compile time constant. Each tensor also carries `D` dimension name and length pairs.
7 */
8use crate::linear_algebra;
9use crate::numeric::extra::{Real, RealRef};
10use crate::numeric::{Numeric, NumericRef};
11use crate::tensors::indexing::{
12    ShapeIterator, TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
13    TensorReferenceMutIterator, TensorTranspose,
14};
15use crate::tensors::views::{
16    DataLayout, IndexRange, IndexRangeValidationError, TensorExpansion, TensorIndex, TensorMask,
17    TensorMut, TensorRange, TensorRef, TensorRename, TensorReshape, TensorReverse, TensorView,
18};
19
20use std::error::Error;
21use std::fmt;
22
23#[cfg(feature = "serde")]
24use serde::Serialize;
25
26pub mod dimensions;
27mod display;
28pub mod indexing;
29pub mod operations;
30pub mod views;
31
32#[cfg(feature = "serde")]
33pub use serde_impls::TensorDeserialize;
34
35/**
36 * Dimension names are represented as static string references.
37 *
38 * This allows you to use string literals to refer to named dimensions, for example you might want
39 * to construct a tensor with a shape of
40 * `[("batch", 1000), ("height", 100), ("width", 100), ("rgba", 4)]`.
41 *
42 * Alternatively you can define the strings once as constants and refer to your dimension
43 * names by the constant identifiers.
44 *
45 * ```
46 * const BATCH: &'static str = "batch";
47 * const HEIGHT: &'static str = "height";
48 * const WIDTH: &'static str = "width";
49 * const RGBA: &'static str = "rgba";
50 * ```
51 *
52 * Although `Dimension` is interchangable with `&'static str` as it is just a type alias, Easy ML
53 * uses `Dimension` whenever dimension names are expected to distinguish the types from just
54 * strings.
55 */
56pub type Dimension = &'static str;
57
58/**
59 * An error indicating failure to do something with a Tensor because the requested shape
60 * is not valid.
61 */
62#[derive(Clone, Debug, Eq, PartialEq)]
63pub struct InvalidShapeError<const D: usize> {
64    shape: [(Dimension, usize); D],
65}
66
67impl<const D: usize> InvalidShapeError<D> {
68    /**
69     * Checks if this shape is valid. This is mainly for internal library use but may also be
70     * useful for unit testing.
71     *
72     * Note: in some functions and methods, an InvalidShapeError may be returned which is a valid
73     * shape, but not the right size for the quantity of data provided.
74     */
75    pub fn is_valid(&self) -> bool {
76        !crate::tensors::dimensions::has_duplicates(&self.shape)
77            && !self.shape.iter().any(|d| d.1 == 0)
78    }
79
80    /**
81     * Constructs an InvalidShapeError for assistance with unit testing. Note that you can
82     * construct an InvalidShapeError that *is* a valid shape in this way.
83     */
84    pub fn new(shape: [(Dimension, usize); D]) -> InvalidShapeError<D> {
85        InvalidShapeError { shape }
86    }
87
88    pub fn shape(&self) -> [(Dimension, usize); D] {
89        self.shape
90    }
91
92    pub fn shape_ref(&self) -> &[(Dimension, usize); D] {
93        &self.shape
94    }
95
96    // Panics if the shape is invalid for any reason with the appropriate error message.
97    #[track_caller]
98    #[inline]
99    fn validate_dimensions_or_panic(shape: &[(Dimension, usize); D], data_len: usize) {
100        let elements = crate::tensors::dimensions::elements(shape);
101        if data_len != elements {
102            panic!(
103                "Product of dimension lengths must match size of data. {} != {}",
104                elements, data_len
105            );
106        }
107        if crate::tensors::dimensions::has_duplicates(shape) {
108            panic!("Dimension names must all be unique: {:?}", &shape);
109        }
110        if shape.iter().any(|d| d.1 == 0) {
111            panic!("No dimension can have 0 elements: {:?}", &shape);
112        }
113    }
114
115    // Returns true if the shape is valid and matches the data length
116    fn validate_dimensions(shape: &[(Dimension, usize); D], data_len: usize) -> bool {
117        let elements = crate::tensors::dimensions::elements(shape);
118        data_len == elements
119            && !crate::tensors::dimensions::has_duplicates(shape)
120            && !shape.iter().any(|d| d.1 == 0)
121    }
122}
123
124impl<const D: usize> fmt::Display for InvalidShapeError<D> {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        write!(
127            f,
128            "Dimensions must all be at least length 1 with unique names: {:?}",
129            self.shape
130        )
131    }
132}
133
134impl<const D: usize> Error for InvalidShapeError<D> {}
135
136/**
137 * An error indicating failure to do something with a Tensor because the dimension names that
138 * were provided did not match with the dimension names that were valid.
139 *
140 * Typically this would be due to the same dimension name being provided multiple times, or a
141 * dimension name being provided that is not present in the shape of the Tensor in use.
142 */
143#[derive(Clone, Debug, Eq, PartialEq)]
144pub struct InvalidDimensionsError<const D: usize, const P: usize> {
145    valid: [Dimension; D],
146    provided: [Dimension; P],
147}
148
149impl<const D: usize, const P: usize> InvalidDimensionsError<D, P> {
150    /**
151     * Checks if the provided dimensions have duplicate names. This is mainly for internal library
152     * use but may also be useful for unit testing.
153     */
154    pub fn has_duplicates(&self) -> bool {
155        crate::tensors::dimensions::has_duplicates_names(&self.provided)
156    }
157
158    // TODO: method to check provided is a subset of valid
159
160    /**
161     * Constructs an InvalidDimensions for assistance with unit testing.
162     */
163    pub fn new(provided: [Dimension; P], valid: [Dimension; D]) -> InvalidDimensionsError<D, P> {
164        InvalidDimensionsError { valid, provided }
165    }
166
167    pub fn provided_names(&self) -> [Dimension; P] {
168        self.provided
169    }
170
171    pub fn provided_names_ref(&self) -> &[Dimension; P] {
172        &self.provided
173    }
174
175    pub fn valid_names(&self) -> [Dimension; D] {
176        self.valid
177    }
178
179    pub fn valid_names_ref(&self) -> &[Dimension; D] {
180        &self.valid
181    }
182}
183
184impl<const D: usize, const P: usize> fmt::Display for InvalidDimensionsError<D, P> {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        if P > 0 {
187            write!(
188                f,
189                "Dimensions names {:?} were incorrect, valid dimensions in this context are: {:?}",
190                self.provided, self.valid
191            )
192        } else {
193            write!(f, "Dimensions names {:?} were incorrect", self.provided)
194        }
195    }
196}
197
198impl<const D: usize, const P: usize> Error for InvalidDimensionsError<D, P> {}
199
200#[test]
201fn test_sync() {
202    fn assert_sync<T: Sync>() {}
203    assert_sync::<InvalidShapeError<2>>();
204    assert_sync::<InvalidDimensionsError<2, 2>>();
205}
206
207#[test]
208fn test_send() {
209    fn assert_send<T: Send>() {}
210    assert_send::<InvalidShapeError<2>>();
211    assert_send::<InvalidDimensionsError<2, 2>>();
212}
213
214/**
215 * A [named tensor](http://nlp.seas.harvard.edu/NamedTensor) of some type `T` and number of
216 * dimensions `D`.
217 *
218 * Tensors are a generalisation of matrices; whereas [Matrix](crate::matrices::Matrix) only
219 * supports 2 dimensions, and vectors are represented in Matrix by making either the rows or
220 * columns have a length of one, [Tensor] supports an arbitary number of dimensions,
221 * with 0 through 6 having full API support. A `Tensor<T, 2>` is very similar to a `Matrix<T>`
222 * except that this type associates each dimension with a name, and favors names to refer to
223 * dimensions instead of index order.
224 *
225 * Like Matrix, the type of the data in this Tensor may implement no traits, in which case the
226 * tensor will be rather useless. If the type implements Clone most storage and accessor methods
227 * are defined and if the type implements Numeric then the tensor can be used in a mathematical
228 * way.
229 *
230 * Like Matrix, a Tensor must always contain at least one element, and it may not not have more
231 * elements than `std::isize::MAX`. Concerned readers should note that on a 64 bit computer this
232 * maximum value is 9,223,372,036,854,775,807 so running out of memory is likely to occur first.
233 *
234 * When doing numeric operations with Tensors you should be careful to not consume a tensor by
235 * accidentally using it by value. All the operations are also defined on references to tensors
236 * so you should favor &x + &y style notation for tensors you intend to continue using.
237 *
238 * See also:
239 * - [indexing]
240 */
241#[derive(Debug)]
242#[cfg_attr(feature = "serde", derive(Serialize))]
243pub struct Tensor<T, const D: usize> {
244    data: Vec<T>,
245    #[cfg_attr(feature = "serde", serde(with = "serde_arrays"))]
246    shape: [(Dimension, usize); D],
247    #[cfg_attr(feature = "serde", serde(skip))]
248    strides: [usize; D],
249}
250
251impl<T, const D: usize> Tensor<T, D> {
252    /**
253     * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
254     *
255     * The product of the dimension lengths corresponds to the number of elements the Tensor
256     * will store. Elements are stored in what would be row major order for a Matrix.
257     * Each step in memory through the N dimensions corresponds to incrementing the rightmost
258     * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
259     * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
260     * for the 25th and final element.
261     *
262     * # Panics
263     *
264     * - If the number of provided elements does not match the product of the dimension lengths.
265     * - If a dimension name is not unique
266     * - If any dimension has 0 elements
267     *
268     * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
269     * a single element (since the product of an empty list is 1).
270     */
271    #[track_caller]
272    pub fn from(shape: [(Dimension, usize); D], data: Vec<T>) -> Self {
273        InvalidShapeError::validate_dimensions_or_panic(&shape, data.len());
274        let strides = compute_strides(&shape);
275        Tensor {
276            data,
277            shape,
278            strides,
279        }
280    }
281
282    /**
283     * Creates a Tensor with a particular shape initialised from a function.
284     *
285     * The product of the dimension lengths corresponds to the number of elements the Tensor
286     * will store. Elements are stored in what would be row major order for a Matrix.
287     * Each step in memory through the N dimensions corresponds to incrementing the rightmost
288     * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
289     * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
290     * for the 25th and final element. These same indexes will be passed to the producer function
291     * to initialised the values for the Tensor.
292     *
293     * ```
294     * use easy_ml::tensors::Tensor;
295     * let tensor = Tensor::from_fn([("rows", 4), ("columns", 4)], |[r, c]| r * c);
296     * assert_eq!(
297     *     tensor,
298     *     Tensor::from([("rows", 4), ("columns", 4)], vec![
299     *         0, 0, 0, 0,
300     *         0, 1, 2, 3,
301     *         0, 2, 4, 6,
302     *         0, 3, 6, 9,
303     *     ])
304     * );
305     * ```
306     *
307     * # Panics
308     *
309     * - If a dimension name is not unique
310     * - If any dimension has 0 elements
311     *
312     * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
313     * a single element (since the product of an empty list is 1).
314     */
315    #[track_caller]
316    pub fn from_fn<F>(shape: [(Dimension, usize); D], mut producer: F) -> Self
317    where
318        F: FnMut([usize; D]) -> T,
319    {
320        let length = dimensions::elements(&shape);
321        let mut data = Vec::with_capacity(length);
322        let iterator = ShapeIterator::from(shape);
323        for index in iterator {
324            data.push(producer(index));
325        }
326        Tensor::from(shape, data)
327    }
328
329    /**
330     * The shape of this tensor. Since Tensors are named Tensors, their shape is not just a
331     * list of lengths along each dimension, but instead a list of pairs of names and lengths.
332     *
333     * See also
334     * - [dimensions]
335     * - [indexing]
336     */
337    pub fn shape(&self) -> [(Dimension, usize); D] {
338        self.shape
339    }
340
341    /**
342     * Returns the length of the dimension name provided, if one is present in the Tensor.
343     *
344     * See also
345     * - [dimensions]
346     * - [indexing]
347     */
348    pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
349        dimensions::length_of(&self.shape, dimension)
350    }
351
352    /**
353     * Returns the last index of the dimension name provided, if one is present in the Tensor.
354     *
355     * This is always 1 less than the length, the 'index' in this sense is based on what the
356     * Tensor's shape is, not any implementation index.
357     *
358     * See also
359     * - [dimensions]
360     * - [indexing]
361     */
362    pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
363        dimensions::last_index_of(&self.shape, dimension)
364    }
365
366    /**
367     * A non panicking version of [from](Tensor::from) which returns `Result::Err` if the input
368     * is invalid.
369     *
370     * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
371     *
372     * The product of the dimension lengths corresponds to the number of elements the Tensor
373     * will store. Elements are stored in what would be row major order for a Matrix.
374     * Each step in memory through the N dimensions corresponds to incrementing the rightmost
375     * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
376     * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
377     * for the 25th and final element.
378     *
379     * Returns the Err variant if
380     * - If the number of provided elements does not match the product of the dimension lengths.
381     * - If a dimension name is not unique
382     * - If any dimension has 0 elements
383     *
384     * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
385     * a single element (since the product of an empty list is 1).
386     */
387    pub fn try_from(
388        shape: [(Dimension, usize); D],
389        data: Vec<T>,
390    ) -> Result<Self, InvalidShapeError<D>> {
391        let valid = InvalidShapeError::validate_dimensions(&shape, data.len());
392        if !valid {
393            return Err(InvalidShapeError::new(shape));
394        }
395        let strides = compute_strides(&shape);
396        Ok(Tensor {
397            data,
398            shape,
399            strides,
400        })
401    }
402
403    /// Unverified constructor for interal use when we know the dimensions/data/strides are
404    /// unchanged and don't need reverification
405    pub(crate) fn direct_from(
406        data: Vec<T>,
407        shape: [(Dimension, usize); D],
408        strides: [usize; D],
409    ) -> Self {
410        Tensor {
411            data,
412            shape,
413            strides,
414        }
415    }
416
417    /// Unverified constructor for interal use when we know the dimensions/data/strides are
418    /// the same as the existing instance and don't need reverification
419    #[allow(dead_code)] // pretty sure something else will want this in the future
420    pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> Self {
421        Tensor {
422            data,
423            shape: self.shape,
424            strides: self.strides,
425        }
426    }
427}
428
429impl<T> Tensor<T, 0> {
430    /**
431     * Creates a 0 dimensional tensor from some scalar
432     */
433    pub fn from_scalar(value: T) -> Tensor<T, 0> {
434        Tensor {
435            data: vec![value],
436            shape: [],
437            strides: [],
438        }
439    }
440
441    /**
442     * Returns the sole element of the 0 dimensional tensor.
443     */
444    pub fn into_scalar(self) -> T {
445        self.data
446            .into_iter()
447            .next()
448            .expect("Tensors always have at least 1 element")
449    }
450}
451
452impl<T> Tensor<T, 0>
453where
454    T: Clone,
455{
456    /**
457     * Returns a copy of the sole element in the 0 dimensional tensor.
458     */
459    pub fn scalar(&self) -> T {
460        self.data
461            .first()
462            .expect("Tensors always have at least 1 element")
463            .clone()
464    }
465}
466
467impl<T> From<T> for Tensor<T, 0> {
468    fn from(scalar: T) -> Tensor<T, 0> {
469        Tensor::from_scalar(scalar)
470    }
471}
472// TODO: See if we can find a way to write the reverse Tensor<T, 0> -> T conversion using From or Into (doesn't seem like we can?)
473
474// # Safety
475//
476// We promise to never implement interior mutability for Tensor.
477/**
478 * A Tensor implements TensorRef.
479 */
480unsafe impl<T, const D: usize> TensorRef<T, D> for Tensor<T, D> {
481    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
482        let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
483        self.data.get(i)
484    }
485
486    fn view_shape(&self) -> [(Dimension, usize); D] {
487        Tensor::shape(self)
488    }
489
490    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
491        unsafe {
492            // The point of get_reference_unchecked is no bounds checking, and therefore
493            // it does not make any sense to just use `unwrap` here. The trait documents that
494            // it's undefind behaviour to call this method with an out of bounds index, so we
495            // can assume the None case will never happen.
496            let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
497            self.data.get_unchecked(i)
498        }
499    }
500
501    fn data_layout(&self) -> DataLayout<D> {
502        // We always have our memory in most significant to least
503        DataLayout::Linear(std::array::from_fn(|i| self.shape[i].0))
504    }
505}
506
507// # Safety
508//
509// We promise to never implement interior mutability for Tensor.
510unsafe impl<T, const D: usize> TensorMut<T, D> for Tensor<T, D> {
511    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
512        let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
513        self.data.get_mut(i)
514    }
515
516    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
517        unsafe {
518            // The point of get_reference_unchecked_mut is no bounds checking, and therefore
519            // it does not make any sense to just use `unwrap` here. The trait documents that
520            // it's undefind behaviour to call this method with an out of bounds index, so we
521            // can assume the None case will never happen.
522            let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
523            self.data.get_unchecked_mut(i)
524        }
525    }
526}
527
528/**
529 * Any tensor of a Cloneable type implements Clone.
530 */
531impl<T: Clone, const D: usize> Clone for Tensor<T, D> {
532    fn clone(&self) -> Self {
533        self.map(|element| element)
534    }
535}
536
537/**
538 * Any tensor of a Displayable type implements Display
539 *
540 * You can control the precision of the formatting using format arguments, i.e.
541 * `format!("{:.3}", tensor)`
542 */
543impl<T: std::fmt::Display, const D: usize> std::fmt::Display for Tensor<T, D> {
544    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
545        crate::tensors::display::format_view(self, f)
546    }
547}
548
549/**
550 * Any 2 dimensional tensor can be converted to a matrix with rows equal to the length of the
551 * first dimension in the tensor, and columns equal to the length of the second.
552 */
553impl<T> From<Tensor<T, 2>> for crate::matrices::Matrix<T> {
554    fn from(tensor: Tensor<T, 2>) -> Self {
555        crate::matrices::Matrix::from_flat_row_major(
556            (tensor.shape[0].1, tensor.shape[1].1),
557            tensor.data,
558        )
559    }
560}
561
562pub(crate) fn compute_strides<const D: usize>(shape: &[(Dimension, usize); D]) -> [usize; D] {
563    std::array::from_fn(|d| shape.iter().skip(d + 1).map(|d| d.1).product())
564}
565
566/// returns the 1 dimensional index to use to get the requested index into some tensor
567#[inline]
568pub(crate) fn get_index_direct<const D: usize>(
569    // indexes to use
570    indexes: &[usize; D],
571    // strides for indexing into the tensor
572    strides: &[usize; D],
573    // shape of the tensor to index into
574    shape: &[(Dimension, usize); D],
575) -> Option<usize> {
576    let mut index = 0;
577    for d in 0..D {
578        let n = indexes[d];
579        if n >= shape[d].1 {
580            return None;
581        }
582        index += n * strides[d];
583    }
584    Some(index)
585}
586
587/// returns the 1 dimensional index to use to get the requested index into some tensor, without
588/// checking the indexes are within bounds for the shape.
589#[inline]
590fn get_index_direct_unchecked<const D: usize>(
591    // indexes to use
592    indexes: &[usize; D],
593    // strides for indexing into the tensor
594    strides: &[usize; D],
595) -> usize {
596    let mut index = 0;
597    for d in 0..D {
598        let n = indexes[d];
599        index += n * strides[d];
600    }
601    index
602}
603
604impl<T, const D: usize> Tensor<T, D> {
605    pub fn view(&self) -> TensorView<T, &Tensor<T, D>, D> {
606        TensorView::from(self)
607    }
608
609    pub fn view_mut(&mut self) -> TensorView<T, &mut Tensor<T, D>, D> {
610        TensorView::from(self)
611    }
612
613    pub fn view_owned(self) -> TensorView<T, Tensor<T, D>, D> {
614        TensorView::from(self)
615    }
616
617    /**
618     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
619     * to read values from this tensor.
620     *
621     * # Panics
622     *
623     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
624     */
625    #[track_caller]
626    pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &Tensor<T, D>, D> {
627        TensorAccess::from(self, dimensions)
628    }
629
630    /**
631     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
632     * to read or write values from this tensor.
633     *
634     * # Panics
635     *
636     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
637     */
638    #[track_caller]
639    pub fn index_by_mut(
640        &mut self,
641        dimensions: [Dimension; D],
642    ) -> TensorAccess<T, &mut Tensor<T, D>, D> {
643        TensorAccess::from(self, dimensions)
644    }
645
646    /**
647     * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
648     * to read or write values from this tensor.
649     *
650     * # Panics
651     *
652     * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
653     */
654    #[track_caller]
655    pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, Tensor<T, D>, D> {
656        TensorAccess::from(self, dimensions)
657    }
658
659    /**
660     * Creates a TensorAccess which will index into the dimensions this Tensor was created with
661     * in the same order as they were provided. See [TensorAccess::from_source_order].
662     */
663    pub fn index(&self) -> TensorAccess<T, &Tensor<T, D>, D> {
664        TensorAccess::from_source_order(self)
665    }
666
667    /**
668     * Creates a TensorAccess which will index into the dimensions this Tensor was
669     * created with in the same order as they were provided. The TensorAccess mutably borrows
670     * the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
671     */
672    pub fn index_mut(&mut self) -> TensorAccess<T, &mut Tensor<T, D>, D> {
673        TensorAccess::from_source_order(self)
674    }
675
676    /**
677     * Creates a TensorAccess which will index into the dimensions this Tensor was
678     * created with in the same order as they were provided. The TensorAccess takes ownership
679     * of the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
680     */
681    pub fn index_owned(self) -> TensorAccess<T, Tensor<T, D>, D> {
682        TensorAccess::from_source_order(self)
683    }
684
685    /**
686     * Returns an iterator over references to the data in this Tensor.
687     */
688    pub fn iter_reference(&self) -> TensorReferenceIterator<T, Tensor<T, D>, D> {
689        TensorReferenceIterator::from(self)
690    }
691
692    /**
693     * Returns an iterator over mutable references to the data in this Tensor.
694     */
695    pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<T, Tensor<T, D>, D> {
696        TensorReferenceMutIterator::from(self)
697    }
698
699    /**
700     * Creates an iterator over the values in this Tensor.
701     */
702    pub fn iter_owned(self) -> TensorOwnedIterator<T, Tensor<T, D>, D>
703    where
704        T: Default,
705    {
706        TensorOwnedIterator::from(self)
707    }
708
709    // Non public index order reference iterator since we don't want to expose our implementation
710    // details to public API since then we could never change them.
711    pub(crate) fn direct_iter_reference(&self) -> std::slice::Iter<T> {
712        self.data.iter()
713    }
714
715    // Non public index order reference iterator since we don't want to expose our implementation
716    // details to public API since then we could never change them.
717    pub(crate) fn direct_iter_reference_mut(&mut self) -> std::slice::IterMut<T> {
718        self.data.iter_mut()
719    }
720
721    /**
722     * Renames the dimension names of the tensor without changing the lengths of the dimensions
723     * in the tensor or moving any data around.
724     *
725     * ```
726     * use easy_ml::tensors::Tensor;
727     * let mut tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6]);
728     * tensor.rename(["y", "z"]);
729     * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
730     * ```
731     *
732     * # Panics
733     *
734     * - If a dimension name is not unique
735     */
736    #[track_caller]
737    pub fn rename(&mut self, dimensions: [Dimension; D]) {
738        if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
739            panic!("Dimension names must all be unique: {:?}", &dimensions);
740        }
741        #[allow(clippy::needless_range_loop)]
742        for d in 0..D {
743            self.shape[d].0 = dimensions[d];
744        }
745    }
746
747    /**
748     * Renames the dimension names of the tensor and returns it without changing the lengths
749     * of the dimensions in the tensor or moving any data around.
750     *
751     * ```
752     * use easy_ml::tensors::Tensor;
753     * let tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6])
754     *     .rename_owned(["y", "z"]);
755     * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
756     * ```
757     *
758     * # Panics
759     *
760     * - If a dimension name is not unique
761     */
762    #[track_caller]
763    pub fn rename_owned(mut self, dimensions: [Dimension; D]) -> Tensor<T, D> {
764        self.rename(dimensions);
765        self
766    }
767
768    /**
769     * Returns a TensorView with the dimension names of the shape renamed to the provided
770     * dimensions. The data of this tensor and the dimension lengths and order remain unchanged.
771     *
772     * This is a shorthand for constructing the TensorView from this Tensor.
773     *
774     * ```
775     * use easy_ml::tensors::Tensor;
776     * use easy_ml::tensors::views::{TensorView, TensorRename};
777     * let abc = Tensor::from([("a", 3), ("b", 3), ("c", 3)], (0..27).collect());
778     * let xyz = abc.rename_view(["x", "y", "z"]);
779     * let also_xyz = TensorView::from(TensorRename::from(&abc, ["x", "y", "z"]));
780     * assert_eq!(xyz, also_xyz);
781     * assert_eq!(xyz, Tensor::from([("x", 3), ("y", 3), ("z", 3)], (0..27).collect()));
782     * ```
783     *
784     * # Panics
785     *
786     * - If a dimension name is not unique
787     */
788    #[track_caller]
789    pub fn rename_view(
790        &self,
791        dimensions: [Dimension; D],
792    ) -> TensorView<T, TensorRename<T, &Tensor<T, D>, D>, D> {
793        TensorView::from(TensorRename::from(self, dimensions))
794    }
795
796    /**
797     * Changes the shape of the tensor without changing the number of dimensions or moving any
798     * data around.
799     *
800     * # Panics
801     *
802     * - If the number of provided elements in the new shape does not match the product of the
803     * dimension lengths in the existing tensor's shape.
804     * - If a dimension name is not unique
805     * - If any dimension has 0 elements
806     *
807     * ```
808     * use easy_ml::tensors::Tensor;
809     * let mut tensor = Tensor::from([("width", 2), ("height", 2)], vec![
810     *     1, 2,
811     *     3, 4
812     * ]);
813     * tensor.reshape_mut([("batch", 1), ("image", 4)]);
814     * assert_eq!(tensor, Tensor::from([("batch", 1), ("image", 4)], vec![ 1, 2, 3, 4 ]));
815     * ```
816     */
817    #[track_caller]
818    pub fn reshape_mut(&mut self, shape: [(Dimension, usize); D]) {
819        InvalidShapeError::validate_dimensions_or_panic(&shape, self.data.len());
820        let strides = compute_strides(&shape);
821        self.shape = shape;
822        self.strides = strides;
823    }
824
825    /**
826     * Consumes the tensor and changes the shape of the tensor without moving any
827     * data around. The new Tensor may also have a different number of dimensions.
828     *
829     * # Panics
830     *
831     * - If the number of provided elements in the new shape does not match the product of the
832     * dimension lengths in the existing tensor's shape.
833     * - If a dimension name is not unique
834     *
835     * ```
836     * use easy_ml::tensors::Tensor;
837     * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
838     *     1, 2,
839     *     3, 4
840     * ]);
841     * let flattened = tensor.reshape_owned([("image", 4)]);
842     * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
843     * ```
844     *
845     * See also [reshape_view_owned](Tensor::reshape_view_owned)
846     */
847    #[track_caller]
848    pub fn reshape_owned<const D2: usize>(self, shape: [(Dimension, usize); D2]) -> Tensor<T, D2> {
849        Tensor::from(shape, self.data)
850    }
851
852    /**
853     * Returns a TensorView with the dimensions changed to the provided shape without moving any
854     * data around. The new Tensor may also have a different number of dimensions.
855     *
856     * This is a shorthand for constructing the TensorView from this Tensor.
857     *
858     * # Panics
859     *
860     * - If the number of provided elements in the new shape does not match the product of the
861     * dimension lengths in the existing tensor's shape.
862     * - If a dimension name is not unique
863     *
864     * ```
865     * use easy_ml::tensors::Tensor;
866     * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
867     *     1, 2,
868     *     3, 4
869     * ]);
870     * let flattened = tensor.reshape_view([("image", 4)]);
871     * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
872     * ```
873     */
874    pub fn reshape_view<const D2: usize>(
875        &self,
876        shape: [(Dimension, usize); D2],
877    ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, D2>, D2> {
878        TensorView::from(TensorReshape::from(self, shape))
879    }
880
881    /**
882     * Returns a TensorView with the dimensions changed to the provided shape without moving any
883     * data around. The new Tensor may also have a different number of dimensions.
884     *
885     * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
886     * mutably borrows this Tensor, and can therefore mutate it
887     *
888     * # Panics
889     *
890     * - If the number of provided elements in the new shape does not match the product of the
891     * dimension lengths in the existing tensor's shape.
892     * - If a dimension name is not unique
893     */
894    pub fn reshape_view_mut<const D2: usize>(
895        &mut self,
896        shape: [(Dimension, usize); D2],
897    ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, D2>, D2> {
898        TensorView::from(TensorReshape::from(self, shape))
899    }
900
901    /**
902     * Returns a TensorView with the dimensions changed to the provided shape without moving any
903     * data around. The new Tensor may also have a different number of dimensions.
904     *
905     * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
906     * takes ownership of this Tensor, and can therefore mutate it
907     *
908     * # Panics
909     *
910     * - If the number of provided elements in the new shape does not match the product of the
911     * dimension lengths in the existing tensor's shape.
912     * - If a dimension name is not unique
913     *
914     * See also [reshape_owned](Tensor::reshape_owned)
915     */
916    pub fn reshape_view_owned<const D2: usize>(
917        self,
918        shape: [(Dimension, usize); D2],
919    ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, D2>, D2> {
920        TensorView::from(TensorReshape::from(self, shape))
921    }
922
923    /**
924     * Given the dimension name, returns a view of this tensor reshaped to one dimension
925     * with a length equal to the number of elements in this tensor.
926     */
927    pub fn flatten_view(
928        &self,
929        dimension: Dimension,
930    ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, 1>, 1> {
931        self.reshape_view([(dimension, dimensions::elements(&self.shape))])
932    }
933
934    /**
935     * Given the dimension name, returns a view of this tensor reshaped to one dimension
936     * with a length equal to the number of elements in this tensor.
937     */
938    pub fn flatten_view_mut(
939        &mut self,
940        dimension: Dimension,
941    ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, 1>, 1> {
942        self.reshape_view_mut([(dimension, dimensions::elements(&self.shape))])
943    }
944
945    /**
946     * Given the dimension name, returns a view of this tensor reshaped to one dimension
947     * with a length equal to the number of elements in this tensor.
948     *
949     * If you intend to query the tensor a lot after creating the view, consider
950     * using [flatten](Tensor::flatten) instead as it will have less overhead to
951     * index after creation.
952     */
953    pub fn flatten_view_owned(
954        self,
955        dimension: Dimension,
956    ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, 1>, 1> {
957        let length = dimensions::elements(&self.shape);
958        self.reshape_view_owned([(dimension, length)])
959    }
960
961    /**
962     * Given the dimension name, returns a new tensor reshaped to one dimension
963     * with a length equal to the number of elements in this tensor.
964     */
965    pub fn flatten(self, dimension: Dimension) -> Tensor<T, 1> {
966        let length = dimensions::elements(&self.shape);
967        self.reshape_owned([(dimension, length)])
968    }
969
970    /**
971     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
972     * range from view. Error cases are documented on [TensorRange].
973     *
974     * This is a shorthand for constructing the TensorView from this Tensor.
975     *
976     * ```
977     * use easy_ml::tensors::Tensor;
978     * use easy_ml::tensors::views::{TensorView, TensorRange, IndexRange};
979     * # use easy_ml::tensors::views::IndexRangeValidationError;
980     * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
981     * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
982     * let cropped = samples.range([("x", IndexRange::new(1, 5)), ("y", IndexRange::new(1, 5))])?;
983     * let also_cropped = TensorView::from(
984     *     TensorRange::from(&samples, [("x", 1..6), ("y", 1..6)])?
985     * );
986     * assert_eq!(cropped, also_cropped);
987     * assert_eq!(
988     *     cropped.select([("batch", 0)]),
989     *     Tensor::from([("x", 5), ("y", 5)], vec![
990     *          8,  9, 10, 11, 12,
991     *         15, 16, 17, 18, 19,
992     *         22, 23, 24, 25, 26,
993     *         29, 30, 31, 32, 33,
994     *         36, 37, 38, 39, 40
995     *     ])
996     * );
997     * # Ok(())
998     * # }
999     * ```
1000     */
1001    pub fn range<R, const P: usize>(
1002        &self,
1003        ranges: [(Dimension, R); P],
1004    ) -> Result<TensorView<T, TensorRange<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1005    where
1006        R: Into<IndexRange>,
1007    {
1008        TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1009    }
1010
1011    /**
1012     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1013     * range from view. Error cases are documented on [TensorRange]. The TensorRange
1014     * mutably borrows this Tensor, and can therefore mutate it
1015     *
1016     * This is a shorthand for constructing the TensorView from this Tensor.
1017     */
1018    pub fn range_mut<R, const P: usize>(
1019        &mut self,
1020        ranges: [(Dimension, R); P],
1021    ) -> Result<
1022        TensorView<T, TensorRange<T, &mut Tensor<T, D>, D>, D>,
1023        IndexRangeValidationError<D, P>,
1024    >
1025    where
1026        R: Into<IndexRange>,
1027    {
1028        TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1029    }
1030
1031    /**
1032     * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1033     * range from view. Error cases are documented on [TensorRange]. The TensorRange
1034     * takes ownership of this Tensor, and can therefore mutate it
1035     *
1036     * This is a shorthand for constructing the TensorView from this Tensor.
1037     */
1038    pub fn range_owned<R, const P: usize>(
1039        self,
1040        ranges: [(Dimension, R); P],
1041    ) -> Result<TensorView<T, TensorRange<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1042    where
1043        R: Into<IndexRange>,
1044    {
1045        TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1046    }
1047
1048    /**
1049     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1050     * range from view. Error cases are documented on [TensorMask].
1051     *
1052     * This is a shorthand for constructing the TensorView from this Tensor.
1053     *
1054     * ```
1055     * use easy_ml::tensors::Tensor;
1056     * use easy_ml::tensors::views::{TensorView, TensorMask, IndexRange};
1057     * # use easy_ml::tensors::views::IndexRangeValidationError;
1058     * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
1059     * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
1060     * let corners = samples.mask([("x", IndexRange::new(3, 2)), ("y", IndexRange::new(3, 2))])?;
1061     * let also_corners = TensorView::from(
1062     *     TensorMask::from(&samples, [("x", 3..5), ("y", 3..5)])?
1063     * );
1064     * assert_eq!(corners, also_corners);
1065     * assert_eq!(
1066     *     corners.select([("batch", 0)]),
1067     *     Tensor::from([("x", 5), ("y", 5)], vec![
1068     *          0,  1,  2,    5, 6,
1069     *          7,  8,  9,   12, 13,
1070     *         14, 15, 16,   19, 20,
1071     *
1072     *         35, 36, 37,   40, 41,
1073     *         42, 43, 44,   47, 48
1074     *     ])
1075     * );
1076     * # Ok(())
1077     * # }
1078     * ```
1079     */
1080    pub fn mask<R, const P: usize>(
1081        &self,
1082        masks: [(Dimension, R); P],
1083    ) -> Result<TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1084    where
1085        R: Into<IndexRange>,
1086    {
1087        TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1088    }
1089
1090    /**
1091     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1092     * range from view. Error cases are documented on [TensorMask]. The TensorMask
1093     * mutably borrows this Tensor, and can therefore mutate it
1094     *
1095     * This is a shorthand for constructing the TensorView from this Tensor.
1096     */
1097    pub fn mask_mut<R, const P: usize>(
1098        &mut self,
1099        masks: [(Dimension, R); P],
1100    ) -> Result<
1101        TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D>,
1102        IndexRangeValidationError<D, P>,
1103    >
1104    where
1105        R: Into<IndexRange>,
1106    {
1107        TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1108    }
1109
1110    /**
1111     * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1112     * range from view. Error cases are documented on [TensorMask]. The TensorMask
1113     * takes ownership of this Tensor, and can therefore mutate it
1114     *
1115     * This is a shorthand for constructing the TensorView from this Tensor.
1116     */
1117    pub fn mask_owned<R, const P: usize>(
1118        self,
1119        masks: [(Dimension, R); P],
1120    ) -> Result<TensorView<T, TensorMask<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1121    where
1122        R: Into<IndexRange>,
1123    {
1124        TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1125    }
1126
1127    /**
1128     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1129     * order. The data of this tensor and the dimension lengths remain unchanged.
1130     *
1131     * This is a shorthand for constructing the TensorView from this Tensor.
1132     *
1133     * ```
1134     * use easy_ml::tensors::Tensor;
1135     * use easy_ml::tensors::views::{TensorView, TensorReverse};
1136     * let ab = Tensor::from([("a", 2), ("b", 3)], (0..6).collect());
1137     * let reversed = ab.reverse(&["a"]);
1138     * let also_reversed = TensorView::from(TensorReverse::from(&ab, &["a"]));
1139     * assert_eq!(reversed, also_reversed);
1140     * assert_eq!(
1141     *     reversed,
1142     *     Tensor::from(
1143     *         [("a", 2), ("b", 3)],
1144     *         vec![
1145     *             3, 4, 5,
1146     *             0, 1, 2,
1147     *         ]
1148     *     )
1149     * );
1150     * ```
1151     *
1152     * # Panics
1153     *
1154     * - If a dimension name is not in the tensor's shape or is repeated.
1155     */
1156    #[track_caller]
1157    pub fn reverse(
1158        &self,
1159        dimensions: &[Dimension],
1160    ) -> TensorView<T, TensorReverse<T, &Tensor<T, D>, D>, D> {
1161        TensorView::from(TensorReverse::from(self, dimensions))
1162    }
1163
1164    /**
1165     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1166     * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1167     * mutably borrows this Tensor, and can therefore mutate it
1168     *
1169     * This is a shorthand for constructing the TensorView from this Tensor.
1170     *
1171     * # Panics
1172     *
1173     * - If a dimension name is not in the tensor's shape or is repeated.
1174     */
1175    #[track_caller]
1176    pub fn reverse_mut(
1177        &mut self,
1178        dimensions: &[Dimension],
1179    ) -> TensorView<T, TensorReverse<T, &mut Tensor<T, D>, D>, D> {
1180        TensorView::from(TensorReverse::from(self, dimensions))
1181    }
1182
1183    /**
1184     * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1185     * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1186     * takes ownership of this Tensor, and can therefore mutate it
1187     *
1188     * This is a shorthand for constructing the TensorView from this Tensor.
1189     *
1190     * # Panics
1191     *
1192     * - If a dimension name is not in the tensor's shape or is repeated.
1193     */
1194    #[track_caller]
1195    pub fn reverse_owned(
1196        self,
1197        dimensions: &[Dimension],
1198    ) -> TensorView<T, TensorReverse<T, Tensor<T, D>, D>, D> {
1199        TensorView::from(TensorReverse::from(self, dimensions))
1200    }
1201
1202    /**
1203     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1204     * mapped by a function. The value pairs are not copied for you, if you're using `Copy` types
1205     * or need to clone the values anyway, you can use
1206     * [`Tensor::elementwise`](Tensor::elementwise) instead.
1207     *
1208     * # Generics
1209     *
1210     * This method can be called with any right hand side that can be converted to a TensorView,
1211     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1212     *
1213     * # Panics
1214     *
1215     * If the two tensors have different shapes.
1216     */
1217    #[track_caller]
1218    pub fn elementwise_reference<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1219    where
1220        I: Into<TensorView<T, S, D>>,
1221        S: TensorRef<T, D>,
1222        M: Fn(&T, &T) -> T,
1223    {
1224        self.elementwise_reference_less_generic(rhs.into(), mapping_function)
1225    }
1226
1227    /**
1228     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1229     * mapped by a function. The mapping function also receives each index corresponding to the
1230     * value pairs. The value pairs are not copied for you, if you're using `Copy` types
1231     * or need to clone the values anyway, you can use
1232     * [`Tensor::elementwise_with_index`](Tensor::elementwise_with_index) instead.
1233     *
1234     * # Generics
1235     *
1236     * This method can be called with any right hand side that can be converted to a TensorView,
1237     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1238     *
1239     * # Panics
1240     *
1241     * If the two tensors have different shapes.
1242     */
1243    #[track_caller]
1244    pub fn elementwise_reference_with_index<S, I, M>(
1245        &self,
1246        rhs: I,
1247        mapping_function: M,
1248    ) -> Tensor<T, D>
1249    where
1250        I: Into<TensorView<T, S, D>>,
1251        S: TensorRef<T, D>,
1252        M: Fn([usize; D], &T, &T) -> T,
1253    {
1254        self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
1255    }
1256
1257    #[track_caller]
1258    fn elementwise_reference_less_generic<S, M>(
1259        &self,
1260        rhs: TensorView<T, S, D>,
1261        mapping_function: M,
1262    ) -> Tensor<T, D>
1263    where
1264        S: TensorRef<T, D>,
1265        M: Fn(&T, &T) -> T,
1266    {
1267        let left_shape = self.shape();
1268        let right_shape = rhs.shape();
1269        if left_shape != right_shape {
1270            panic!(
1271                "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1272                left_shape, right_shape
1273            );
1274        }
1275        let mapped = self
1276            .direct_iter_reference()
1277            .zip(rhs.iter_reference())
1278            .map(|(x, y)| mapping_function(x, y))
1279            .collect();
1280        // We're not changing the shape of the Tensor, so don't need to revalidate
1281        Tensor::direct_from(mapped, self.shape, self.strides)
1282    }
1283
1284    #[track_caller]
1285    fn elementwise_reference_less_generic_with_index<S, M>(
1286        &self,
1287        rhs: TensorView<T, S, D>,
1288        mapping_function: M,
1289    ) -> Tensor<T, D>
1290    where
1291        S: TensorRef<T, D>,
1292        M: Fn([usize; D], &T, &T) -> T,
1293    {
1294        let left_shape = self.shape();
1295        let right_shape = rhs.shape();
1296        if left_shape != right_shape {
1297            panic!(
1298                "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1299                left_shape, right_shape
1300            );
1301        }
1302        // we just checked both shapes were the same, so we don't need to propagate indexes
1303        // for both tensors because they'll be identical
1304        let mapped = self
1305            .direct_iter_reference()
1306            .zip(rhs.iter_reference().with_index())
1307            .map(|(x, (i, y))| mapping_function(i, x, y))
1308            .collect();
1309        // We're not changing the shape of the Tensor, so don't need to revalidate
1310        Tensor::direct_from(mapped, self.shape, self.strides)
1311    }
1312
1313    /**
1314     * Returns a TensorView which makes the order of the data in this tensor appear to be in
1315     * a different order. The order of the dimension names is unchanged, although their lengths
1316     * may swap.
1317     *
1318     * This is a shorthand for constructing the TensorView from this Tensor.
1319     *
1320     * See also: [transpose](Tensor::transpose), [TensorTranspose]
1321     *
1322     * # Panics
1323     *
1324     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1325     * order need not match.
1326     */
1327    pub fn transpose_view(
1328        &self,
1329        dimensions: [Dimension; D],
1330    ) -> TensorView<T, TensorTranspose<T, &Tensor<T, D>, D>, D> {
1331        TensorView::from(TensorTranspose::from(self, dimensions))
1332    }
1333}
1334
1335impl<T, const D: usize> Tensor<T, D>
1336where
1337    T: Clone,
1338{
1339    /**
1340     * Creates a tensor with a particular number of dimensions and length in each dimension
1341     * with all elements initialised to the provided value.
1342     *
1343     * # Panics
1344     *
1345     * - If a dimension name is not unique
1346     * - If any dimension has 0 elements
1347     */
1348    #[track_caller]
1349    pub fn empty(shape: [(Dimension, usize); D], value: T) -> Self {
1350        let elements = crate::tensors::dimensions::elements(&shape);
1351        Tensor::from(shape, vec![value; elements])
1352    }
1353
1354    /**
1355     * Gets a copy of the first value in this tensor.
1356     * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
1357     * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
1358     */
1359    pub fn first(&self) -> T {
1360        self.data
1361            .first()
1362            .expect("Tensors always have at least 1 element")
1363            .clone()
1364    }
1365
1366    /**
1367     * Returns a new Tensor which has the same data as this tensor, but with the order of data
1368     * changed. The order of the dimension names is unchanged, although their lengths may swap.
1369     *
1370     * For example, with a `[("x", x), ("y", y)]` tensor you could call
1371     * `transpose(["y", "x"])` which would return a new tensor with a shape of
1372     * `[("x", y), ("y", x)]` where every (x,y) of its data corresponds to (y,x) in the original.
1373     *
1374     * This method need not shift *all* the dimensions though, you could also swap the width
1375     * and height of images in a tensor with a shape of
1376     * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `transpose(["batch", "w", "h", "c"])`
1377     * which would return a new tensor where all the images have been swapped over the diagonal.
1378     *
1379     * See also: [TensorAccess], [reorder](Tensor::reorder)
1380     *
1381     * # Panics
1382     *
1383     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1384     * order need not match (and if the order does match, this function is just an expensive
1385     * clone).
1386     */
1387    #[track_caller]
1388    pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1389        let shape = self.shape;
1390        let mut reordered = self.reorder(dimensions);
1391        // Transposition is essentially reordering, but we retain the dimension name ordering
1392        // of the original order, this means we may swap dimension lengths, but the dimensions
1393        // will not change order.
1394        #[allow(clippy::needless_range_loop)]
1395        for d in 0..D {
1396            reordered.shape[d].0 = shape[d].0;
1397        }
1398        reordered
1399    }
1400
1401    /**
1402     * Modifies this tensor to have the same data as before, but with the order of data changed.
1403     * The order of the dimension names is unchanged, although their lengths may swap.
1404     *
1405     * For example, with a `[("x", x), ("y", y)]` tensor you could call
1406     * `transpose_mut(["y", "x"])` which would edit the tensor, updating its shape to
1407     * `[("x", y), ("y", x)]`, so every (x,y) of its data corresponds to (y,x) before the
1408     * transposition.
1409     *
1410     * The order swapping will try to be in place, but this is currently only supported for
1411     * square tensors with 2 dimensions. Other types of tensors will not be transposed in place.
1412     *
1413     * # Panics
1414     *
1415     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1416     * order need not match (and if the order does match, this function is just an expensive
1417     * clone).
1418     */
1419    #[track_caller]
1420    pub fn transpose_mut(&mut self, dimensions: [Dimension; D]) {
1421        let shape = self.shape;
1422        self.reorder_mut(dimensions);
1423        // Transposition is essentially reordering, but we retain the dimension name ordering
1424        // we had before, this means we may swap dimension lengths, but the dimensions
1425        // will not change order.
1426        #[allow(clippy::needless_range_loop)]
1427        for d in 0..D {
1428            self.shape[d].0 = shape[d].0;
1429        }
1430    }
1431
1432    /**
1433     * Returns a new Tensor which has the same data as this tensor, but with the order of the
1434     * dimensions and corresponding order of data changed.
1435     *
1436     * For example, with a `[("x", x), ("y", y)]` tensor you could call
1437     * `reorder(["y", "x"])` which would return a new tensor with a shape of
1438     * `[("y", y), ("x", x)]` where every (y,x) of its data corresponds to (x,y) in the original.
1439     *
1440     * This method need not shift *all* the dimensions though, you could also swap the width
1441     * and height of images in a tensor with a shape of
1442     * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `reorder(["batch", "w", "h", "c"])`
1443     * which would return a new tensor where every (b,w,h,c) of its data corresponds to (b,h,w,c)
1444     * in the original.
1445     *
1446     * See also: [TensorAccess], [transpose](Tensor::transpose)
1447     *
1448     * # Panics
1449     *
1450     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1451     * order need not match (and if the order does match, this function is just an expensive
1452     * clone).
1453     */
1454    #[track_caller]
1455    pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1456        let reorderd = match TensorAccess::try_from(&self, dimensions) {
1457            Ok(reordered) => reordered,
1458            Err(_error) => panic!(
1459                "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1460                dimensions, &self.shape,
1461            ),
1462        };
1463        let reorderd_shape = reorderd.shape();
1464        Tensor::from(reorderd_shape, reorderd.iter().collect())
1465    }
1466
1467    /**
1468     * Modifies this tensor to have the same data as before, but with the order of the
1469     * dimensions and corresponding order of data changed.
1470     *
1471     * For example, with a `[("x", x), ("y", y)]` tensor you could call
1472     * `reorder_mut(["y", "x"])` which would edit the tensor, updating its shape to
1473     * `[("y", y), ("x", x)]`, so every (y,x) of its data corresponds to (x,y) before the
1474     * transposition.
1475     *
1476     * The order swapping will try to be in place, but this is currently only supported for
1477     * square tensors with 2 dimensions. Other types of tensors will not be reordered in place.
1478     *
1479     * # Panics
1480     *
1481     * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1482     * order need not match (and if the order does match, this function is just an expensive
1483     * clone).
1484     */
1485    #[track_caller]
1486    pub fn reorder_mut(&mut self, dimensions: [Dimension; D]) {
1487        use crate::tensors::dimensions::DimensionMappings;
1488        if D == 2 && crate::tensors::dimensions::is_square(&self.shape) {
1489            let dimension_mapping = match DimensionMappings::new(&self.shape, &dimensions) {
1490                Some(dimension_mapping) => dimension_mapping,
1491                None => panic!(
1492                    "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1493                    dimensions, &self.shape,
1494                ),
1495            };
1496
1497            let shape = dimension_mapping.map_shape_to_requested(&self.shape);
1498            let shape_iterator = ShapeIterator::from(shape);
1499
1500            for index in shape_iterator {
1501                let i = index[0];
1502                let j = index[1];
1503                if j >= i {
1504                    let mapped_index = dimension_mapping.map_dimensions_to_source(&index);
1505                    // Swap elements from the upper triangle (using index order of the actual tensor's
1506                    // shape)
1507                    let temp = self.get_reference(index).unwrap().clone();
1508                    // tensor[i,j] becomes tensor[mapping(i,j)]
1509                    *self.get_reference_mut(index).unwrap() =
1510                        self.get_reference(mapped_index).unwrap().clone();
1511                    // tensor[mapping(i,j)] becomes tensor[i,j]
1512                    *self.get_reference_mut(mapped_index).unwrap() = temp;
1513                    // If the mapping is a noop we've assigned i,j to i,j
1514                    // If the mapping is i,j -> j,i we've assigned i,j to j,i and j,i to i,j
1515                }
1516            }
1517
1518            // now update our shape and strides to match
1519            self.shape = shape;
1520            self.strides = compute_strides(&shape);
1521        } else {
1522            // fallback to allocating a new reordered tensor
1523            let reordered = self.reorder(dimensions);
1524            self.data = reordered.data;
1525            self.shape = reordered.shape;
1526            self.strides = reordered.strides;
1527        }
1528    }
1529
1530    /**
1531     * Returns an iterator over copies of the data in this Tensor.
1532     */
1533    pub fn iter(&self) -> TensorIterator<T, Tensor<T, D>, D> {
1534        TensorIterator::from(self)
1535    }
1536
1537    /**
1538     * Creates and returns a new tensor with all values from the original with the
1539     * function applied to each. This can be used to change the type of the tensor
1540     * such as creating a mask:
1541     * ```
1542     * use easy_ml::tensors::Tensor;
1543     * let x = Tensor::from([("a", 2), ("b", 2)], vec![
1544     *    0.0, 1.2,
1545     *    5.8, 6.9
1546     * ]);
1547     * let y = x.map(|element| element > 2.0);
1548     * let result = Tensor::from([("a", 2), ("b", 2)], vec![
1549     *    false, false,
1550     *    true, true
1551     * ]);
1552     * assert_eq!(&y, &result);
1553     * ```
1554     */
1555    pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
1556        let mapped = self
1557            .data
1558            .iter()
1559            .map(|x| mapping_function(x.clone()))
1560            .collect();
1561        // We're not changing the shape of the Tensor, so don't need to revalidate
1562        Tensor::direct_from(mapped, self.shape, self.strides)
1563    }
1564
1565    /**
1566     * Creates and returns a new tensor with all values from the original and
1567     * the index of each value mapped by a function.
1568     */
1569    pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
1570        let mapped = self
1571            .iter()
1572            .with_index()
1573            .map(|(i, x)| mapping_function(i, x))
1574            .collect();
1575        // We're not changing the shape of the Tensor, so don't need to revalidate
1576        Tensor::direct_from(mapped, self.shape, self.strides)
1577    }
1578
1579    /**
1580     * Applies a function to all values in the tensor, modifying
1581     * the tensor in place.
1582     */
1583    pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1584        for value in self.data.iter_mut() {
1585            *value = mapping_function(value.clone());
1586        }
1587    }
1588
1589    /**
1590     * Applies a function to all values and each value's index in the tensor, modifying
1591     * the tensor in place.
1592     */
1593    pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
1594        self.iter_reference_mut()
1595            .with_index()
1596            .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
1597    }
1598
1599    /**
1600     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1601     * mapped by a function.
1602     *
1603     * ```
1604     * use easy_ml::tensors::Tensor;
1605     * let lhs = Tensor::from([("a", 4)], vec![1, 2, 3, 4]);
1606     * let rhs = Tensor::from([("a", 4)], vec![0, 1, 2, 3]);
1607     * let multiplied = lhs.elementwise(&rhs, |l, r| l * r);
1608     * assert_eq!(
1609     *     multiplied,
1610     *     Tensor::from([("a", 4)], vec![0, 2, 6, 12])
1611     * );
1612     * ```
1613     *
1614     * # Generics
1615     *
1616     * This method can be called with any right hand side that can be converted to a TensorView,
1617     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1618     *
1619     * # Panics
1620     *
1621     * If the two tensors have different shapes.
1622     */
1623    #[track_caller]
1624    pub fn elementwise<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1625    where
1626        I: Into<TensorView<T, S, D>>,
1627        S: TensorRef<T, D>,
1628        M: Fn(T, T) -> T,
1629    {
1630        self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
1631            mapping_function(lhs.clone(), rhs.clone())
1632        })
1633    }
1634
1635    /**
1636     * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1637     * mapped by a function. The mapping function also receives each index corresponding to the
1638     * value pairs.
1639     *
1640     * # Generics
1641     *
1642     * This method can be called with any right hand side that can be converted to a TensorView,
1643     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1644     *
1645     * # Panics
1646     *
1647     * If the two tensors have different shapes.
1648     */
1649    #[track_caller]
1650    pub fn elementwise_with_index<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1651    where
1652        I: Into<TensorView<T, S, D>>,
1653        S: TensorRef<T, D>,
1654        M: Fn([usize; D], T, T) -> T,
1655    {
1656        self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
1657            mapping_function(i, lhs.clone(), rhs.clone())
1658        })
1659    }
1660}
1661
1662impl<T> Tensor<T, 1>
1663where
1664    T: Numeric,
1665    for<'a> &'a T: NumericRef<T>,
1666{
1667    /**
1668     * Computes the scalar product of two equal length vectors. For two vectors `[a,b,c]` and
1669     * `[d,e,f]`, returns `a*d + b*e + c*f`. This is also known as the dot product.
1670     *
1671     * ```
1672     * use easy_ml::tensors::Tensor;
1673     * let tensor = Tensor::from([("sequence", 5)], vec![3, 4, 5, 6, 7]);
1674     * assert_eq!(tensor.scalar_product(&tensor), 3*3 + 4*4 + 5*5 + 6*6 + 7*7);
1675     * ```
1676     *
1677     * # Generics
1678     *
1679     * This method can be called with any right hand side that can be converted to a TensorView,
1680     * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1681     *
1682     * # Panics
1683     *
1684     * If the two vectors are not of equal length or their dimension names do not match.
1685     */
1686    // Would like this impl block to be in operations.rs too but then it would show first in the
1687    // Tensor docs which isn't ideal
1688    pub fn scalar_product<S, I>(&self, rhs: I) -> T
1689    where
1690        I: Into<TensorView<T, S, 1>>,
1691        S: TensorRef<T, 1>,
1692    {
1693        self.scalar_product_less_generic(rhs.into())
1694    }
1695}
1696
1697impl<T> Tensor<T, 2>
1698where
1699    T: Numeric,
1700    for<'a> &'a T: NumericRef<T>,
1701{
1702    /**
1703     * Returns the determinant of this square matrix, or None if the matrix
1704     * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant_tensor())
1705     */
1706    pub fn determinant(&self) -> Option<T> {
1707        linear_algebra::determinant_tensor::<T, _, _>(self)
1708    }
1709
1710    /**
1711     * Computes the inverse of a matrix provided that it exists. To have an inverse a
1712     * matrix must be square (same number of rows and columns) and it must also have a
1713     * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse_tensor())
1714     */
1715    pub fn inverse(&self) -> Option<Tensor<T, 2>> {
1716        linear_algebra::inverse_tensor::<T, _, _>(self)
1717    }
1718
1719    /**
1720     * Computes the covariance matrix for this feature matrix along the specified feature
1721     * dimension in this matrix. See [`linear_algebra`](crate::linear_algebra::covariance()).
1722     */
1723    pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
1724        linear_algebra::covariance::<T, _, _>(self, feature_dimension)
1725    }
1726}
1727
1728// FIXME: want this to be callable in the main numeric impl block
1729impl<T> Tensor<T, 2>
1730where
1731    T: Numeric,
1732{
1733    /**
1734     * Creates a diagonal matrix of the provided size with the diagonal elements
1735     * set to the provided value and all other elements in the tensor set to 0.
1736     * A diagonal matrix is always square.
1737     *
1738     * The size is still taken as a shape to facilitate creating a diagonal matrix
1739     * from the dimensionality of an existing one. If the provided value is 1 then
1740     * this will create an identity matrix.
1741     *
1742     * A 3 x 3 identity matrix:
1743     * ```ignore
1744     * [
1745     *   1, 0, 0
1746     *   0, 1, 0
1747     *   0, 0, 1
1748     * ]
1749     * ```
1750     *
1751     * # Panics
1752     *
1753     * - If the shape is not square.
1754     * - If a dimension name is not unique
1755     * - If any dimension has 0 elements
1756     */
1757    #[track_caller]
1758    pub fn diagonal(shape: [(Dimension, usize); 2], value: T) -> Tensor<T, 2> {
1759        if !crate::tensors::dimensions::is_square(&shape) {
1760            panic!("Shape must be square: {:?}", shape);
1761        }
1762        let mut tensor = Tensor::empty(shape, T::zero());
1763        for ([r, c], x) in tensor.iter_reference_mut().with_index() {
1764            if r == c {
1765                *x = value.clone();
1766            }
1767        }
1768        tensor
1769    }
1770}
1771
1772impl<T> Tensor<T, 2> {
1773    /**
1774     * Converts this 2 dimensional Tensor into a Matrix.
1775     *
1776     * This is a wrapper around the `From<Tensor<T, 2>>` implementation.
1777     *
1778     * The Matrix will have the data in the same order, with rows equal to the length of
1779     * the first dimension in the tensor, and columns equal to the length of the second.
1780     */
1781    pub fn into_matrix(self) -> crate::matrices::Matrix<T> {
1782        self.into()
1783    }
1784}
1785
1786/**
1787 * Methods for tensors with numerical real valued types, such as f32 or f64.
1788 *
1789 * This excludes signed and unsigned integers as they do not support decimal
1790 * precision and hence can't be used for operations like square roots.
1791 *
1792 * Third party fixed precision and infinite precision decimal types should
1793 * be able to implement all of the methods for [Real] and then utilise these functions.
1794 */
1795impl<T: Real> Tensor<T, 1>
1796where
1797    for<'a> &'a T: RealRef<T>,
1798{
1799    /**
1800     * Computes the [L2 norm](https://en.wikipedia.org/wiki/Euclidean_vector#Length)
1801     * of this vector, also referred to as the length or magnitude,
1802     * and written as ||x||, or sometimes |x|.
1803     *
1804     * ||**a**|| = sqrt(a<sub>1</sub><sup>2</sup> + a<sub>2</sub><sup>2</sup> + a<sub>3</sub><sup>2</sup>...) = sqrt(**a**<sup>T</sup> * **a**)
1805     *
1806     * This is a shorthand for `(x.iter().map(|x| x * x).sum().sqrt()`, ie
1807     * the square root of the dot product of a vector with itself.
1808     *
1809     * The euclidean length can be used to compute a
1810     * [unit vector](https://en.wikipedia.org/wiki/Unit_vector), that is, a
1811     * vector with length of 1. This should not be confused with a unit matrix,
1812     * which is another name for an identity matrix.
1813     *
1814     * ```
1815     * use easy_ml::tensors::Tensor;
1816     * let a = Tensor::from([("data", 3)], vec![ 1.0, 2.0, 3.0 ]);
1817     * let length = a.euclidean_length(); // (1^2 + 2^2 + 3^2)^0.5
1818     * let unit = a.map(|x| x / length);
1819     * assert_eq!(unit.euclidean_length(), 1.0);
1820     * ```
1821     */
1822    // TODO: Scalar ops for tensors
1823    pub fn euclidean_length(&self) -> T {
1824        self.direct_iter_reference()
1825            .map(|x| x * x)
1826            .sum::<T>()
1827            .sqrt()
1828    }
1829}
1830
1831#[cfg(feature = "serde")]
1832mod serde_impls {
1833    use crate::tensors::{Dimension, InvalidShapeError, Tensor};
1834    use serde::Deserialize;
1835    use std::convert::TryFrom;
1836
1837    /**
1838     * Deserialised data for a Tensor. Can be converted into a Tensor by providing `&'static str`
1839     * dimension names.
1840     */
1841    #[derive(Deserialize)]
1842    #[serde(rename = "Tensor")]
1843    pub struct TensorDeserialize<'a, T, const D: usize> {
1844        data: Vec<T>,
1845        #[serde(with = "serde_arrays")]
1846        #[serde(borrow)]
1847        shape: [(&'a str, usize); D],
1848    }
1849
1850    impl<'a, T, const D: usize> TensorDeserialize<'a, T, D> {
1851        /**
1852         * Converts this deserialised Tensor data to a Tensor, using the provided `&'static str`
1853         * dimension names in place of what was serialised (which wouldn't necessarily live
1854         * long enough).
1855         */
1856        pub fn into_tensor(
1857            self,
1858            dimensions: [Dimension; D],
1859        ) -> Result<Tensor<T, D>, InvalidShapeError<D>> {
1860            let shape = std::array::from_fn(|d| (dimensions[d], self.shape[d].1));
1861            // Safety: Use the normal constructor that performs validation to prevent invalid
1862            // serialized data being created as a Tensor, which would then break all the
1863            // code that's relying on these invariants.
1864            // By never serialising the strides in the first place, we reduce the possibility
1865            // of creating invalid serialised represenations at the slight increase in
1866            // serialisation work.
1867            Tensor::try_from(shape, self.data)
1868        }
1869    }
1870
1871    /**
1872     * Converts this deserialised Tensor data which has a static lifetime for the dimension
1873     * names to a Tensor, using the serialised data.
1874     */
1875    impl<T, const D: usize> TryFrom<TensorDeserialize<'static, T, D>> for Tensor<T, D> {
1876        type Error = InvalidShapeError<D>;
1877
1878        fn try_from(value: TensorDeserialize<'static, T, D>) -> Result<Self, Self::Error> {
1879            Tensor::try_from(value.shape, value.data)
1880        }
1881    }
1882}
1883
1884#[cfg(feature = "serde")]
1885#[test]
1886fn test_serialize() {
1887    fn assert_serialize<T: Serialize>() {}
1888    assert_serialize::<Tensor<f64, 3>>();
1889    assert_serialize::<Tensor<f64, 2>>();
1890    assert_serialize::<Tensor<f64, 1>>();
1891    assert_serialize::<Tensor<f64, 0>>();
1892}
1893
1894#[cfg(feature = "serde")]
1895#[test]
1896fn test_deserialize() {
1897    use serde::Deserialize;
1898    fn assert_deserialize<'de, T: Deserialize<'de>>() {}
1899    assert_deserialize::<TensorDeserialize<f64, 3>>();
1900    assert_deserialize::<TensorDeserialize<f64, 2>>();
1901    assert_deserialize::<TensorDeserialize<f64, 1>>();
1902    assert_deserialize::<TensorDeserialize<f64, 0>>();
1903}
1904
1905#[cfg(feature = "serde")]
1906#[test]
1907fn test_serialization_deserialization_loop() {
1908    #[rustfmt::skip]
1909    let tensor = Tensor::from(
1910        [("rows", 3), ("columns", 4)],
1911        vec![
1912            1,  2,  3,  4,
1913            5,  6,  7,  8,
1914            9, 10, 11, 12
1915        ],
1916    );
1917    let encoded = toml::to_string(&tensor).unwrap();
1918    assert_eq!(
1919        encoded,
1920        r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
1921shape = [["rows", 3], ["columns", 4]]
1922"#,
1923    );
1924    let parsed: Result<TensorDeserialize<i32, 2>, _> = toml::from_str(&encoded);
1925    assert!(parsed.is_ok());
1926    let result = parsed.unwrap().into_tensor(["rows", "columns"]);
1927    assert!(result.is_ok());
1928    assert_eq!(result.unwrap(), tensor);
1929}
1930
1931#[cfg(feature = "serde")]
1932#[test]
1933fn test_deserialization_validation() {
1934    let parsed: Result<TensorDeserialize<i32, 2>, _> = toml::from_str(
1935        r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
1936shape = [["rows", 4], ["columns", 4]]
1937"#,
1938    );
1939    assert!(parsed.is_ok());
1940    let result = parsed.unwrap().into_tensor(["rows", "columns"]);
1941    assert!(result.is_err());
1942}
1943
1944macro_rules! tensor_select_impl {
1945    (impl Tensor $d:literal 1) => {
1946        impl<T> Tensor<T, $d> {
1947            /**
1948             * Selects the provided dimension name and index pairs in this Tensor, returning a
1949             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1950             * always indexed as the provided values.
1951             *
1952             * This is a shorthand for manually constructing the TensorView and
1953             * [TensorIndex]
1954             *
1955             * Note: due to limitations in Rust's const generics support, this method is only
1956             * implemented for `provided_indexes` of length 1 and `D` from 1 to 6. You can fall
1957             * back to manual construction to create `TensorIndex`es with multiple provided
1958             * indexes if you need to reduce dimensionality by more than 1 dimension at a time.
1959             */
1960            #[track_caller]
1961            pub fn select(
1962                &self,
1963                provided_indexes: [(Dimension, usize); 1],
1964            ) -> TensorView<T, TensorIndex<T, &Tensor<T, $d>, $d, 1>, { $d - 1 }> {
1965                TensorView::from(TensorIndex::from(self, provided_indexes))
1966            }
1967
1968            /**
1969             * Selects the provided dimension name and index pairs in this Tensor, returning a
1970             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1971             * always indexed as the provided values. The TensorIndex mutably borrows this
1972             * Tensor, and can therefore mutate it
1973             *
1974             * See [select](Tensor::select)
1975             */
1976            #[track_caller]
1977            pub fn select_mut(
1978                &mut self,
1979                provided_indexes: [(Dimension, usize); 1],
1980            ) -> TensorView<T, TensorIndex<T, &mut Tensor<T, $d>, $d, 1>, { $d - 1 }> {
1981                TensorView::from(TensorIndex::from(self, provided_indexes))
1982            }
1983
1984            /**
1985             * Selects the provided dimension name and index pairs in this Tensor, returning a
1986             * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1987             * always indexed as the provided values. The TensorIndex takes ownership of this
1988             * Tensor, and can therefore mutate it
1989             *
1990             * See [select](Tensor::select)
1991             */
1992            #[track_caller]
1993            pub fn select_owned(
1994                self,
1995                provided_indexes: [(Dimension, usize); 1],
1996            ) -> TensorView<T, TensorIndex<T, Tensor<T, $d>, $d, 1>, { $d - 1 }> {
1997                TensorView::from(TensorIndex::from(self, provided_indexes))
1998            }
1999        }
2000    };
2001}
2002
2003tensor_select_impl!(impl Tensor 6 1);
2004tensor_select_impl!(impl Tensor 5 1);
2005tensor_select_impl!(impl Tensor 4 1);
2006tensor_select_impl!(impl Tensor 3 1);
2007tensor_select_impl!(impl Tensor 2 1);
2008tensor_select_impl!(impl Tensor 1 1);
2009
2010macro_rules! tensor_expand_impl {
2011    (impl Tensor $d:literal 1) => {
2012        impl<T> Tensor<T, $d> {
2013            /**
2014             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2015             * a particular position within the shape, returning a TensorView which has more
2016             * dimensions than this Tensor.
2017             *
2018             * This is a shorthand for manually constructing the TensorView and
2019             * [TensorExpansion]
2020             *
2021             * Note: due to limitations in Rust's const generics support, this method is only
2022             * implemented for `extra_dimension_names` of length 1 and `D` from 0 to 5. You can
2023             * fall back to manual construction to create `TensorExpansion`s with multiple provided
2024             * indexes if you need to increase dimensionality by more than 1 dimension at a time.
2025             */
2026            #[track_caller]
2027            pub fn expand(
2028                &self,
2029                extra_dimension_names: [(usize, Dimension); 1],
2030            ) -> TensorView<T, TensorExpansion<T, &Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2031                TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2032            }
2033
2034            /**
2035             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2036             * a particular position within the shape, returning a TensorView which has more
2037             * dimensions than this Tensor. The TensorIndex mutably borrows this
2038             * Tensor, and can therefore mutate it
2039             *
2040             * See [expand](Tensor::expand)
2041             */
2042            #[track_caller]
2043            pub fn expand_mut(
2044                &mut self,
2045                extra_dimension_names: [(usize, Dimension); 1],
2046            ) -> TensorView<T, TensorExpansion<T, &mut Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2047                TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2048            }
2049
2050            /**
2051             * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2052             * a particular position within the shape, returning a TensorView which has more
2053             * dimensions than this Tensor. The TensorIndex takes ownership of this
2054             * Tensor, and can therefore mutate it
2055             *
2056             * See [expand](Tensor::expand)
2057             */
2058            #[track_caller]
2059            pub fn expand_owned(
2060                self,
2061                extra_dimension_names: [(usize, Dimension); 1],
2062            ) -> TensorView<T, TensorExpansion<T, Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2063                TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2064            }
2065        }
2066    };
2067}
2068
2069tensor_expand_impl!(impl Tensor 0 1);
2070tensor_expand_impl!(impl Tensor 1 1);
2071tensor_expand_impl!(impl Tensor 2 1);
2072tensor_expand_impl!(impl Tensor 3 1);
2073tensor_expand_impl!(impl Tensor 4 1);
2074tensor_expand_impl!(impl Tensor 5 1);