easy_ml/tensors/mod.rs
1/*!
2 * Generic N dimensional [named tensors](http://nlp.seas.harvard.edu/NamedTensor).
3 *
4 * Tensors are generic over some type `T` and some usize `D`. If `T` is [Numeric](super::numeric)
5 * then the tensor can be used in a mathematical way. `D` is the number of dimensions in the tensor
6 * and a compile time constant. Each tensor also carries `D` dimension name and length pairs.
7 */
8use crate::linear_algebra;
9use crate::numeric::extra::{Real, RealRef};
10use crate::numeric::{Numeric, NumericRef};
11use crate::tensors::indexing::{
12 ShapeIterator, TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
13 TensorReferenceMutIterator, TensorTranspose,
14};
15use crate::tensors::views::{
16 DataLayout, IndexRange, IndexRangeValidationError, TensorExpansion, TensorIndex, TensorMask,
17 TensorMut, TensorRange, TensorRef, TensorRename, TensorReshape, TensorReverse, TensorView,
18};
19
20use std::error::Error;
21use std::fmt;
22
23#[cfg(feature = "serde")]
24use serde::Serialize;
25
26pub mod dimensions;
27mod display;
28pub mod indexing;
29pub mod operations;
30pub mod views;
31
32#[cfg(feature = "serde")]
33pub use serde_impls::TensorDeserialize;
34
35/**
36 * Dimension names are represented as static string references.
37 *
38 * This allows you to use string literals to refer to named dimensions, for example you might want
39 * to construct a tensor with a shape of
40 * `[("batch", 1000), ("height", 100), ("width", 100), ("rgba", 4)]`.
41 *
42 * Alternatively you can define the strings once as constants and refer to your dimension
43 * names by the constant identifiers.
44 *
45 * ```
46 * const BATCH: &'static str = "batch";
47 * const HEIGHT: &'static str = "height";
48 * const WIDTH: &'static str = "width";
49 * const RGBA: &'static str = "rgba";
50 * ```
51 *
52 * Although `Dimension` is interchangable with `&'static str` as it is just a type alias, Easy ML
53 * uses `Dimension` whenever dimension names are expected to distinguish the types from just
54 * strings.
55 */
56pub type Dimension = &'static str;
57
58/**
59 * An error indicating failure to do something with a Tensor because the requested shape
60 * is not valid.
61 */
62#[derive(Clone, Debug, Eq, PartialEq)]
63pub struct InvalidShapeError<const D: usize> {
64 shape: [(Dimension, usize); D],
65}
66
67impl<const D: usize> InvalidShapeError<D> {
68 /**
69 * Checks if this shape is valid. This is mainly for internal library use but may also be
70 * useful for unit testing.
71 *
72 * Note: in some functions and methods, an InvalidShapeError may be returned which is a valid
73 * shape, but not the right size for the quantity of data provided.
74 */
75 pub fn is_valid(&self) -> bool {
76 !crate::tensors::dimensions::has_duplicates(&self.shape)
77 && !self.shape.iter().any(|d| d.1 == 0)
78 }
79
80 /**
81 * Constructs an InvalidShapeError for assistance with unit testing. Note that you can
82 * construct an InvalidShapeError that *is* a valid shape in this way.
83 */
84 pub fn new(shape: [(Dimension, usize); D]) -> InvalidShapeError<D> {
85 InvalidShapeError { shape }
86 }
87
88 pub fn shape(&self) -> [(Dimension, usize); D] {
89 self.shape
90 }
91
92 pub fn shape_ref(&self) -> &[(Dimension, usize); D] {
93 &self.shape
94 }
95
96 // Panics if the shape is invalid for any reason with the appropriate error message.
97 #[track_caller]
98 #[inline]
99 fn validate_dimensions_or_panic(shape: &[(Dimension, usize); D], data_len: usize) {
100 let elements = crate::tensors::dimensions::elements(shape);
101 if data_len != elements {
102 panic!(
103 "Product of dimension lengths must match size of data. {} != {}",
104 elements, data_len
105 );
106 }
107 if crate::tensors::dimensions::has_duplicates(shape) {
108 panic!("Dimension names must all be unique: {:?}", &shape);
109 }
110 if shape.iter().any(|d| d.1 == 0) {
111 panic!("No dimension can have 0 elements: {:?}", &shape);
112 }
113 }
114
115 // Returns true if the shape is valid and matches the data length
116 fn validate_dimensions(shape: &[(Dimension, usize); D], data_len: usize) -> bool {
117 let elements = crate::tensors::dimensions::elements(shape);
118 data_len == elements
119 && !crate::tensors::dimensions::has_duplicates(shape)
120 && !shape.iter().any(|d| d.1 == 0)
121 }
122}
123
124impl<const D: usize> fmt::Display for InvalidShapeError<D> {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 write!(
127 f,
128 "Dimensions must all be at least length 1 with unique names: {:?}",
129 self.shape
130 )
131 }
132}
133
134impl<const D: usize> Error for InvalidShapeError<D> {}
135
136/**
137 * An error indicating failure to do something with a Tensor because the dimension names that
138 * were provided did not match with the dimension names that were valid.
139 *
140 * Typically this would be due to the same dimension name being provided multiple times, or a
141 * dimension name being provided that is not present in the shape of the Tensor in use.
142 */
143#[derive(Clone, Debug, Eq, PartialEq)]
144pub struct InvalidDimensionsError<const D: usize, const P: usize> {
145 valid: [Dimension; D],
146 provided: [Dimension; P],
147}
148
149impl<const D: usize, const P: usize> InvalidDimensionsError<D, P> {
150 /**
151 * Checks if the provided dimensions have duplicate names. This is mainly for internal library
152 * use but may also be useful for unit testing.
153 */
154 pub fn has_duplicates(&self) -> bool {
155 crate::tensors::dimensions::has_duplicates_names(&self.provided)
156 }
157
158 // TODO: method to check provided is a subset of valid
159
160 /**
161 * Constructs an InvalidDimensions for assistance with unit testing.
162 */
163 pub fn new(provided: [Dimension; P], valid: [Dimension; D]) -> InvalidDimensionsError<D, P> {
164 InvalidDimensionsError { valid, provided }
165 }
166
167 pub fn provided_names(&self) -> [Dimension; P] {
168 self.provided
169 }
170
171 pub fn provided_names_ref(&self) -> &[Dimension; P] {
172 &self.provided
173 }
174
175 pub fn valid_names(&self) -> [Dimension; D] {
176 self.valid
177 }
178
179 pub fn valid_names_ref(&self) -> &[Dimension; D] {
180 &self.valid
181 }
182}
183
184impl<const D: usize, const P: usize> fmt::Display for InvalidDimensionsError<D, P> {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 if P > 0 {
187 write!(
188 f,
189 "Dimensions names {:?} were incorrect, valid dimensions in this context are: {:?}",
190 self.provided, self.valid
191 )
192 } else {
193 write!(f, "Dimensions names {:?} were incorrect", self.provided)
194 }
195 }
196}
197
198impl<const D: usize, const P: usize> Error for InvalidDimensionsError<D, P> {}
199
200#[test]
201fn test_sync() {
202 fn assert_sync<T: Sync>() {}
203 assert_sync::<InvalidShapeError<2>>();
204 assert_sync::<InvalidDimensionsError<2, 2>>();
205}
206
207#[test]
208fn test_send() {
209 fn assert_send<T: Send>() {}
210 assert_send::<InvalidShapeError<2>>();
211 assert_send::<InvalidDimensionsError<2, 2>>();
212}
213
214/**
215 * A [named tensor](http://nlp.seas.harvard.edu/NamedTensor) of some type `T` and number of
216 * dimensions `D`.
217 *
218 * Tensors are a generalisation of matrices; whereas [Matrix](crate::matrices::Matrix) only
219 * supports 2 dimensions, and vectors are represented in Matrix by making either the rows or
220 * columns have a length of one, [Tensor] supports an arbitary number of dimensions,
221 * with 0 through 6 having full API support. A `Tensor<T, 2>` is very similar to a `Matrix<T>`
222 * except that this type associates each dimension with a name, and favors names to refer to
223 * dimensions instead of index order.
224 *
225 * Like Matrix, the type of the data in this Tensor may implement no traits, in which case the
226 * tensor will be rather useless. If the type implements Clone most storage and accessor methods
227 * are defined and if the type implements Numeric then the tensor can be used in a mathematical
228 * way.
229 *
230 * Like Matrix, a Tensor must always contain at least one element, and it may not not have more
231 * elements than `std::isize::MAX`. Concerned readers should note that on a 64 bit computer this
232 * maximum value is 9,223,372,036,854,775,807 so running out of memory is likely to occur first.
233 *
234 * When doing numeric operations with Tensors you should be careful to not consume a tensor by
235 * accidentally using it by value. All the operations are also defined on references to tensors
236 * so you should favor &x + &y style notation for tensors you intend to continue using.
237 *
238 * See also:
239 * - [indexing]
240 */
241#[derive(Debug)]
242#[cfg_attr(feature = "serde", derive(Serialize))]
243pub struct Tensor<T, const D: usize> {
244 data: Vec<T>,
245 #[cfg_attr(feature = "serde", serde(with = "serde_arrays"))]
246 shape: [(Dimension, usize); D],
247 #[cfg_attr(feature = "serde", serde(skip))]
248 strides: [usize; D],
249}
250
251impl<T, const D: usize> Tensor<T, D> {
252 /**
253 * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
254 *
255 * The product of the dimension lengths corresponds to the number of elements the Tensor
256 * will store. Elements are stored in what would be row major order for a Matrix.
257 * Each step in memory through the N dimensions corresponds to incrementing the rightmost
258 * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
259 * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
260 * for the 25th and final element.
261 *
262 * # Panics
263 *
264 * - If the number of provided elements does not match the product of the dimension lengths.
265 * - If a dimension name is not unique
266 * - If any dimension has 0 elements
267 *
268 * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
269 * a single element (since the product of an empty list is 1).
270 */
271 #[track_caller]
272 pub fn from(shape: [(Dimension, usize); D], data: Vec<T>) -> Self {
273 InvalidShapeError::validate_dimensions_or_panic(&shape, data.len());
274 let strides = compute_strides(&shape);
275 Tensor {
276 data,
277 shape,
278 strides,
279 }
280 }
281
282 /**
283 * Creates a Tensor with a particular shape initialised from a function.
284 *
285 * The product of the dimension lengths corresponds to the number of elements the Tensor
286 * will store. Elements are stored in what would be row major order for a Matrix.
287 * Each step in memory through the N dimensions corresponds to incrementing the rightmost
288 * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
289 * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
290 * for the 25th and final element. These same indexes will be passed to the producer function
291 * to initialised the values for the Tensor.
292 *
293 * ```
294 * use easy_ml::tensors::Tensor;
295 * let tensor = Tensor::from_fn([("rows", 4), ("columns", 4)], |[r, c]| r * c);
296 * assert_eq!(
297 * tensor,
298 * Tensor::from([("rows", 4), ("columns", 4)], vec![
299 * 0, 0, 0, 0,
300 * 0, 1, 2, 3,
301 * 0, 2, 4, 6,
302 * 0, 3, 6, 9,
303 * ])
304 * );
305 * ```
306 *
307 * # Panics
308 *
309 * - If a dimension name is not unique
310 * - If any dimension has 0 elements
311 *
312 * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
313 * a single element (since the product of an empty list is 1).
314 */
315 #[track_caller]
316 pub fn from_fn<F>(shape: [(Dimension, usize); D], mut producer: F) -> Self
317 where
318 F: FnMut([usize; D]) -> T,
319 {
320 let length = dimensions::elements(&shape);
321 let mut data = Vec::with_capacity(length);
322 let iterator = ShapeIterator::from(shape);
323 for index in iterator {
324 data.push(producer(index));
325 }
326 Tensor::from(shape, data)
327 }
328
329 /**
330 * The shape of this tensor. Since Tensors are named Tensors, their shape is not just a
331 * list of lengths along each dimension, but instead a list of pairs of names and lengths.
332 *
333 * See also
334 * - [dimensions]
335 * - [indexing]
336 */
337 pub fn shape(&self) -> [(Dimension, usize); D] {
338 self.shape
339 }
340
341 /**
342 * Returns the length of the dimension name provided, if one is present in the Tensor.
343 *
344 * See also
345 * - [dimensions]
346 * - [indexing]
347 */
348 pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
349 dimensions::length_of(&self.shape, dimension)
350 }
351
352 /**
353 * Returns the last index of the dimension name provided, if one is present in the Tensor.
354 *
355 * This is always 1 less than the length, the 'index' in this sense is based on what the
356 * Tensor's shape is, not any implementation index.
357 *
358 * See also
359 * - [dimensions]
360 * - [indexing]
361 */
362 pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
363 dimensions::last_index_of(&self.shape, dimension)
364 }
365
366 /**
367 * A non panicking version of [from](Tensor::from) which returns `Result::Err` if the input
368 * is invalid.
369 *
370 * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
371 *
372 * The product of the dimension lengths corresponds to the number of elements the Tensor
373 * will store. Elements are stored in what would be row major order for a Matrix.
374 * Each step in memory through the N dimensions corresponds to incrementing the rightmost
375 * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
376 * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
377 * for the 25th and final element.
378 *
379 * Returns the Err variant if
380 * - If the number of provided elements does not match the product of the dimension lengths.
381 * - If a dimension name is not unique
382 * - If any dimension has 0 elements
383 *
384 * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
385 * a single element (since the product of an empty list is 1).
386 */
387 pub fn try_from(
388 shape: [(Dimension, usize); D],
389 data: Vec<T>,
390 ) -> Result<Self, InvalidShapeError<D>> {
391 let valid = InvalidShapeError::validate_dimensions(&shape, data.len());
392 if !valid {
393 return Err(InvalidShapeError::new(shape));
394 }
395 let strides = compute_strides(&shape);
396 Ok(Tensor {
397 data,
398 shape,
399 strides,
400 })
401 }
402
403 /// Unverified constructor for interal use when we know the dimensions/data/strides are
404 /// unchanged and don't need reverification
405 pub(crate) fn direct_from(
406 data: Vec<T>,
407 shape: [(Dimension, usize); D],
408 strides: [usize; D],
409 ) -> Self {
410 Tensor {
411 data,
412 shape,
413 strides,
414 }
415 }
416
417 /// Unverified constructor for interal use when we know the dimensions/data/strides are
418 /// the same as the existing instance and don't need reverification
419 #[allow(dead_code)] // pretty sure something else will want this in the future
420 pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> Self {
421 Tensor {
422 data,
423 shape: self.shape,
424 strides: self.strides,
425 }
426 }
427}
428
429impl<T> Tensor<T, 0> {
430 /**
431 * Creates a 0 dimensional tensor from some scalar
432 */
433 pub fn from_scalar(value: T) -> Tensor<T, 0> {
434 Tensor {
435 data: vec![value],
436 shape: [],
437 strides: [],
438 }
439 }
440
441 /**
442 * Returns the sole element of the 0 dimensional tensor.
443 */
444 pub fn into_scalar(self) -> T {
445 self.data
446 .into_iter()
447 .next()
448 .expect("Tensors always have at least 1 element")
449 }
450}
451
452impl<T> Tensor<T, 0>
453where
454 T: Clone,
455{
456 /**
457 * Returns a copy of the sole element in the 0 dimensional tensor.
458 */
459 pub fn scalar(&self) -> T {
460 self.data
461 .first()
462 .expect("Tensors always have at least 1 element")
463 .clone()
464 }
465}
466
467impl<T> From<T> for Tensor<T, 0> {
468 fn from(scalar: T) -> Tensor<T, 0> {
469 Tensor::from_scalar(scalar)
470 }
471}
472// TODO: See if we can find a way to write the reverse Tensor<T, 0> -> T conversion using From or Into (doesn't seem like we can?)
473
474// # Safety
475//
476// We promise to never implement interior mutability for Tensor.
477/**
478 * A Tensor implements TensorRef.
479 */
480unsafe impl<T, const D: usize> TensorRef<T, D> for Tensor<T, D> {
481 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
482 let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
483 self.data.get(i)
484 }
485
486 fn view_shape(&self) -> [(Dimension, usize); D] {
487 Tensor::shape(self)
488 }
489
490 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
491 unsafe {
492 // The point of get_reference_unchecked is no bounds checking, and therefore
493 // it does not make any sense to just use `unwrap` here. The trait documents that
494 // it's undefind behaviour to call this method with an out of bounds index, so we
495 // can assume the None case will never happen.
496 let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
497 self.data.get_unchecked(i)
498 }
499 }
500
501 fn data_layout(&self) -> DataLayout<D> {
502 // We always have our memory in most significant to least
503 DataLayout::Linear(std::array::from_fn(|i| self.shape[i].0))
504 }
505}
506
507// # Safety
508//
509// We promise to never implement interior mutability for Tensor.
510unsafe impl<T, const D: usize> TensorMut<T, D> for Tensor<T, D> {
511 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
512 let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
513 self.data.get_mut(i)
514 }
515
516 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
517 unsafe {
518 // The point of get_reference_unchecked_mut is no bounds checking, and therefore
519 // it does not make any sense to just use `unwrap` here. The trait documents that
520 // it's undefind behaviour to call this method with an out of bounds index, so we
521 // can assume the None case will never happen.
522 let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
523 self.data.get_unchecked_mut(i)
524 }
525 }
526}
527
528/**
529 * Any tensor of a Cloneable type implements Clone.
530 */
531impl<T: Clone, const D: usize> Clone for Tensor<T, D> {
532 fn clone(&self) -> Self {
533 self.map(|element| element)
534 }
535}
536
537/**
538 * Any tensor of a Displayable type implements Display
539 *
540 * You can control the precision of the formatting using format arguments, i.e.
541 * `format!("{:.3}", tensor)`
542 */
543impl<T: std::fmt::Display, const D: usize> std::fmt::Display for Tensor<T, D> {
544 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
545 crate::tensors::display::format_view(self, f)
546 }
547}
548
549/**
550 * Any 2 dimensional tensor can be converted to a matrix with rows equal to the length of the
551 * first dimension in the tensor, and columns equal to the length of the second.
552 */
553impl<T> From<Tensor<T, 2>> for crate::matrices::Matrix<T> {
554 fn from(tensor: Tensor<T, 2>) -> Self {
555 crate::matrices::Matrix::from_flat_row_major(
556 (tensor.shape[0].1, tensor.shape[1].1),
557 tensor.data,
558 )
559 }
560}
561
562pub(crate) fn compute_strides<const D: usize>(shape: &[(Dimension, usize); D]) -> [usize; D] {
563 std::array::from_fn(|d| shape.iter().skip(d + 1).map(|d| d.1).product())
564}
565
566/// returns the 1 dimensional index to use to get the requested index into some tensor
567#[inline]
568pub(crate) fn get_index_direct<const D: usize>(
569 // indexes to use
570 indexes: &[usize; D],
571 // strides for indexing into the tensor
572 strides: &[usize; D],
573 // shape of the tensor to index into
574 shape: &[(Dimension, usize); D],
575) -> Option<usize> {
576 let mut index = 0;
577 for d in 0..D {
578 let n = indexes[d];
579 if n >= shape[d].1 {
580 return None;
581 }
582 index += n * strides[d];
583 }
584 Some(index)
585}
586
587/// returns the 1 dimensional index to use to get the requested index into some tensor, without
588/// checking the indexes are within bounds for the shape.
589#[inline]
590fn get_index_direct_unchecked<const D: usize>(
591 // indexes to use
592 indexes: &[usize; D],
593 // strides for indexing into the tensor
594 strides: &[usize; D],
595) -> usize {
596 let mut index = 0;
597 for d in 0..D {
598 let n = indexes[d];
599 index += n * strides[d];
600 }
601 index
602}
603
604impl<T, const D: usize> Tensor<T, D> {
605 pub fn view(&self) -> TensorView<T, &Tensor<T, D>, D> {
606 TensorView::from(self)
607 }
608
609 pub fn view_mut(&mut self) -> TensorView<T, &mut Tensor<T, D>, D> {
610 TensorView::from(self)
611 }
612
613 pub fn view_owned(self) -> TensorView<T, Tensor<T, D>, D> {
614 TensorView::from(self)
615 }
616
617 /**
618 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
619 * to read values from this tensor.
620 *
621 * # Panics
622 *
623 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
624 */
625 #[track_caller]
626 pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &Tensor<T, D>, D> {
627 TensorAccess::from(self, dimensions)
628 }
629
630 /**
631 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
632 * to read or write values from this tensor.
633 *
634 * # Panics
635 *
636 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
637 */
638 #[track_caller]
639 pub fn index_by_mut(
640 &mut self,
641 dimensions: [Dimension; D],
642 ) -> TensorAccess<T, &mut Tensor<T, D>, D> {
643 TensorAccess::from(self, dimensions)
644 }
645
646 /**
647 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
648 * to read or write values from this tensor.
649 *
650 * # Panics
651 *
652 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
653 */
654 #[track_caller]
655 pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, Tensor<T, D>, D> {
656 TensorAccess::from(self, dimensions)
657 }
658
659 /**
660 * Creates a TensorAccess which will index into the dimensions this Tensor was created with
661 * in the same order as they were provided. See [TensorAccess::from_source_order].
662 */
663 pub fn index(&self) -> TensorAccess<T, &Tensor<T, D>, D> {
664 TensorAccess::from_source_order(self)
665 }
666
667 /**
668 * Creates a TensorAccess which will index into the dimensions this Tensor was
669 * created with in the same order as they were provided. The TensorAccess mutably borrows
670 * the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
671 */
672 pub fn index_mut(&mut self) -> TensorAccess<T, &mut Tensor<T, D>, D> {
673 TensorAccess::from_source_order(self)
674 }
675
676 /**
677 * Creates a TensorAccess which will index into the dimensions this Tensor was
678 * created with in the same order as they were provided. The TensorAccess takes ownership
679 * of the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
680 */
681 pub fn index_owned(self) -> TensorAccess<T, Tensor<T, D>, D> {
682 TensorAccess::from_source_order(self)
683 }
684
685 /**
686 * Returns an iterator over references to the data in this Tensor.
687 */
688 pub fn iter_reference(&self) -> TensorReferenceIterator<T, Tensor<T, D>, D> {
689 TensorReferenceIterator::from(self)
690 }
691
692 /**
693 * Returns an iterator over mutable references to the data in this Tensor.
694 */
695 pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<T, Tensor<T, D>, D> {
696 TensorReferenceMutIterator::from(self)
697 }
698
699 /**
700 * Creates an iterator over the values in this Tensor.
701 */
702 pub fn iter_owned(self) -> TensorOwnedIterator<T, Tensor<T, D>, D>
703 where
704 T: Default,
705 {
706 TensorOwnedIterator::from(self)
707 }
708
709 // Non public index order reference iterator since we don't want to expose our implementation
710 // details to public API since then we could never change them.
711 pub(crate) fn direct_iter_reference(&self) -> std::slice::Iter<T> {
712 self.data.iter()
713 }
714
715 // Non public index order reference iterator since we don't want to expose our implementation
716 // details to public API since then we could never change them.
717 pub(crate) fn direct_iter_reference_mut(&mut self) -> std::slice::IterMut<T> {
718 self.data.iter_mut()
719 }
720
721 /**
722 * Renames the dimension names of the tensor without changing the lengths of the dimensions
723 * in the tensor or moving any data around.
724 *
725 * ```
726 * use easy_ml::tensors::Tensor;
727 * let mut tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6]);
728 * tensor.rename(["y", "z"]);
729 * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
730 * ```
731 *
732 * # Panics
733 *
734 * - If a dimension name is not unique
735 */
736 #[track_caller]
737 pub fn rename(&mut self, dimensions: [Dimension; D]) {
738 if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
739 panic!("Dimension names must all be unique: {:?}", &dimensions);
740 }
741 #[allow(clippy::needless_range_loop)]
742 for d in 0..D {
743 self.shape[d].0 = dimensions[d];
744 }
745 }
746
747 /**
748 * Renames the dimension names of the tensor and returns it without changing the lengths
749 * of the dimensions in the tensor or moving any data around.
750 *
751 * ```
752 * use easy_ml::tensors::Tensor;
753 * let tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6])
754 * .rename_owned(["y", "z"]);
755 * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
756 * ```
757 *
758 * # Panics
759 *
760 * - If a dimension name is not unique
761 */
762 #[track_caller]
763 pub fn rename_owned(mut self, dimensions: [Dimension; D]) -> Tensor<T, D> {
764 self.rename(dimensions);
765 self
766 }
767
768 /**
769 * Returns a TensorView with the dimension names of the shape renamed to the provided
770 * dimensions. The data of this tensor and the dimension lengths and order remain unchanged.
771 *
772 * This is a shorthand for constructing the TensorView from this Tensor.
773 *
774 * ```
775 * use easy_ml::tensors::Tensor;
776 * use easy_ml::tensors::views::{TensorView, TensorRename};
777 * let abc = Tensor::from([("a", 3), ("b", 3), ("c", 3)], (0..27).collect());
778 * let xyz = abc.rename_view(["x", "y", "z"]);
779 * let also_xyz = TensorView::from(TensorRename::from(&abc, ["x", "y", "z"]));
780 * assert_eq!(xyz, also_xyz);
781 * assert_eq!(xyz, Tensor::from([("x", 3), ("y", 3), ("z", 3)], (0..27).collect()));
782 * ```
783 *
784 * # Panics
785 *
786 * - If a dimension name is not unique
787 */
788 #[track_caller]
789 pub fn rename_view(
790 &self,
791 dimensions: [Dimension; D],
792 ) -> TensorView<T, TensorRename<T, &Tensor<T, D>, D>, D> {
793 TensorView::from(TensorRename::from(self, dimensions))
794 }
795
796 /**
797 * Changes the shape of the tensor without changing the number of dimensions or moving any
798 * data around.
799 *
800 * # Panics
801 *
802 * - If the number of provided elements in the new shape does not match the product of the
803 * dimension lengths in the existing tensor's shape.
804 * - If a dimension name is not unique
805 * - If any dimension has 0 elements
806 *
807 * ```
808 * use easy_ml::tensors::Tensor;
809 * let mut tensor = Tensor::from([("width", 2), ("height", 2)], vec![
810 * 1, 2,
811 * 3, 4
812 * ]);
813 * tensor.reshape_mut([("batch", 1), ("image", 4)]);
814 * assert_eq!(tensor, Tensor::from([("batch", 1), ("image", 4)], vec![ 1, 2, 3, 4 ]));
815 * ```
816 */
817 #[track_caller]
818 pub fn reshape_mut(&mut self, shape: [(Dimension, usize); D]) {
819 InvalidShapeError::validate_dimensions_or_panic(&shape, self.data.len());
820 let strides = compute_strides(&shape);
821 self.shape = shape;
822 self.strides = strides;
823 }
824
825 /**
826 * Consumes the tensor and changes the shape of the tensor without moving any
827 * data around. The new Tensor may also have a different number of dimensions.
828 *
829 * # Panics
830 *
831 * - If the number of provided elements in the new shape does not match the product of the
832 * dimension lengths in the existing tensor's shape.
833 * - If a dimension name is not unique
834 *
835 * ```
836 * use easy_ml::tensors::Tensor;
837 * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
838 * 1, 2,
839 * 3, 4
840 * ]);
841 * let flattened = tensor.reshape_owned([("image", 4)]);
842 * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
843 * ```
844 *
845 * See also [reshape_view_owned](Tensor::reshape_view_owned)
846 */
847 #[track_caller]
848 pub fn reshape_owned<const D2: usize>(self, shape: [(Dimension, usize); D2]) -> Tensor<T, D2> {
849 Tensor::from(shape, self.data)
850 }
851
852 /**
853 * Returns a TensorView with the dimensions changed to the provided shape without moving any
854 * data around. The new Tensor may also have a different number of dimensions.
855 *
856 * This is a shorthand for constructing the TensorView from this Tensor.
857 *
858 * # Panics
859 *
860 * - If the number of provided elements in the new shape does not match the product of the
861 * dimension lengths in the existing tensor's shape.
862 * - If a dimension name is not unique
863 *
864 * ```
865 * use easy_ml::tensors::Tensor;
866 * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
867 * 1, 2,
868 * 3, 4
869 * ]);
870 * let flattened = tensor.reshape_view([("image", 4)]);
871 * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
872 * ```
873 */
874 pub fn reshape_view<const D2: usize>(
875 &self,
876 shape: [(Dimension, usize); D2],
877 ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, D2>, D2> {
878 TensorView::from(TensorReshape::from(self, shape))
879 }
880
881 /**
882 * Returns a TensorView with the dimensions changed to the provided shape without moving any
883 * data around. The new Tensor may also have a different number of dimensions.
884 *
885 * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
886 * mutably borrows this Tensor, and can therefore mutate it
887 *
888 * # Panics
889 *
890 * - If the number of provided elements in the new shape does not match the product of the
891 * dimension lengths in the existing tensor's shape.
892 * - If a dimension name is not unique
893 */
894 pub fn reshape_view_mut<const D2: usize>(
895 &mut self,
896 shape: [(Dimension, usize); D2],
897 ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, D2>, D2> {
898 TensorView::from(TensorReshape::from(self, shape))
899 }
900
901 /**
902 * Returns a TensorView with the dimensions changed to the provided shape without moving any
903 * data around. The new Tensor may also have a different number of dimensions.
904 *
905 * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
906 * takes ownership of this Tensor, and can therefore mutate it
907 *
908 * # Panics
909 *
910 * - If the number of provided elements in the new shape does not match the product of the
911 * dimension lengths in the existing tensor's shape.
912 * - If a dimension name is not unique
913 *
914 * See also [reshape_owned](Tensor::reshape_owned)
915 */
916 pub fn reshape_view_owned<const D2: usize>(
917 self,
918 shape: [(Dimension, usize); D2],
919 ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, D2>, D2> {
920 TensorView::from(TensorReshape::from(self, shape))
921 }
922
923 /**
924 * Given the dimension name, returns a view of this tensor reshaped to one dimension
925 * with a length equal to the number of elements in this tensor.
926 */
927 pub fn flatten_view(
928 &self,
929 dimension: Dimension,
930 ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, 1>, 1> {
931 self.reshape_view([(dimension, dimensions::elements(&self.shape))])
932 }
933
934 /**
935 * Given the dimension name, returns a view of this tensor reshaped to one dimension
936 * with a length equal to the number of elements in this tensor.
937 */
938 pub fn flatten_view_mut(
939 &mut self,
940 dimension: Dimension,
941 ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, 1>, 1> {
942 self.reshape_view_mut([(dimension, dimensions::elements(&self.shape))])
943 }
944
945 /**
946 * Given the dimension name, returns a view of this tensor reshaped to one dimension
947 * with a length equal to the number of elements in this tensor.
948 *
949 * If you intend to query the tensor a lot after creating the view, consider
950 * using [flatten](Tensor::flatten) instead as it will have less overhead to
951 * index after creation.
952 */
953 pub fn flatten_view_owned(
954 self,
955 dimension: Dimension,
956 ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, 1>, 1> {
957 let length = dimensions::elements(&self.shape);
958 self.reshape_view_owned([(dimension, length)])
959 }
960
961 /**
962 * Given the dimension name, returns a new tensor reshaped to one dimension
963 * with a length equal to the number of elements in this tensor.
964 */
965 pub fn flatten(self, dimension: Dimension) -> Tensor<T, 1> {
966 let length = dimensions::elements(&self.shape);
967 self.reshape_owned([(dimension, length)])
968 }
969
970 /**
971 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
972 * range from view. Error cases are documented on [TensorRange].
973 *
974 * This is a shorthand for constructing the TensorView from this Tensor.
975 *
976 * ```
977 * use easy_ml::tensors::Tensor;
978 * use easy_ml::tensors::views::{TensorView, TensorRange, IndexRange};
979 * # use easy_ml::tensors::views::IndexRangeValidationError;
980 * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
981 * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
982 * let cropped = samples.range([("x", IndexRange::new(1, 5)), ("y", IndexRange::new(1, 5))])?;
983 * let also_cropped = TensorView::from(
984 * TensorRange::from(&samples, [("x", 1..6), ("y", 1..6)])?
985 * );
986 * assert_eq!(cropped, also_cropped);
987 * assert_eq!(
988 * cropped.select([("batch", 0)]),
989 * Tensor::from([("x", 5), ("y", 5)], vec![
990 * 8, 9, 10, 11, 12,
991 * 15, 16, 17, 18, 19,
992 * 22, 23, 24, 25, 26,
993 * 29, 30, 31, 32, 33,
994 * 36, 37, 38, 39, 40
995 * ])
996 * );
997 * # Ok(())
998 * # }
999 * ```
1000 */
1001 pub fn range<R, const P: usize>(
1002 &self,
1003 ranges: [(Dimension, R); P],
1004 ) -> Result<TensorView<T, TensorRange<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1005 where
1006 R: Into<IndexRange>,
1007 {
1008 TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1009 }
1010
1011 /**
1012 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1013 * range from view. Error cases are documented on [TensorRange]. The TensorRange
1014 * mutably borrows this Tensor, and can therefore mutate it
1015 *
1016 * This is a shorthand for constructing the TensorView from this Tensor.
1017 */
1018 pub fn range_mut<R, const P: usize>(
1019 &mut self,
1020 ranges: [(Dimension, R); P],
1021 ) -> Result<
1022 TensorView<T, TensorRange<T, &mut Tensor<T, D>, D>, D>,
1023 IndexRangeValidationError<D, P>,
1024 >
1025 where
1026 R: Into<IndexRange>,
1027 {
1028 TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1029 }
1030
1031 /**
1032 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1033 * range from view. Error cases are documented on [TensorRange]. The TensorRange
1034 * takes ownership of this Tensor, and can therefore mutate it
1035 *
1036 * This is a shorthand for constructing the TensorView from this Tensor.
1037 */
1038 pub fn range_owned<R, const P: usize>(
1039 self,
1040 ranges: [(Dimension, R); P],
1041 ) -> Result<TensorView<T, TensorRange<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1042 where
1043 R: Into<IndexRange>,
1044 {
1045 TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1046 }
1047
1048 /**
1049 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1050 * range from view. Error cases are documented on [TensorMask].
1051 *
1052 * This is a shorthand for constructing the TensorView from this Tensor.
1053 *
1054 * ```
1055 * use easy_ml::tensors::Tensor;
1056 * use easy_ml::tensors::views::{TensorView, TensorMask, IndexRange};
1057 * # use easy_ml::tensors::views::IndexRangeValidationError;
1058 * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
1059 * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
1060 * let corners = samples.mask([("x", IndexRange::new(3, 2)), ("y", IndexRange::new(3, 2))])?;
1061 * let also_corners = TensorView::from(
1062 * TensorMask::from(&samples, [("x", 3..5), ("y", 3..5)])?
1063 * );
1064 * assert_eq!(corners, also_corners);
1065 * assert_eq!(
1066 * corners.select([("batch", 0)]),
1067 * Tensor::from([("x", 5), ("y", 5)], vec![
1068 * 0, 1, 2, 5, 6,
1069 * 7, 8, 9, 12, 13,
1070 * 14, 15, 16, 19, 20,
1071 *
1072 * 35, 36, 37, 40, 41,
1073 * 42, 43, 44, 47, 48
1074 * ])
1075 * );
1076 * # Ok(())
1077 * # }
1078 * ```
1079 */
1080 pub fn mask<R, const P: usize>(
1081 &self,
1082 masks: [(Dimension, R); P],
1083 ) -> Result<TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1084 where
1085 R: Into<IndexRange>,
1086 {
1087 TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1088 }
1089
1090 /**
1091 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1092 * range from view. Error cases are documented on [TensorMask]. The TensorMask
1093 * mutably borrows this Tensor, and can therefore mutate it
1094 *
1095 * This is a shorthand for constructing the TensorView from this Tensor.
1096 */
1097 pub fn mask_mut<R, const P: usize>(
1098 &mut self,
1099 masks: [(Dimension, R); P],
1100 ) -> Result<
1101 TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D>,
1102 IndexRangeValidationError<D, P>,
1103 >
1104 where
1105 R: Into<IndexRange>,
1106 {
1107 TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1108 }
1109
1110 /**
1111 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1112 * range from view. Error cases are documented on [TensorMask]. The TensorMask
1113 * takes ownership of this Tensor, and can therefore mutate it
1114 *
1115 * This is a shorthand for constructing the TensorView from this Tensor.
1116 */
1117 pub fn mask_owned<R, const P: usize>(
1118 self,
1119 masks: [(Dimension, R); P],
1120 ) -> Result<TensorView<T, TensorMask<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1121 where
1122 R: Into<IndexRange>,
1123 {
1124 TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1125 }
1126
1127 /**
1128 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1129 * order. The data of this tensor and the dimension lengths remain unchanged.
1130 *
1131 * This is a shorthand for constructing the TensorView from this Tensor.
1132 *
1133 * ```
1134 * use easy_ml::tensors::Tensor;
1135 * use easy_ml::tensors::views::{TensorView, TensorReverse};
1136 * let ab = Tensor::from([("a", 2), ("b", 3)], (0..6).collect());
1137 * let reversed = ab.reverse(&["a"]);
1138 * let also_reversed = TensorView::from(TensorReverse::from(&ab, &["a"]));
1139 * assert_eq!(reversed, also_reversed);
1140 * assert_eq!(
1141 * reversed,
1142 * Tensor::from(
1143 * [("a", 2), ("b", 3)],
1144 * vec![
1145 * 3, 4, 5,
1146 * 0, 1, 2,
1147 * ]
1148 * )
1149 * );
1150 * ```
1151 *
1152 * # Panics
1153 *
1154 * - If a dimension name is not in the tensor's shape or is repeated.
1155 */
1156 #[track_caller]
1157 pub fn reverse(
1158 &self,
1159 dimensions: &[Dimension],
1160 ) -> TensorView<T, TensorReverse<T, &Tensor<T, D>, D>, D> {
1161 TensorView::from(TensorReverse::from(self, dimensions))
1162 }
1163
1164 /**
1165 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1166 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1167 * mutably borrows this Tensor, and can therefore mutate it
1168 *
1169 * This is a shorthand for constructing the TensorView from this Tensor.
1170 *
1171 * # Panics
1172 *
1173 * - If a dimension name is not in the tensor's shape or is repeated.
1174 */
1175 #[track_caller]
1176 pub fn reverse_mut(
1177 &mut self,
1178 dimensions: &[Dimension],
1179 ) -> TensorView<T, TensorReverse<T, &mut Tensor<T, D>, D>, D> {
1180 TensorView::from(TensorReverse::from(self, dimensions))
1181 }
1182
1183 /**
1184 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1185 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1186 * takes ownership of this Tensor, and can therefore mutate it
1187 *
1188 * This is a shorthand for constructing the TensorView from this Tensor.
1189 *
1190 * # Panics
1191 *
1192 * - If a dimension name is not in the tensor's shape or is repeated.
1193 */
1194 #[track_caller]
1195 pub fn reverse_owned(
1196 self,
1197 dimensions: &[Dimension],
1198 ) -> TensorView<T, TensorReverse<T, Tensor<T, D>, D>, D> {
1199 TensorView::from(TensorReverse::from(self, dimensions))
1200 }
1201
1202 /**
1203 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1204 * mapped by a function. The value pairs are not copied for you, if you're using `Copy` types
1205 * or need to clone the values anyway, you can use
1206 * [`Tensor::elementwise`](Tensor::elementwise) instead.
1207 *
1208 * # Generics
1209 *
1210 * This method can be called with any right hand side that can be converted to a TensorView,
1211 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1212 *
1213 * # Panics
1214 *
1215 * If the two tensors have different shapes.
1216 */
1217 #[track_caller]
1218 pub fn elementwise_reference<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1219 where
1220 I: Into<TensorView<T, S, D>>,
1221 S: TensorRef<T, D>,
1222 M: Fn(&T, &T) -> T,
1223 {
1224 self.elementwise_reference_less_generic(rhs.into(), mapping_function)
1225 }
1226
1227 /**
1228 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1229 * mapped by a function. The mapping function also receives each index corresponding to the
1230 * value pairs. The value pairs are not copied for you, if you're using `Copy` types
1231 * or need to clone the values anyway, you can use
1232 * [`Tensor::elementwise_with_index`](Tensor::elementwise_with_index) instead.
1233 *
1234 * # Generics
1235 *
1236 * This method can be called with any right hand side that can be converted to a TensorView,
1237 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1238 *
1239 * # Panics
1240 *
1241 * If the two tensors have different shapes.
1242 */
1243 #[track_caller]
1244 pub fn elementwise_reference_with_index<S, I, M>(
1245 &self,
1246 rhs: I,
1247 mapping_function: M,
1248 ) -> Tensor<T, D>
1249 where
1250 I: Into<TensorView<T, S, D>>,
1251 S: TensorRef<T, D>,
1252 M: Fn([usize; D], &T, &T) -> T,
1253 {
1254 self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
1255 }
1256
1257 #[track_caller]
1258 fn elementwise_reference_less_generic<S, M>(
1259 &self,
1260 rhs: TensorView<T, S, D>,
1261 mapping_function: M,
1262 ) -> Tensor<T, D>
1263 where
1264 S: TensorRef<T, D>,
1265 M: Fn(&T, &T) -> T,
1266 {
1267 let left_shape = self.shape();
1268 let right_shape = rhs.shape();
1269 if left_shape != right_shape {
1270 panic!(
1271 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1272 left_shape, right_shape
1273 );
1274 }
1275 let mapped = self
1276 .direct_iter_reference()
1277 .zip(rhs.iter_reference())
1278 .map(|(x, y)| mapping_function(x, y))
1279 .collect();
1280 // We're not changing the shape of the Tensor, so don't need to revalidate
1281 Tensor::direct_from(mapped, self.shape, self.strides)
1282 }
1283
1284 #[track_caller]
1285 fn elementwise_reference_less_generic_with_index<S, M>(
1286 &self,
1287 rhs: TensorView<T, S, D>,
1288 mapping_function: M,
1289 ) -> Tensor<T, D>
1290 where
1291 S: TensorRef<T, D>,
1292 M: Fn([usize; D], &T, &T) -> T,
1293 {
1294 let left_shape = self.shape();
1295 let right_shape = rhs.shape();
1296 if left_shape != right_shape {
1297 panic!(
1298 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1299 left_shape, right_shape
1300 );
1301 }
1302 // we just checked both shapes were the same, so we don't need to propagate indexes
1303 // for both tensors because they'll be identical
1304 let mapped = self
1305 .direct_iter_reference()
1306 .zip(rhs.iter_reference().with_index())
1307 .map(|(x, (i, y))| mapping_function(i, x, y))
1308 .collect();
1309 // We're not changing the shape of the Tensor, so don't need to revalidate
1310 Tensor::direct_from(mapped, self.shape, self.strides)
1311 }
1312
1313 /**
1314 * Returns a TensorView which makes the order of the data in this tensor appear to be in
1315 * a different order. The order of the dimension names is unchanged, although their lengths
1316 * may swap.
1317 *
1318 * This is a shorthand for constructing the TensorView from this Tensor.
1319 *
1320 * See also: [transpose](Tensor::transpose), [TensorTranspose]
1321 *
1322 * # Panics
1323 *
1324 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1325 * order need not match.
1326 */
1327 pub fn transpose_view(
1328 &self,
1329 dimensions: [Dimension; D],
1330 ) -> TensorView<T, TensorTranspose<T, &Tensor<T, D>, D>, D> {
1331 TensorView::from(TensorTranspose::from(self, dimensions))
1332 }
1333}
1334
1335impl<T, const D: usize> Tensor<T, D>
1336where
1337 T: Clone,
1338{
1339 /**
1340 * Creates a tensor with a particular number of dimensions and length in each dimension
1341 * with all elements initialised to the provided value.
1342 *
1343 * # Panics
1344 *
1345 * - If a dimension name is not unique
1346 * - If any dimension has 0 elements
1347 */
1348 #[track_caller]
1349 pub fn empty(shape: [(Dimension, usize); D], value: T) -> Self {
1350 let elements = crate::tensors::dimensions::elements(&shape);
1351 Tensor::from(shape, vec![value; elements])
1352 }
1353
1354 /**
1355 * Gets a copy of the first value in this tensor.
1356 * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
1357 * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
1358 */
1359 pub fn first(&self) -> T {
1360 self.data
1361 .first()
1362 .expect("Tensors always have at least 1 element")
1363 .clone()
1364 }
1365
1366 /**
1367 * Returns a new Tensor which has the same data as this tensor, but with the order of data
1368 * changed. The order of the dimension names is unchanged, although their lengths may swap.
1369 *
1370 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1371 * `transpose(["y", "x"])` which would return a new tensor with a shape of
1372 * `[("x", y), ("y", x)]` where every (x,y) of its data corresponds to (y,x) in the original.
1373 *
1374 * This method need not shift *all* the dimensions though, you could also swap the width
1375 * and height of images in a tensor with a shape of
1376 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `transpose(["batch", "w", "h", "c"])`
1377 * which would return a new tensor where all the images have been swapped over the diagonal.
1378 *
1379 * See also: [TensorAccess], [reorder](Tensor::reorder)
1380 *
1381 * # Panics
1382 *
1383 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1384 * order need not match (and if the order does match, this function is just an expensive
1385 * clone).
1386 */
1387 #[track_caller]
1388 pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1389 let shape = self.shape;
1390 let mut reordered = self.reorder(dimensions);
1391 // Transposition is essentially reordering, but we retain the dimension name ordering
1392 // of the original order, this means we may swap dimension lengths, but the dimensions
1393 // will not change order.
1394 #[allow(clippy::needless_range_loop)]
1395 for d in 0..D {
1396 reordered.shape[d].0 = shape[d].0;
1397 }
1398 reordered
1399 }
1400
1401 /**
1402 * Modifies this tensor to have the same data as before, but with the order of data changed.
1403 * The order of the dimension names is unchanged, although their lengths may swap.
1404 *
1405 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1406 * `transpose_mut(["y", "x"])` which would edit the tensor, updating its shape to
1407 * `[("x", y), ("y", x)]`, so every (x,y) of its data corresponds to (y,x) before the
1408 * transposition.
1409 *
1410 * The order swapping will try to be in place, but this is currently only supported for
1411 * square tensors with 2 dimensions. Other types of tensors will not be transposed in place.
1412 *
1413 * # Panics
1414 *
1415 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1416 * order need not match (and if the order does match, this function is just an expensive
1417 * clone).
1418 */
1419 #[track_caller]
1420 pub fn transpose_mut(&mut self, dimensions: [Dimension; D]) {
1421 let shape = self.shape;
1422 self.reorder_mut(dimensions);
1423 // Transposition is essentially reordering, but we retain the dimension name ordering
1424 // we had before, this means we may swap dimension lengths, but the dimensions
1425 // will not change order.
1426 #[allow(clippy::needless_range_loop)]
1427 for d in 0..D {
1428 self.shape[d].0 = shape[d].0;
1429 }
1430 }
1431
1432 /**
1433 * Returns a new Tensor which has the same data as this tensor, but with the order of the
1434 * dimensions and corresponding order of data changed.
1435 *
1436 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1437 * `reorder(["y", "x"])` which would return a new tensor with a shape of
1438 * `[("y", y), ("x", x)]` where every (y,x) of its data corresponds to (x,y) in the original.
1439 *
1440 * This method need not shift *all* the dimensions though, you could also swap the width
1441 * and height of images in a tensor with a shape of
1442 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `reorder(["batch", "w", "h", "c"])`
1443 * which would return a new tensor where every (b,w,h,c) of its data corresponds to (b,h,w,c)
1444 * in the original.
1445 *
1446 * See also: [TensorAccess], [transpose](Tensor::transpose)
1447 *
1448 * # Panics
1449 *
1450 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1451 * order need not match (and if the order does match, this function is just an expensive
1452 * clone).
1453 */
1454 #[track_caller]
1455 pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1456 let reorderd = match TensorAccess::try_from(&self, dimensions) {
1457 Ok(reordered) => reordered,
1458 Err(_error) => panic!(
1459 "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1460 dimensions, &self.shape,
1461 ),
1462 };
1463 let reorderd_shape = reorderd.shape();
1464 Tensor::from(reorderd_shape, reorderd.iter().collect())
1465 }
1466
1467 /**
1468 * Modifies this tensor to have the same data as before, but with the order of the
1469 * dimensions and corresponding order of data changed.
1470 *
1471 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1472 * `reorder_mut(["y", "x"])` which would edit the tensor, updating its shape to
1473 * `[("y", y), ("x", x)]`, so every (y,x) of its data corresponds to (x,y) before the
1474 * transposition.
1475 *
1476 * The order swapping will try to be in place, but this is currently only supported for
1477 * square tensors with 2 dimensions. Other types of tensors will not be reordered in place.
1478 *
1479 * # Panics
1480 *
1481 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1482 * order need not match (and if the order does match, this function is just an expensive
1483 * clone).
1484 */
1485 #[track_caller]
1486 pub fn reorder_mut(&mut self, dimensions: [Dimension; D]) {
1487 use crate::tensors::dimensions::DimensionMappings;
1488 if D == 2 && crate::tensors::dimensions::is_square(&self.shape) {
1489 let dimension_mapping = match DimensionMappings::new(&self.shape, &dimensions) {
1490 Some(dimension_mapping) => dimension_mapping,
1491 None => panic!(
1492 "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1493 dimensions, &self.shape,
1494 ),
1495 };
1496
1497 let shape = dimension_mapping.map_shape_to_requested(&self.shape);
1498 let shape_iterator = ShapeIterator::from(shape);
1499
1500 for index in shape_iterator {
1501 let i = index[0];
1502 let j = index[1];
1503 if j >= i {
1504 let mapped_index = dimension_mapping.map_dimensions_to_source(&index);
1505 // Swap elements from the upper triangle (using index order of the actual tensor's
1506 // shape)
1507 let temp = self.get_reference(index).unwrap().clone();
1508 // tensor[i,j] becomes tensor[mapping(i,j)]
1509 *self.get_reference_mut(index).unwrap() =
1510 self.get_reference(mapped_index).unwrap().clone();
1511 // tensor[mapping(i,j)] becomes tensor[i,j]
1512 *self.get_reference_mut(mapped_index).unwrap() = temp;
1513 // If the mapping is a noop we've assigned i,j to i,j
1514 // If the mapping is i,j -> j,i we've assigned i,j to j,i and j,i to i,j
1515 }
1516 }
1517
1518 // now update our shape and strides to match
1519 self.shape = shape;
1520 self.strides = compute_strides(&shape);
1521 } else {
1522 // fallback to allocating a new reordered tensor
1523 let reordered = self.reorder(dimensions);
1524 self.data = reordered.data;
1525 self.shape = reordered.shape;
1526 self.strides = reordered.strides;
1527 }
1528 }
1529
1530 /**
1531 * Returns an iterator over copies of the data in this Tensor.
1532 */
1533 pub fn iter(&self) -> TensorIterator<T, Tensor<T, D>, D> {
1534 TensorIterator::from(self)
1535 }
1536
1537 /**
1538 * Creates and returns a new tensor with all values from the original with the
1539 * function applied to each. This can be used to change the type of the tensor
1540 * such as creating a mask:
1541 * ```
1542 * use easy_ml::tensors::Tensor;
1543 * let x = Tensor::from([("a", 2), ("b", 2)], vec![
1544 * 0.0, 1.2,
1545 * 5.8, 6.9
1546 * ]);
1547 * let y = x.map(|element| element > 2.0);
1548 * let result = Tensor::from([("a", 2), ("b", 2)], vec![
1549 * false, false,
1550 * true, true
1551 * ]);
1552 * assert_eq!(&y, &result);
1553 * ```
1554 */
1555 pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
1556 let mapped = self
1557 .data
1558 .iter()
1559 .map(|x| mapping_function(x.clone()))
1560 .collect();
1561 // We're not changing the shape of the Tensor, so don't need to revalidate
1562 Tensor::direct_from(mapped, self.shape, self.strides)
1563 }
1564
1565 /**
1566 * Creates and returns a new tensor with all values from the original and
1567 * the index of each value mapped by a function.
1568 */
1569 pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
1570 let mapped = self
1571 .iter()
1572 .with_index()
1573 .map(|(i, x)| mapping_function(i, x))
1574 .collect();
1575 // We're not changing the shape of the Tensor, so don't need to revalidate
1576 Tensor::direct_from(mapped, self.shape, self.strides)
1577 }
1578
1579 /**
1580 * Applies a function to all values in the tensor, modifying
1581 * the tensor in place.
1582 */
1583 pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1584 for value in self.data.iter_mut() {
1585 *value = mapping_function(value.clone());
1586 }
1587 }
1588
1589 /**
1590 * Applies a function to all values and each value's index in the tensor, modifying
1591 * the tensor in place.
1592 */
1593 pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
1594 self.iter_reference_mut()
1595 .with_index()
1596 .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
1597 }
1598
1599 /**
1600 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1601 * mapped by a function.
1602 *
1603 * ```
1604 * use easy_ml::tensors::Tensor;
1605 * let lhs = Tensor::from([("a", 4)], vec![1, 2, 3, 4]);
1606 * let rhs = Tensor::from([("a", 4)], vec![0, 1, 2, 3]);
1607 * let multiplied = lhs.elementwise(&rhs, |l, r| l * r);
1608 * assert_eq!(
1609 * multiplied,
1610 * Tensor::from([("a", 4)], vec![0, 2, 6, 12])
1611 * );
1612 * ```
1613 *
1614 * # Generics
1615 *
1616 * This method can be called with any right hand side that can be converted to a TensorView,
1617 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1618 *
1619 * # Panics
1620 *
1621 * If the two tensors have different shapes.
1622 */
1623 #[track_caller]
1624 pub fn elementwise<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1625 where
1626 I: Into<TensorView<T, S, D>>,
1627 S: TensorRef<T, D>,
1628 M: Fn(T, T) -> T,
1629 {
1630 self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
1631 mapping_function(lhs.clone(), rhs.clone())
1632 })
1633 }
1634
1635 /**
1636 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1637 * mapped by a function. The mapping function also receives each index corresponding to the
1638 * value pairs.
1639 *
1640 * # Generics
1641 *
1642 * This method can be called with any right hand side that can be converted to a TensorView,
1643 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1644 *
1645 * # Panics
1646 *
1647 * If the two tensors have different shapes.
1648 */
1649 #[track_caller]
1650 pub fn elementwise_with_index<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1651 where
1652 I: Into<TensorView<T, S, D>>,
1653 S: TensorRef<T, D>,
1654 M: Fn([usize; D], T, T) -> T,
1655 {
1656 self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
1657 mapping_function(i, lhs.clone(), rhs.clone())
1658 })
1659 }
1660}
1661
1662impl<T> Tensor<T, 1>
1663where
1664 T: Numeric,
1665 for<'a> &'a T: NumericRef<T>,
1666{
1667 /**
1668 * Computes the scalar product of two equal length vectors. For two vectors `[a,b,c]` and
1669 * `[d,e,f]`, returns `a*d + b*e + c*f`. This is also known as the dot product.
1670 *
1671 * ```
1672 * use easy_ml::tensors::Tensor;
1673 * let tensor = Tensor::from([("sequence", 5)], vec![3, 4, 5, 6, 7]);
1674 * assert_eq!(tensor.scalar_product(&tensor), 3*3 + 4*4 + 5*5 + 6*6 + 7*7);
1675 * ```
1676 *
1677 * # Generics
1678 *
1679 * This method can be called with any right hand side that can be converted to a TensorView,
1680 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1681 *
1682 * # Panics
1683 *
1684 * If the two vectors are not of equal length or their dimension names do not match.
1685 */
1686 // Would like this impl block to be in operations.rs too but then it would show first in the
1687 // Tensor docs which isn't ideal
1688 pub fn scalar_product<S, I>(&self, rhs: I) -> T
1689 where
1690 I: Into<TensorView<T, S, 1>>,
1691 S: TensorRef<T, 1>,
1692 {
1693 self.scalar_product_less_generic(rhs.into())
1694 }
1695}
1696
1697impl<T> Tensor<T, 2>
1698where
1699 T: Numeric,
1700 for<'a> &'a T: NumericRef<T>,
1701{
1702 /**
1703 * Returns the determinant of this square matrix, or None if the matrix
1704 * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant_tensor())
1705 */
1706 pub fn determinant(&self) -> Option<T> {
1707 linear_algebra::determinant_tensor::<T, _, _>(self)
1708 }
1709
1710 /**
1711 * Computes the inverse of a matrix provided that it exists. To have an inverse a
1712 * matrix must be square (same number of rows and columns) and it must also have a
1713 * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse_tensor())
1714 */
1715 pub fn inverse(&self) -> Option<Tensor<T, 2>> {
1716 linear_algebra::inverse_tensor::<T, _, _>(self)
1717 }
1718
1719 /**
1720 * Computes the covariance matrix for this feature matrix along the specified feature
1721 * dimension in this matrix. See [`linear_algebra`](crate::linear_algebra::covariance()).
1722 */
1723 pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
1724 linear_algebra::covariance::<T, _, _>(self, feature_dimension)
1725 }
1726}
1727
1728// FIXME: want this to be callable in the main numeric impl block
1729impl<T> Tensor<T, 2>
1730where
1731 T: Numeric,
1732{
1733 /**
1734 * Creates a diagonal matrix of the provided size with the diagonal elements
1735 * set to the provided value and all other elements in the tensor set to 0.
1736 * A diagonal matrix is always square.
1737 *
1738 * The size is still taken as a shape to facilitate creating a diagonal matrix
1739 * from the dimensionality of an existing one. If the provided value is 1 then
1740 * this will create an identity matrix.
1741 *
1742 * A 3 x 3 identity matrix:
1743 * ```ignore
1744 * [
1745 * 1, 0, 0
1746 * 0, 1, 0
1747 * 0, 0, 1
1748 * ]
1749 * ```
1750 *
1751 * # Panics
1752 *
1753 * - If the shape is not square.
1754 * - If a dimension name is not unique
1755 * - If any dimension has 0 elements
1756 */
1757 #[track_caller]
1758 pub fn diagonal(shape: [(Dimension, usize); 2], value: T) -> Tensor<T, 2> {
1759 if !crate::tensors::dimensions::is_square(&shape) {
1760 panic!("Shape must be square: {:?}", shape);
1761 }
1762 let mut tensor = Tensor::empty(shape, T::zero());
1763 for ([r, c], x) in tensor.iter_reference_mut().with_index() {
1764 if r == c {
1765 *x = value.clone();
1766 }
1767 }
1768 tensor
1769 }
1770}
1771
1772impl<T> Tensor<T, 2> {
1773 /**
1774 * Converts this 2 dimensional Tensor into a Matrix.
1775 *
1776 * This is a wrapper around the `From<Tensor<T, 2>>` implementation.
1777 *
1778 * The Matrix will have the data in the same order, with rows equal to the length of
1779 * the first dimension in the tensor, and columns equal to the length of the second.
1780 */
1781 pub fn into_matrix(self) -> crate::matrices::Matrix<T> {
1782 self.into()
1783 }
1784}
1785
1786/**
1787 * Methods for tensors with numerical real valued types, such as f32 or f64.
1788 *
1789 * This excludes signed and unsigned integers as they do not support decimal
1790 * precision and hence can't be used for operations like square roots.
1791 *
1792 * Third party fixed precision and infinite precision decimal types should
1793 * be able to implement all of the methods for [Real] and then utilise these functions.
1794 */
1795impl<T: Real> Tensor<T, 1>
1796where
1797 for<'a> &'a T: RealRef<T>,
1798{
1799 /**
1800 * Computes the [L2 norm](https://en.wikipedia.org/wiki/Euclidean_vector#Length)
1801 * of this vector, also referred to as the length or magnitude,
1802 * and written as ||x||, or sometimes |x|.
1803 *
1804 * ||**a**|| = sqrt(a<sub>1</sub><sup>2</sup> + a<sub>2</sub><sup>2</sup> + a<sub>3</sub><sup>2</sup>...) = sqrt(**a**<sup>T</sup> * **a**)
1805 *
1806 * This is a shorthand for `(x.iter().map(|x| x * x).sum().sqrt()`, ie
1807 * the square root of the dot product of a vector with itself.
1808 *
1809 * The euclidean length can be used to compute a
1810 * [unit vector](https://en.wikipedia.org/wiki/Unit_vector), that is, a
1811 * vector with length of 1. This should not be confused with a unit matrix,
1812 * which is another name for an identity matrix.
1813 *
1814 * ```
1815 * use easy_ml::tensors::Tensor;
1816 * let a = Tensor::from([("data", 3)], vec![ 1.0, 2.0, 3.0 ]);
1817 * let length = a.euclidean_length(); // (1^2 + 2^2 + 3^2)^0.5
1818 * let unit = a.map(|x| x / length);
1819 * assert_eq!(unit.euclidean_length(), 1.0);
1820 * ```
1821 */
1822 // TODO: Scalar ops for tensors
1823 pub fn euclidean_length(&self) -> T {
1824 self.direct_iter_reference()
1825 .map(|x| x * x)
1826 .sum::<T>()
1827 .sqrt()
1828 }
1829}
1830
1831#[cfg(feature = "serde")]
1832mod serde_impls {
1833 use crate::tensors::{Dimension, InvalidShapeError, Tensor};
1834 use serde::Deserialize;
1835 use std::convert::TryFrom;
1836
1837 /**
1838 * Deserialised data for a Tensor. Can be converted into a Tensor by providing `&'static str`
1839 * dimension names.
1840 */
1841 #[derive(Deserialize)]
1842 #[serde(rename = "Tensor")]
1843 pub struct TensorDeserialize<'a, T, const D: usize> {
1844 data: Vec<T>,
1845 #[serde(with = "serde_arrays")]
1846 #[serde(borrow)]
1847 shape: [(&'a str, usize); D],
1848 }
1849
1850 impl<'a, T, const D: usize> TensorDeserialize<'a, T, D> {
1851 /**
1852 * Converts this deserialised Tensor data to a Tensor, using the provided `&'static str`
1853 * dimension names in place of what was serialised (which wouldn't necessarily live
1854 * long enough).
1855 */
1856 pub fn into_tensor(
1857 self,
1858 dimensions: [Dimension; D],
1859 ) -> Result<Tensor<T, D>, InvalidShapeError<D>> {
1860 let shape = std::array::from_fn(|d| (dimensions[d], self.shape[d].1));
1861 // Safety: Use the normal constructor that performs validation to prevent invalid
1862 // serialized data being created as a Tensor, which would then break all the
1863 // code that's relying on these invariants.
1864 // By never serialising the strides in the first place, we reduce the possibility
1865 // of creating invalid serialised represenations at the slight increase in
1866 // serialisation work.
1867 Tensor::try_from(shape, self.data)
1868 }
1869 }
1870
1871 /**
1872 * Converts this deserialised Tensor data which has a static lifetime for the dimension
1873 * names to a Tensor, using the serialised data.
1874 */
1875 impl<T, const D: usize> TryFrom<TensorDeserialize<'static, T, D>> for Tensor<T, D> {
1876 type Error = InvalidShapeError<D>;
1877
1878 fn try_from(value: TensorDeserialize<'static, T, D>) -> Result<Self, Self::Error> {
1879 Tensor::try_from(value.shape, value.data)
1880 }
1881 }
1882}
1883
1884#[cfg(feature = "serde")]
1885#[test]
1886fn test_serialize() {
1887 fn assert_serialize<T: Serialize>() {}
1888 assert_serialize::<Tensor<f64, 3>>();
1889 assert_serialize::<Tensor<f64, 2>>();
1890 assert_serialize::<Tensor<f64, 1>>();
1891 assert_serialize::<Tensor<f64, 0>>();
1892}
1893
1894#[cfg(feature = "serde")]
1895#[test]
1896fn test_deserialize() {
1897 use serde::Deserialize;
1898 fn assert_deserialize<'de, T: Deserialize<'de>>() {}
1899 assert_deserialize::<TensorDeserialize<f64, 3>>();
1900 assert_deserialize::<TensorDeserialize<f64, 2>>();
1901 assert_deserialize::<TensorDeserialize<f64, 1>>();
1902 assert_deserialize::<TensorDeserialize<f64, 0>>();
1903}
1904
1905#[cfg(feature = "serde")]
1906#[test]
1907fn test_serialization_deserialization_loop() {
1908 #[rustfmt::skip]
1909 let tensor = Tensor::from(
1910 [("rows", 3), ("columns", 4)],
1911 vec![
1912 1, 2, 3, 4,
1913 5, 6, 7, 8,
1914 9, 10, 11, 12
1915 ],
1916 );
1917 let encoded = toml::to_string(&tensor).unwrap();
1918 assert_eq!(
1919 encoded,
1920 r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
1921shape = [["rows", 3], ["columns", 4]]
1922"#,
1923 );
1924 let parsed: Result<TensorDeserialize<i32, 2>, _> = toml::from_str(&encoded);
1925 assert!(parsed.is_ok());
1926 let result = parsed.unwrap().into_tensor(["rows", "columns"]);
1927 assert!(result.is_ok());
1928 assert_eq!(result.unwrap(), tensor);
1929}
1930
1931#[cfg(feature = "serde")]
1932#[test]
1933fn test_deserialization_validation() {
1934 let parsed: Result<TensorDeserialize<i32, 2>, _> = toml::from_str(
1935 r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
1936shape = [["rows", 4], ["columns", 4]]
1937"#,
1938 );
1939 assert!(parsed.is_ok());
1940 let result = parsed.unwrap().into_tensor(["rows", "columns"]);
1941 assert!(result.is_err());
1942}
1943
1944macro_rules! tensor_select_impl {
1945 (impl Tensor $d:literal 1) => {
1946 impl<T> Tensor<T, $d> {
1947 /**
1948 * Selects the provided dimension name and index pairs in this Tensor, returning a
1949 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1950 * always indexed as the provided values.
1951 *
1952 * This is a shorthand for manually constructing the TensorView and
1953 * [TensorIndex]
1954 *
1955 * Note: due to limitations in Rust's const generics support, this method is only
1956 * implemented for `provided_indexes` of length 1 and `D` from 1 to 6. You can fall
1957 * back to manual construction to create `TensorIndex`es with multiple provided
1958 * indexes if you need to reduce dimensionality by more than 1 dimension at a time.
1959 */
1960 #[track_caller]
1961 pub fn select(
1962 &self,
1963 provided_indexes: [(Dimension, usize); 1],
1964 ) -> TensorView<T, TensorIndex<T, &Tensor<T, $d>, $d, 1>, { $d - 1 }> {
1965 TensorView::from(TensorIndex::from(self, provided_indexes))
1966 }
1967
1968 /**
1969 * Selects the provided dimension name and index pairs in this Tensor, returning a
1970 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1971 * always indexed as the provided values. The TensorIndex mutably borrows this
1972 * Tensor, and can therefore mutate it
1973 *
1974 * See [select](Tensor::select)
1975 */
1976 #[track_caller]
1977 pub fn select_mut(
1978 &mut self,
1979 provided_indexes: [(Dimension, usize); 1],
1980 ) -> TensorView<T, TensorIndex<T, &mut Tensor<T, $d>, $d, 1>, { $d - 1 }> {
1981 TensorView::from(TensorIndex::from(self, provided_indexes))
1982 }
1983
1984 /**
1985 * Selects the provided dimension name and index pairs in this Tensor, returning a
1986 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1987 * always indexed as the provided values. The TensorIndex takes ownership of this
1988 * Tensor, and can therefore mutate it
1989 *
1990 * See [select](Tensor::select)
1991 */
1992 #[track_caller]
1993 pub fn select_owned(
1994 self,
1995 provided_indexes: [(Dimension, usize); 1],
1996 ) -> TensorView<T, TensorIndex<T, Tensor<T, $d>, $d, 1>, { $d - 1 }> {
1997 TensorView::from(TensorIndex::from(self, provided_indexes))
1998 }
1999 }
2000 };
2001}
2002
2003tensor_select_impl!(impl Tensor 6 1);
2004tensor_select_impl!(impl Tensor 5 1);
2005tensor_select_impl!(impl Tensor 4 1);
2006tensor_select_impl!(impl Tensor 3 1);
2007tensor_select_impl!(impl Tensor 2 1);
2008tensor_select_impl!(impl Tensor 1 1);
2009
2010macro_rules! tensor_expand_impl {
2011 (impl Tensor $d:literal 1) => {
2012 impl<T> Tensor<T, $d> {
2013 /**
2014 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2015 * a particular position within the shape, returning a TensorView which has more
2016 * dimensions than this Tensor.
2017 *
2018 * This is a shorthand for manually constructing the TensorView and
2019 * [TensorExpansion]
2020 *
2021 * Note: due to limitations in Rust's const generics support, this method is only
2022 * implemented for `extra_dimension_names` of length 1 and `D` from 0 to 5. You can
2023 * fall back to manual construction to create `TensorExpansion`s with multiple provided
2024 * indexes if you need to increase dimensionality by more than 1 dimension at a time.
2025 */
2026 #[track_caller]
2027 pub fn expand(
2028 &self,
2029 extra_dimension_names: [(usize, Dimension); 1],
2030 ) -> TensorView<T, TensorExpansion<T, &Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2031 TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2032 }
2033
2034 /**
2035 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2036 * a particular position within the shape, returning a TensorView which has more
2037 * dimensions than this Tensor. The TensorIndex mutably borrows this
2038 * Tensor, and can therefore mutate it
2039 *
2040 * See [expand](Tensor::expand)
2041 */
2042 #[track_caller]
2043 pub fn expand_mut(
2044 &mut self,
2045 extra_dimension_names: [(usize, Dimension); 1],
2046 ) -> TensorView<T, TensorExpansion<T, &mut Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2047 TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2048 }
2049
2050 /**
2051 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2052 * a particular position within the shape, returning a TensorView which has more
2053 * dimensions than this Tensor. The TensorIndex takes ownership of this
2054 * Tensor, and can therefore mutate it
2055 *
2056 * See [expand](Tensor::expand)
2057 */
2058 #[track_caller]
2059 pub fn expand_owned(
2060 self,
2061 extra_dimension_names: [(usize, Dimension); 1],
2062 ) -> TensorView<T, TensorExpansion<T, Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2063 TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2064 }
2065 }
2066 };
2067}
2068
2069tensor_expand_impl!(impl Tensor 0 1);
2070tensor_expand_impl!(impl Tensor 1 1);
2071tensor_expand_impl!(impl Tensor 2 1);
2072tensor_expand_impl!(impl Tensor 3 1);
2073tensor_expand_impl!(impl Tensor 4 1);
2074tensor_expand_impl!(impl Tensor 5 1);