easy_ml/tensors/mod.rs
1/*!
2 * Generic N dimensional [named tensors](http://nlp.seas.harvard.edu/NamedTensor).
3 *
4 * Tensors are generic over some type `T` and some usize `D`. If `T` is [Numeric](super::numeric)
5 * then the tensor can be used in a mathematical way. `D` is the number of dimensions in the tensor
6 * and a compile time constant. Each tensor also carries `D` dimension name and length pairs.
7 */
8use crate::linear_algebra;
9use crate::numeric::extra::{Real, RealRef};
10use crate::numeric::{Numeric, NumericRef};
11use crate::tensors::indexing::{
12 ShapeIterator, TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
13 TensorReferenceMutIterator, TensorTranspose,
14};
15use crate::tensors::views::{
16 DataLayout, IndexRange, IndexRangeValidationError, TensorExpansion, TensorIndex, TensorMask,
17 TensorMut, TensorRange, TensorRef, TensorRename, TensorReshape, TensorReverse, TensorView,
18};
19
20use std::error::Error;
21use std::fmt;
22
23#[cfg(feature = "serde")]
24use serde::Serialize;
25
26pub mod dimensions;
27mod display;
28pub mod indexing;
29pub mod operations;
30pub mod views;
31
32#[cfg(feature = "serde")]
33pub use serde_impls::TensorDeserialize;
34
35/**
36 * Dimension names are represented as static string references.
37 *
38 * This allows you to use string literals to refer to named dimensions, for example you might want
39 * to construct a tensor with a shape of
40 * `[("batch", 1000), ("height", 100), ("width", 100), ("rgba", 4)]`.
41 *
42 * Alternatively you can define the strings once as constants and refer to your dimension
43 * names by the constant identifiers.
44 *
45 * ```
46 * const BATCH: &'static str = "batch";
47 * const HEIGHT: &'static str = "height";
48 * const WIDTH: &'static str = "width";
49 * const RGBA: &'static str = "rgba";
50 * ```
51 *
52 * Although `Dimension` is interchangable with `&'static str` as it is just a type alias, Easy ML
53 * uses `Dimension` whenever dimension names are expected to distinguish the types from just
54 * strings.
55 */
56pub type Dimension = &'static str;
57
58/**
59 * An error indicating failure to do something with a Tensor because the requested shape
60 * is not valid.
61 */
62#[derive(Clone, Debug, Eq, PartialEq)]
63pub struct InvalidShapeError<const D: usize> {
64 shape: [(Dimension, usize); D],
65}
66
67impl<const D: usize> InvalidShapeError<D> {
68 /**
69 * Checks if this shape is valid. This is mainly for internal library use but may also be
70 * useful for unit testing.
71 *
72 * Note: in some functions and methods, an InvalidShapeError may be returned which is a valid
73 * shape, but not the right size for the quantity of data provided.
74 */
75 pub fn is_valid(&self) -> bool {
76 !crate::tensors::dimensions::has_duplicates(&self.shape)
77 && !self.shape.iter().any(|d| d.1 == 0)
78 }
79
80 /**
81 * Constructs an InvalidShapeError for assistance with unit testing. Note that you can
82 * construct an InvalidShapeError that *is* a valid shape in this way.
83 */
84 pub fn new(shape: [(Dimension, usize); D]) -> InvalidShapeError<D> {
85 InvalidShapeError { shape }
86 }
87
88 pub fn shape(&self) -> [(Dimension, usize); D] {
89 self.shape
90 }
91
92 pub fn shape_ref(&self) -> &[(Dimension, usize); D] {
93 &self.shape
94 }
95
96 // Panics if the shape is invalid for any reason with the appropriate error message.
97 #[track_caller]
98 #[inline]
99 fn validate_dimensions_or_panic(shape: &[(Dimension, usize); D], data_len: usize) {
100 let elements = crate::tensors::dimensions::elements(shape);
101 if data_len != elements {
102 panic!(
103 "Product of dimension lengths must match size of data. {} != {}",
104 elements, data_len
105 );
106 }
107 if crate::tensors::dimensions::has_duplicates(shape) {
108 panic!("Dimension names must all be unique: {:?}", &shape);
109 }
110 if shape.iter().any(|d| d.1 == 0) {
111 panic!("No dimension can have 0 elements: {:?}", &shape);
112 }
113 }
114
115 // Returns true if the shape is valid and matches the data length
116 fn validate_dimensions(shape: &[(Dimension, usize); D], data_len: usize) -> bool {
117 let elements = crate::tensors::dimensions::elements(shape);
118 data_len == elements
119 && !crate::tensors::dimensions::has_duplicates(shape)
120 && !shape.iter().any(|d| d.1 == 0)
121 }
122}
123
124impl<const D: usize> fmt::Display for InvalidShapeError<D> {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 write!(
127 f,
128 "Dimensions must all be at least length 1 with unique names: {:?}",
129 self.shape
130 )
131 }
132}
133
134impl<const D: usize> Error for InvalidShapeError<D> {}
135
136/**
137 * An error indicating failure to do something with a Tensor because the dimension names that
138 * were provided did not match with the dimension names that were valid.
139 *
140 * Typically this would be due to the same dimension name being provided multiple times, or a
141 * dimension name being provided that is not present in the shape of the Tensor in use.
142 */
143#[derive(Clone, Debug, Eq, PartialEq)]
144pub struct InvalidDimensionsError<const D: usize, const P: usize> {
145 valid: [Dimension; D],
146 provided: [Dimension; P],
147}
148
149impl<const D: usize, const P: usize> InvalidDimensionsError<D, P> {
150 /**
151 * Checks if the provided dimensions have duplicate names. This is mainly for internal library
152 * use but may also be useful for unit testing.
153 */
154 pub fn has_duplicates(&self) -> bool {
155 crate::tensors::dimensions::has_duplicates_names(&self.provided)
156 }
157
158 // TODO: method to check provided is a subset of valid
159
160 /**
161 * Constructs an InvalidDimensions for assistance with unit testing.
162 */
163 pub fn new(provided: [Dimension; P], valid: [Dimension; D]) -> InvalidDimensionsError<D, P> {
164 InvalidDimensionsError { valid, provided }
165 }
166
167 pub fn provided_names(&self) -> [Dimension; P] {
168 self.provided
169 }
170
171 pub fn provided_names_ref(&self) -> &[Dimension; P] {
172 &self.provided
173 }
174
175 pub fn valid_names(&self) -> [Dimension; D] {
176 self.valid
177 }
178
179 pub fn valid_names_ref(&self) -> &[Dimension; D] {
180 &self.valid
181 }
182}
183
184impl<const D: usize, const P: usize> fmt::Display for InvalidDimensionsError<D, P> {
185 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186 if P > 0 {
187 write!(
188 f,
189 "Dimensions names {:?} were incorrect, valid dimensions in this context are: {:?}",
190 self.provided, self.valid
191 )
192 } else {
193 write!(f, "Dimensions names {:?} were incorrect", self.provided)
194 }
195 }
196}
197
198impl<const D: usize, const P: usize> Error for InvalidDimensionsError<D, P> {}
199
200#[test]
201fn test_sync() {
202 fn assert_sync<T: Sync>() {}
203 assert_sync::<InvalidShapeError<2>>();
204 assert_sync::<InvalidDimensionsError<2, 2>>();
205}
206
207#[test]
208fn test_send() {
209 fn assert_send<T: Send>() {}
210 assert_send::<InvalidShapeError<2>>();
211 assert_send::<InvalidDimensionsError<2, 2>>();
212}
213
214/**
215 * A [named tensor](http://nlp.seas.harvard.edu/NamedTensor) of some type `T` and number of
216 * dimensions `D`.
217 *
218 * Tensors are a generalisation of matrices; whereas [Matrix](crate::matrices::Matrix) only
219 * supports 2 dimensions, and vectors are represented in Matrix by making either the rows or
220 * columns have a length of one, [Tensor] supports an arbitary number of dimensions,
221 * with 0 through 6 having full API support. A `Tensor<T, 2>` is very similar to a `Matrix<T>`
222 * except that this type associates each dimension with a name, and favors names to refer to
223 * dimensions instead of index order.
224 *
225 * Like Matrix, the type of the data in this Tensor may implement no traits, in which case the
226 * tensor will be rather useless. If the type implements Clone most storage and accessor methods
227 * are defined and if the type implements Numeric then the tensor can be used in a mathematical
228 * way.
229 *
230 * Like Matrix, a Tensor must always contain at least one element, and it may not not have more
231 * elements than `std::isize::MAX`. Concerned readers should note that on a 64 bit computer this
232 * maximum value is 9,223,372,036,854,775,807 so running out of memory is likely to occur first.
233 *
234 * When doing numeric operations with Tensors you should be careful to not consume a tensor by
235 * accidentally using it by value. All the operations are also defined on references to tensors
236 * so you should favor &x + &y style notation for tensors you intend to continue using.
237 *
238 * See also:
239 * - [indexing]
240 */
241#[derive(Debug)]
242#[cfg_attr(feature = "serde", derive(Serialize))]
243pub struct Tensor<T, const D: usize> {
244 data: Vec<T>,
245 #[cfg_attr(feature = "serde", serde(with = "serde_arrays"))]
246 shape: [(Dimension, usize); D],
247 #[cfg_attr(feature = "serde", serde(skip))]
248 strides: [usize; D],
249}
250
251impl<T, const D: usize> Tensor<T, D> {
252 /**
253 * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
254 *
255 * The product of the dimension lengths corresponds to the number of elements the Tensor
256 * will store. Elements are stored in what would be row major order for a Matrix.
257 * Each step in memory through the N dimensions corresponds to incrementing the rightmost
258 * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
259 * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
260 * for the 25th and final element.
261 *
262 * # Panics
263 *
264 * - If the number of provided elements does not match the product of the dimension lengths.
265 * - If a dimension name is not unique
266 * - If any dimension has 0 elements
267 *
268 * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
269 * a single element (since the product of an empty list is 1).
270 */
271 #[track_caller]
272 pub fn from(shape: [(Dimension, usize); D], data: Vec<T>) -> Self {
273 InvalidShapeError::validate_dimensions_or_panic(&shape, data.len());
274 let strides = compute_strides(&shape);
275 Tensor {
276 data,
277 shape,
278 strides,
279 }
280 }
281
282 /**
283 * Creates a Tensor with a particular shape initialised from a function.
284 *
285 * The product of the dimension lengths corresponds to the number of elements the Tensor
286 * will store. Elements are stored in what would be row major order for a Matrix.
287 * Each step in memory through the N dimensions corresponds to incrementing the rightmost
288 * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
289 * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
290 * for the 25th and final element. These same indexes will be passed to the producer function
291 * to initialised the values for the Tensor.
292 *
293 * ```
294 * use easy_ml::tensors::Tensor;
295 * let tensor = Tensor::from_fn([("rows", 4), ("columns", 4)], |[r, c]| r * c);
296 * assert_eq!(
297 * tensor,
298 * Tensor::from([("rows", 4), ("columns", 4)], vec![
299 * 0, 0, 0, 0,
300 * 0, 1, 2, 3,
301 * 0, 2, 4, 6,
302 * 0, 3, 6, 9,
303 * ])
304 * );
305 * ```
306 *
307 * # Panics
308 *
309 * - If a dimension name is not unique
310 * - If any dimension has 0 elements
311 *
312 * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
313 * a single element (since the product of an empty list is 1).
314 */
315 #[track_caller]
316 pub fn from_fn<F>(shape: [(Dimension, usize); D], mut producer: F) -> Self
317 where
318 F: FnMut([usize; D]) -> T,
319 {
320 let length = dimensions::elements(&shape);
321 let mut data = Vec::with_capacity(length);
322 let iterator = ShapeIterator::from(shape);
323 for index in iterator {
324 data.push(producer(index));
325 }
326 Tensor::from(shape, data)
327 }
328
329 /**
330 * The shape of this tensor. Since Tensors are named Tensors, their shape is not just a
331 * list of lengths along each dimension, but instead a list of pairs of names and lengths.
332 *
333 * See also
334 * - [dimensions]
335 * - [indexing]
336 */
337 pub fn shape(&self) -> [(Dimension, usize); D] {
338 self.shape
339 }
340
341 /**
342 * Returns the length of the dimension name provided, if one is present in the Tensor.
343 *
344 * See also
345 * - [dimensions]
346 * - [indexing]
347 */
348 pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
349 dimensions::length_of(&self.shape, dimension)
350 }
351
352 /**
353 * Returns the last index of the dimension name provided, if one is present in the Tensor.
354 *
355 * This is always 1 less than the length, the 'index' in this sense is based on what the
356 * Tensor's shape is, not any implementation index.
357 *
358 * See also
359 * - [dimensions]
360 * - [indexing]
361 */
362 pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
363 dimensions::last_index_of(&self.shape, dimension)
364 }
365
366 /**
367 * A non panicking version of [from](Tensor::from) which returns `Result::Err` if the input
368 * is invalid.
369 *
370 * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
371 *
372 * The product of the dimension lengths corresponds to the number of elements the Tensor
373 * will store. Elements are stored in what would be row major order for a Matrix.
374 * Each step in memory through the N dimensions corresponds to incrementing the rightmost
375 * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
376 * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
377 * for the 25th and final element.
378 *
379 * Returns the Err variant if
380 * - If the number of provided elements does not match the product of the dimension lengths.
381 * - If a dimension name is not unique
382 * - If any dimension has 0 elements
383 *
384 * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
385 * a single element (since the product of an empty list is 1).
386 */
387 pub fn try_from(
388 shape: [(Dimension, usize); D],
389 data: Vec<T>,
390 ) -> Result<Self, InvalidShapeError<D>> {
391 let valid = InvalidShapeError::validate_dimensions(&shape, data.len());
392 if !valid {
393 return Err(InvalidShapeError::new(shape));
394 }
395 let strides = compute_strides(&shape);
396 Ok(Tensor {
397 data,
398 shape,
399 strides,
400 })
401 }
402
403 /// Unverified constructor for interal use when we know the dimensions/data/strides are
404 /// unchanged and don't need reverification
405 pub(crate) fn direct_from(
406 data: Vec<T>,
407 shape: [(Dimension, usize); D],
408 strides: [usize; D],
409 ) -> Self {
410 Tensor {
411 data,
412 shape,
413 strides,
414 }
415 }
416
417 /// Unverified constructor for interal use when we know the dimensions/data/strides are
418 /// the same as the existing instance and don't need reverification
419 #[allow(dead_code)] // pretty sure something else will want this in the future
420 pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> Self {
421 Tensor {
422 data,
423 shape: self.shape,
424 strides: self.strides,
425 }
426 }
427}
428
429impl<T> Tensor<T, 0> {
430 /**
431 * Creates a 0 dimensional tensor from some scalar
432 */
433 pub fn from_scalar(value: T) -> Tensor<T, 0> {
434 Tensor {
435 data: vec![value],
436 shape: [],
437 strides: [],
438 }
439 }
440
441 /**
442 * Returns the sole element of the 0 dimensional tensor.
443 */
444 pub fn into_scalar(self) -> T {
445 self.data
446 .into_iter()
447 .next()
448 .expect("Tensors always have at least 1 element")
449 }
450}
451
452impl<T> Tensor<T, 0>
453where
454 T: Clone,
455{
456 /**
457 * Returns a copy of the sole element in the 0 dimensional tensor.
458 */
459 pub fn scalar(&self) -> T {
460 self.data
461 .first()
462 .expect("Tensors always have at least 1 element")
463 .clone()
464 }
465}
466
467impl<T> From<T> for Tensor<T, 0> {
468 fn from(scalar: T) -> Tensor<T, 0> {
469 Tensor::from_scalar(scalar)
470 }
471}
472// TODO: See if we can find a way to write the reverse Tensor<T, 0> -> T conversion using From or Into (doesn't seem like we can?)
473
474// # Safety
475//
476// We promise to never implement interior mutability for Tensor.
477/**
478 * A Tensor implements TensorRef.
479 */
480unsafe impl<T, const D: usize> TensorRef<T, D> for Tensor<T, D> {
481 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
482 let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
483 self.data.get(i)
484 }
485
486 fn view_shape(&self) -> [(Dimension, usize); D] {
487 Tensor::shape(self)
488 }
489
490 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
491 unsafe {
492 // The point of get_reference_unchecked is no bounds checking, and therefore
493 // it does not make any sense to just use `unwrap` here. The trait documents that
494 // it's undefind behaviour to call this method with an out of bounds index, so we
495 // can assume the None case will never happen.
496 let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
497 self.data.get_unchecked(i)
498 }
499 }
500
501 fn data_layout(&self) -> DataLayout<D> {
502 // We always have our memory in most significant to least
503 DataLayout::Linear(std::array::from_fn(|i| self.shape[i].0))
504 }
505}
506
507// # Safety
508//
509// We promise to never implement interior mutability for Tensor.
510unsafe impl<T, const D: usize> TensorMut<T, D> for Tensor<T, D> {
511 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
512 let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
513 self.data.get_mut(i)
514 }
515
516 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
517 unsafe {
518 // The point of get_reference_unchecked_mut is no bounds checking, and therefore
519 // it does not make any sense to just use `unwrap` here. The trait documents that
520 // it's undefind behaviour to call this method with an out of bounds index, so we
521 // can assume the None case will never happen.
522 let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
523 self.data.get_unchecked_mut(i)
524 }
525 }
526}
527
528/**
529 * Any tensor of a Cloneable type implements Clone.
530 */
531impl<T: Clone, const D: usize> Clone for Tensor<T, D> {
532 fn clone(&self) -> Self {
533 self.map(|element| element)
534 }
535}
536
537/**
538 * Any tensor of a Displayable type implements Display
539 *
540 * You can control the precision of the formatting using format arguments, i.e.
541 * `format!("{:.3}", tensor)`
542 */
543impl<T: std::fmt::Display, const D: usize> std::fmt::Display for Tensor<T, D> {
544 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
545 crate::tensors::display::format_view(self, f)
546 }
547}
548
549/**
550 * Any 2 dimensional tensor can be converted to a matrix with rows equal to the length of the
551 * first dimension in the tensor, and columns equal to the length of the second.
552 */
553impl<T> From<Tensor<T, 2>> for crate::matrices::Matrix<T> {
554 fn from(tensor: Tensor<T, 2>) -> Self {
555 crate::matrices::Matrix::from_flat_row_major(
556 (tensor.shape[0].1, tensor.shape[1].1),
557 tensor.data,
558 )
559 }
560}
561
562pub(crate) fn compute_strides<const D: usize>(shape: &[(Dimension, usize); D]) -> [usize; D] {
563 std::array::from_fn(|d| shape.iter().skip(d + 1).map(|d| d.1).product())
564}
565
566/// returns the 1 dimensional index to use to get the requested index into some tensor
567#[inline]
568pub(crate) fn get_index_direct<const D: usize>(
569 // indexes to use
570 indexes: &[usize; D],
571 // strides for indexing into the tensor
572 strides: &[usize; D],
573 // shape of the tensor to index into
574 shape: &[(Dimension, usize); D],
575) -> Option<usize> {
576 let mut index = 0;
577 for d in 0..D {
578 let n = indexes[d];
579 if n >= shape[d].1 {
580 return None;
581 }
582 index += n * strides[d];
583 }
584 Some(index)
585}
586
587/// returns the 1 dimensional index to use to get the requested index into some tensor, without
588/// checking the indexes are within bounds for the shape.
589#[inline]
590fn get_index_direct_unchecked<const D: usize>(
591 // indexes to use
592 indexes: &[usize; D],
593 // strides for indexing into the tensor
594 strides: &[usize; D],
595) -> usize {
596 let mut index = 0;
597 for d in 0..D {
598 let n = indexes[d];
599 index += n * strides[d];
600 }
601 index
602}
603
604impl<T, const D: usize> Tensor<T, D> {
605 pub fn view(&self) -> TensorView<T, &Tensor<T, D>, D> {
606 TensorView::from(self)
607 }
608
609 pub fn view_mut(&mut self) -> TensorView<T, &mut Tensor<T, D>, D> {
610 TensorView::from(self)
611 }
612
613 pub fn view_owned(self) -> TensorView<T, Tensor<T, D>, D> {
614 TensorView::from(self)
615 }
616
617 /**
618 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
619 * to read values from this tensor.
620 *
621 * # Panics
622 *
623 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
624 */
625 #[track_caller]
626 pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &Tensor<T, D>, D> {
627 TensorAccess::from(self, dimensions)
628 }
629
630 /**
631 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
632 * to read or write values from this tensor.
633 *
634 * # Panics
635 *
636 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
637 */
638 #[track_caller]
639 pub fn index_by_mut(
640 &mut self,
641 dimensions: [Dimension; D],
642 ) -> TensorAccess<T, &mut Tensor<T, D>, D> {
643 TensorAccess::from(self, dimensions)
644 }
645
646 /**
647 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
648 * to read or write values from this tensor.
649 *
650 * # Panics
651 *
652 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
653 */
654 #[track_caller]
655 pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, Tensor<T, D>, D> {
656 TensorAccess::from(self, dimensions)
657 }
658
659 /**
660 * Creates a TensorAccess which will index into the dimensions this Tensor was created with
661 * in the same order as they were provided. See [TensorAccess::from_source_order].
662 */
663 pub fn index(&self) -> TensorAccess<T, &Tensor<T, D>, D> {
664 TensorAccess::from_source_order(self)
665 }
666
667 /**
668 * Creates a TensorAccess which will index into the dimensions this Tensor was
669 * created with in the same order as they were provided. The TensorAccess mutably borrows
670 * the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
671 */
672 pub fn index_mut(&mut self) -> TensorAccess<T, &mut Tensor<T, D>, D> {
673 TensorAccess::from_source_order(self)
674 }
675
676 /**
677 * Creates a TensorAccess which will index into the dimensions this Tensor was
678 * created with in the same order as they were provided. The TensorAccess takes ownership
679 * of the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
680 */
681 pub fn index_owned(self) -> TensorAccess<T, Tensor<T, D>, D> {
682 TensorAccess::from_source_order(self)
683 }
684
685 /**
686 * Returns an iterator over references to the data in this Tensor.
687 */
688 pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, Tensor<T, D>, D> {
689 TensorReferenceIterator::from(self)
690 }
691
692 /**
693 * Returns an iterator over mutable references to the data in this Tensor.
694 */
695 pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<'_, T, Tensor<T, D>, D> {
696 TensorReferenceMutIterator::from(self)
697 }
698
699 /**
700 * Creates an iterator over the values in this Tensor.
701 */
702 pub fn iter_owned(self) -> TensorOwnedIterator<T, Tensor<T, D>, D>
703 where
704 T: Default,
705 {
706 TensorOwnedIterator::from(self)
707 }
708
709 // Non public index order reference iterator since we don't want to expose our implementation
710 // details to public API since then we could never change them.
711 pub(crate) fn direct_iter_reference(&self) -> std::slice::Iter<'_, T> {
712 self.data.iter()
713 }
714
715 // Non public index order reference iterator since we don't want to expose our implementation
716 // details to public API since then we could never change them.
717 pub(crate) fn direct_iter_reference_mut(&mut self) -> std::slice::IterMut<'_, T> {
718 self.data.iter_mut()
719 }
720
721 /**
722 * Renames the dimension names of the tensor without changing the lengths of the dimensions
723 * in the tensor or moving any data around.
724 *
725 * ```
726 * use easy_ml::tensors::Tensor;
727 * let mut tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6]);
728 * tensor.rename(["y", "z"]);
729 * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
730 * ```
731 *
732 * # Panics
733 *
734 * - If a dimension name is not unique
735 */
736 #[track_caller]
737 pub fn rename(&mut self, dimensions: [Dimension; D]) {
738 if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
739 panic!("Dimension names must all be unique: {:?}", &dimensions);
740 }
741 #[allow(clippy::needless_range_loop)]
742 for d in 0..D {
743 self.shape[d].0 = dimensions[d];
744 }
745 }
746
747 /**
748 * Renames the dimension names of the tensor and returns it without changing the lengths
749 * of the dimensions in the tensor or moving any data around.
750 *
751 * ```
752 * use easy_ml::tensors::Tensor;
753 * let tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6])
754 * .rename_owned(["y", "z"]);
755 * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
756 * ```
757 *
758 * # Panics
759 *
760 * - If a dimension name is not unique
761 */
762 #[track_caller]
763 pub fn rename_owned(mut self, dimensions: [Dimension; D]) -> Tensor<T, D> {
764 self.rename(dimensions);
765 self
766 }
767
768 /**
769 * Returns a TensorView with the dimension names of the shape renamed to the provided
770 * dimensions. The data of this tensor and the dimension lengths and order remain unchanged.
771 *
772 * This is a shorthand for constructing the TensorView from this Tensor.
773 *
774 * ```
775 * use easy_ml::tensors::Tensor;
776 * use easy_ml::tensors::views::{TensorView, TensorRename};
777 * let abc = Tensor::from([("a", 3), ("b", 3), ("c", 3)], (0..27).collect());
778 * let xyz = abc.rename_view(["x", "y", "z"]);
779 * let also_xyz = TensorView::from(TensorRename::from(&abc, ["x", "y", "z"]));
780 * assert_eq!(xyz, also_xyz);
781 * assert_eq!(xyz, Tensor::from([("x", 3), ("y", 3), ("z", 3)], (0..27).collect()));
782 * ```
783 *
784 * # Panics
785 *
786 * - If a dimension name is not unique
787 */
788 #[track_caller]
789 pub fn rename_view(
790 &self,
791 dimensions: [Dimension; D],
792 ) -> TensorView<T, TensorRename<T, &Tensor<T, D>, D>, D> {
793 TensorView::from(TensorRename::from(self, dimensions))
794 }
795
796 /**
797 * Changes the shape of the tensor without changing the number of dimensions or moving any
798 * data around.
799 *
800 * # Panics
801 *
802 * - If the number of provided elements in the new shape does not match the product of the
803 * dimension lengths in the existing tensor's shape.
804 * - If a dimension name is not unique
805 * - If any dimension has 0 elements
806 *
807 * ```
808 * use easy_ml::tensors::Tensor;
809 * let mut tensor = Tensor::from([("width", 2), ("height", 2)], vec![
810 * 1, 2,
811 * 3, 4
812 * ]);
813 * tensor.reshape_mut([("batch", 1), ("image", 4)]);
814 * assert_eq!(tensor, Tensor::from([("batch", 1), ("image", 4)], vec![ 1, 2, 3, 4 ]));
815 * ```
816 */
817 #[track_caller]
818 pub fn reshape_mut(&mut self, shape: [(Dimension, usize); D]) {
819 InvalidShapeError::validate_dimensions_or_panic(&shape, self.data.len());
820 let strides = compute_strides(&shape);
821 self.shape = shape;
822 self.strides = strides;
823 }
824
825 /**
826 * Consumes the tensor and changes the shape of the tensor without moving any
827 * data around. The new Tensor may also have a different number of dimensions.
828 *
829 * # Panics
830 *
831 * - If the number of provided elements in the new shape does not match the product of the
832 * dimension lengths in the existing tensor's shape.
833 * - If a dimension name is not unique
834 *
835 * ```
836 * use easy_ml::tensors::Tensor;
837 * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
838 * 1, 2,
839 * 3, 4
840 * ]);
841 * let flattened = tensor.reshape_owned([("image", 4)]);
842 * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
843 * ```
844 *
845 * See also [reshape_view_owned](Tensor::reshape_view_owned)
846 */
847 #[track_caller]
848 pub fn reshape_owned<const D2: usize>(self, shape: [(Dimension, usize); D2]) -> Tensor<T, D2> {
849 Tensor::from(shape, self.data)
850 }
851
852 /**
853 * Returns a TensorView with the dimensions changed to the provided shape without moving any
854 * data around. The new Tensor may also have a different number of dimensions.
855 *
856 * This is a shorthand for constructing the TensorView from this Tensor.
857 *
858 * # Panics
859 *
860 * - If the number of provided elements in the new shape does not match the product of the
861 * dimension lengths in the existing tensor's shape.
862 * - If a dimension name is not unique
863 *
864 * ```
865 * use easy_ml::tensors::Tensor;
866 * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
867 * 1, 2,
868 * 3, 4
869 * ]);
870 * let flattened = tensor.reshape_view([("image", 4)]);
871 * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
872 * ```
873 */
874 pub fn reshape_view<const D2: usize>(
875 &self,
876 shape: [(Dimension, usize); D2],
877 ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, D2>, D2> {
878 TensorView::from(TensorReshape::from(self, shape))
879 }
880
881 /**
882 * Returns a TensorView with the dimensions changed to the provided shape without moving any
883 * data around. The new Tensor may also have a different number of dimensions.
884 *
885 * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
886 * mutably borrows this Tensor, and can therefore mutate it
887 *
888 * # Panics
889 *
890 * - If the number of provided elements in the new shape does not match the product of the
891 * dimension lengths in the existing tensor's shape.
892 * - If a dimension name is not unique
893 */
894 pub fn reshape_view_mut<const D2: usize>(
895 &mut self,
896 shape: [(Dimension, usize); D2],
897 ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, D2>, D2> {
898 TensorView::from(TensorReshape::from(self, shape))
899 }
900
901 /**
902 * Returns a TensorView with the dimensions changed to the provided shape without moving any
903 * data around. The new Tensor may also have a different number of dimensions.
904 *
905 * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
906 * takes ownership of this Tensor, and can therefore mutate it
907 *
908 * # Panics
909 *
910 * - If the number of provided elements in the new shape does not match the product of the
911 * dimension lengths in the existing tensor's shape.
912 * - If a dimension name is not unique
913 *
914 * See also [reshape_owned](Tensor::reshape_owned)
915 */
916 pub fn reshape_view_owned<const D2: usize>(
917 self,
918 shape: [(Dimension, usize); D2],
919 ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, D2>, D2> {
920 TensorView::from(TensorReshape::from(self, shape))
921 }
922
923 /**
924 * Given the dimension name, returns a view of this tensor reshaped to one dimension
925 * with a length equal to the number of elements in this tensor.
926 */
927 pub fn flatten_view(
928 &self,
929 dimension: Dimension,
930 ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, 1>, 1> {
931 self.reshape_view([(dimension, dimensions::elements(&self.shape))])
932 }
933
934 /**
935 * Given the dimension name, returns a view of this tensor reshaped to one dimension
936 * with a length equal to the number of elements in this tensor.
937 */
938 pub fn flatten_view_mut(
939 &mut self,
940 dimension: Dimension,
941 ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, 1>, 1> {
942 self.reshape_view_mut([(dimension, dimensions::elements(&self.shape))])
943 }
944
945 /**
946 * Given the dimension name, returns a view of this tensor reshaped to one dimension
947 * with a length equal to the number of elements in this tensor.
948 *
949 * If you intend to query the tensor a lot after creating the view, consider
950 * using [flatten](Tensor::flatten) instead as it will have less overhead to
951 * index after creation.
952 */
953 pub fn flatten_view_owned(
954 self,
955 dimension: Dimension,
956 ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, 1>, 1> {
957 let length = dimensions::elements(&self.shape);
958 self.reshape_view_owned([(dimension, length)])
959 }
960
961 /**
962 * Given the dimension name, returns a new tensor reshaped to one dimension
963 * with a length equal to the number of elements in this tensor.
964 */
965 pub fn flatten(self, dimension: Dimension) -> Tensor<T, 1> {
966 let length = dimensions::elements(&self.shape);
967 self.reshape_owned([(dimension, length)])
968 }
969
970 /**
971 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
972 * range from view. Error cases are documented on [TensorRange].
973 *
974 * This is a shorthand for constructing the TensorView from this Tensor.
975 *
976 * ```
977 * use easy_ml::tensors::Tensor;
978 * use easy_ml::tensors::views::{TensorView, TensorRange, IndexRange};
979 * # use easy_ml::tensors::views::IndexRangeValidationError;
980 * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
981 * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
982 * let cropped = samples.range([("x", IndexRange::new(1, 5)), ("y", IndexRange::new(1, 5))])?;
983 * let also_cropped = TensorView::from(
984 * TensorRange::from(&samples, [("x", 1..6), ("y", 1..6)])?
985 * );
986 * assert_eq!(cropped, also_cropped);
987 * assert_eq!(
988 * cropped.select([("batch", 0)]),
989 * Tensor::from([("x", 5), ("y", 5)], vec![
990 * 8, 9, 10, 11, 12,
991 * 15, 16, 17, 18, 19,
992 * 22, 23, 24, 25, 26,
993 * 29, 30, 31, 32, 33,
994 * 36, 37, 38, 39, 40
995 * ])
996 * );
997 * # Ok(())
998 * # }
999 * ```
1000 */
1001 pub fn range<R, const P: usize>(
1002 &self,
1003 ranges: [(Dimension, R); P],
1004 ) -> Result<TensorView<T, TensorRange<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1005 where
1006 R: Into<IndexRange>,
1007 {
1008 TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1009 }
1010
1011 /**
1012 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1013 * range from view. Error cases are documented on [TensorRange]. The TensorRange
1014 * mutably borrows this Tensor, and can therefore mutate it
1015 *
1016 * This is a shorthand for constructing the TensorView from this Tensor.
1017 */
1018 pub fn range_mut<R, const P: usize>(
1019 &mut self,
1020 ranges: [(Dimension, R); P],
1021 ) -> Result<
1022 TensorView<T, TensorRange<T, &mut Tensor<T, D>, D>, D>,
1023 IndexRangeValidationError<D, P>,
1024 >
1025 where
1026 R: Into<IndexRange>,
1027 {
1028 TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1029 }
1030
1031 /**
1032 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1033 * range from view. Error cases are documented on [TensorRange]. The TensorRange
1034 * takes ownership of this Tensor, and can therefore mutate it
1035 *
1036 * This is a shorthand for constructing the TensorView from this Tensor.
1037 */
1038 pub fn range_owned<R, const P: usize>(
1039 self,
1040 ranges: [(Dimension, R); P],
1041 ) -> Result<TensorView<T, TensorRange<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1042 where
1043 R: Into<IndexRange>,
1044 {
1045 TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1046 }
1047
1048 /**
1049 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1050 * range from view. Error cases are documented on [TensorMask].
1051 *
1052 * This is a shorthand for constructing the TensorView from this Tensor.
1053 *
1054 * ```
1055 * use easy_ml::tensors::Tensor;
1056 * use easy_ml::tensors::views::{TensorView, TensorMask, IndexRange};
1057 * # use easy_ml::tensors::views::IndexRangeValidationError;
1058 * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
1059 * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
1060 * let corners = samples.mask([("x", IndexRange::new(3, 2)), ("y", IndexRange::new(3, 2))])?;
1061 * let also_corners = TensorView::from(
1062 * TensorMask::from(&samples, [("x", 3..5), ("y", 3..5)])?
1063 * );
1064 * assert_eq!(corners, also_corners);
1065 * assert_eq!(
1066 * corners.select([("batch", 0)]),
1067 * Tensor::from([("x", 5), ("y", 5)], vec![
1068 * 0, 1, 2, 5, 6,
1069 * 7, 8, 9, 12, 13,
1070 * 14, 15, 16, 19, 20,
1071 *
1072 * 35, 36, 37, 40, 41,
1073 * 42, 43, 44, 47, 48
1074 * ])
1075 * );
1076 * # Ok(())
1077 * # }
1078 * ```
1079 */
1080 pub fn mask<R, const P: usize>(
1081 &self,
1082 masks: [(Dimension, R); P],
1083 ) -> Result<TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1084 where
1085 R: Into<IndexRange>,
1086 {
1087 TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1088 }
1089
1090 /**
1091 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1092 * range from view. Error cases are documented on [TensorMask]. The TensorMask
1093 * mutably borrows this Tensor, and can therefore mutate it
1094 *
1095 * This is a shorthand for constructing the TensorView from this Tensor.
1096 */
1097 pub fn mask_mut<R, const P: usize>(
1098 &mut self,
1099 masks: [(Dimension, R); P],
1100 ) -> Result<
1101 TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D>,
1102 IndexRangeValidationError<D, P>,
1103 >
1104 where
1105 R: Into<IndexRange>,
1106 {
1107 TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1108 }
1109
1110 /**
1111 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1112 * range from view. Error cases are documented on [TensorMask]. The TensorMask
1113 * takes ownership of this Tensor, and can therefore mutate it
1114 *
1115 * This is a shorthand for constructing the TensorView from this Tensor.
1116 */
1117 pub fn mask_owned<R, const P: usize>(
1118 self,
1119 masks: [(Dimension, R); P],
1120 ) -> Result<TensorView<T, TensorMask<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1121 where
1122 R: Into<IndexRange>,
1123 {
1124 TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1125 }
1126
1127 /**
1128 * Returns a TensorView with a mask taken in the provided dimension, hiding
1129 * all but the start_and_end number of values at the start and end of the
1130 * dimension from view.
1131 *
1132 * This is a shorthand for constructing the TensorView from this Tensor.
1133 *
1134 * # Panics
1135 *
1136 * - If the start_and_end value is 0 - this is not a valid mask as it would
1137 * hide all elements
1138 * - If the dimension is not in the tensor's shape.
1139 *
1140 * ```
1141 * use easy_ml::tensors::Tensor;
1142 * let samples = Tensor::from([("batch", 5), ("x", 2), ("y", 2)], (0..20).collect());
1143 * let shortlist = samples.start_and_end_of("batch", 1);
1144 * assert_eq!(
1145 * shortlist,
1146 * Tensor::from([("batch", 2), ("x", 2), ("y", 2)], vec![
1147 * 0, 1,
1148 * 2, 3,
1149 *
1150 * 16, 17,
1151 * 18, 19,
1152 * ])
1153 * )
1154 * ```
1155 */
1156 #[track_caller]
1157 pub fn start_and_end_of(
1158 &self,
1159 dimension: Dimension,
1160 start_and_end: usize,
1161 ) -> TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D> {
1162 TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
1163 }
1164
1165 /**
1166 * Returns a TensorView with a mask taken in the provided dimension, hiding
1167 * all but the start_and_end number of values at the start and end of the
1168 * dimension from view. The TensorMask mutably borrows this Tensor, and can
1169 * therefore mutate it
1170 *
1171 * This is a shorthand for constructing the TensorView from this Tensor.
1172 *
1173 * # Panics
1174 *
1175 * - If the start_and_end value is 0 - this is not a valid mask as it would
1176 * hide all elements
1177 * - If the dimension is not in the tensor's shape.
1178 */
1179 #[track_caller]
1180 pub fn start_and_end_of_mut(
1181 &mut self,
1182 dimension: Dimension,
1183 start_and_end: usize,
1184 ) -> TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D> {
1185 TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
1186 }
1187
1188 /**
1189 * Returns a TensorView with a mask taken in the provided dimension, hiding
1190 * all but the start_and_end number of values at the start and end of the
1191 * dimension from view. The TensorMask takes ownership of this Tensor, and
1192 * can therefore mutate it
1193 *
1194 * This is a shorthand for constructing the TensorView from this Tensor.
1195 *
1196 * # Panics
1197 *
1198 * - If the start_and_end value is 0 - this is not a valid mask as it would
1199 * hide all elements
1200 * - If the dimension is not in the tensor's shape.
1201 */
1202 #[track_caller]
1203 pub fn start_and_end_of_owned(
1204 self,
1205 dimension: Dimension,
1206 start_and_end: usize,
1207 ) -> TensorView<T, TensorMask<T, Tensor<T, D>, D>, D> {
1208 TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
1209 }
1210
1211 /**
1212 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1213 * order. The data of this tensor and the dimension lengths remain unchanged.
1214 *
1215 * This is a shorthand for constructing the TensorView from this Tensor.
1216 *
1217 * ```
1218 * use easy_ml::tensors::Tensor;
1219 * use easy_ml::tensors::views::{TensorView, TensorReverse};
1220 * let ab = Tensor::from([("a", 2), ("b", 3)], (0..6).collect());
1221 * let reversed = ab.reverse(&["a"]);
1222 * let also_reversed = TensorView::from(TensorReverse::from(&ab, &["a"]));
1223 * assert_eq!(reversed, also_reversed);
1224 * assert_eq!(
1225 * reversed,
1226 * Tensor::from(
1227 * [("a", 2), ("b", 3)],
1228 * vec![
1229 * 3, 4, 5,
1230 * 0, 1, 2,
1231 * ]
1232 * )
1233 * );
1234 * ```
1235 *
1236 * # Panics
1237 *
1238 * - If a dimension name is not in the tensor's shape or is repeated.
1239 */
1240 #[track_caller]
1241 pub fn reverse(
1242 &self,
1243 dimensions: &[Dimension],
1244 ) -> TensorView<T, TensorReverse<T, &Tensor<T, D>, D>, D> {
1245 TensorView::from(TensorReverse::from(self, dimensions))
1246 }
1247
1248 /**
1249 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1250 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1251 * mutably borrows this Tensor, and can therefore mutate it
1252 *
1253 * This is a shorthand for constructing the TensorView from this Tensor.
1254 *
1255 * # Panics
1256 *
1257 * - If a dimension name is not in the tensor's shape or is repeated.
1258 */
1259 #[track_caller]
1260 pub fn reverse_mut(
1261 &mut self,
1262 dimensions: &[Dimension],
1263 ) -> TensorView<T, TensorReverse<T, &mut Tensor<T, D>, D>, D> {
1264 TensorView::from(TensorReverse::from(self, dimensions))
1265 }
1266
1267 /**
1268 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1269 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1270 * takes ownership of this Tensor, and can therefore mutate it
1271 *
1272 * This is a shorthand for constructing the TensorView from this Tensor.
1273 *
1274 * # Panics
1275 *
1276 * - If a dimension name is not in the tensor's shape or is repeated.
1277 */
1278 #[track_caller]
1279 pub fn reverse_owned(
1280 self,
1281 dimensions: &[Dimension],
1282 ) -> TensorView<T, TensorReverse<T, Tensor<T, D>, D>, D> {
1283 TensorView::from(TensorReverse::from(self, dimensions))
1284 }
1285
1286 /**
1287 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1288 * mapped by a function. The value pairs are not copied for you, if you're using `Copy` types
1289 * or need to clone the values anyway, you can use
1290 * [`Tensor::elementwise`](Tensor::elementwise) instead.
1291 *
1292 * # Generics
1293 *
1294 * This method can be called with any right hand side that can be converted to a TensorView,
1295 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1296 *
1297 * # Panics
1298 *
1299 * If the two tensors have different shapes.
1300 */
1301 #[track_caller]
1302 pub fn elementwise_reference<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1303 where
1304 I: Into<TensorView<T, S, D>>,
1305 S: TensorRef<T, D>,
1306 M: Fn(&T, &T) -> T,
1307 {
1308 self.elementwise_reference_less_generic(rhs.into(), mapping_function)
1309 }
1310
1311 /**
1312 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1313 * mapped by a function. The mapping function also receives each index corresponding to the
1314 * value pairs. The value pairs are not copied for you, if you're using `Copy` types
1315 * or need to clone the values anyway, you can use
1316 * [`Tensor::elementwise_with_index`](Tensor::elementwise_with_index) instead.
1317 *
1318 * # Generics
1319 *
1320 * This method can be called with any right hand side that can be converted to a TensorView,
1321 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1322 *
1323 * # Panics
1324 *
1325 * If the two tensors have different shapes.
1326 */
1327 #[track_caller]
1328 pub fn elementwise_reference_with_index<S, I, M>(
1329 &self,
1330 rhs: I,
1331 mapping_function: M,
1332 ) -> Tensor<T, D>
1333 where
1334 I: Into<TensorView<T, S, D>>,
1335 S: TensorRef<T, D>,
1336 M: Fn([usize; D], &T, &T) -> T,
1337 {
1338 self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
1339 }
1340
1341 #[track_caller]
1342 fn elementwise_reference_less_generic<S, M>(
1343 &self,
1344 rhs: TensorView<T, S, D>,
1345 mapping_function: M,
1346 ) -> Tensor<T, D>
1347 where
1348 S: TensorRef<T, D>,
1349 M: Fn(&T, &T) -> T,
1350 {
1351 let left_shape = self.shape();
1352 let right_shape = rhs.shape();
1353 if left_shape != right_shape {
1354 panic!(
1355 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1356 left_shape, right_shape
1357 );
1358 }
1359 let mapped = self
1360 .direct_iter_reference()
1361 .zip(rhs.iter_reference())
1362 .map(|(x, y)| mapping_function(x, y))
1363 .collect();
1364 // We're not changing the shape of the Tensor, so don't need to revalidate
1365 Tensor::direct_from(mapped, self.shape, self.strides)
1366 }
1367
1368 #[track_caller]
1369 fn elementwise_reference_less_generic_with_index<S, M>(
1370 &self,
1371 rhs: TensorView<T, S, D>,
1372 mapping_function: M,
1373 ) -> Tensor<T, D>
1374 where
1375 S: TensorRef<T, D>,
1376 M: Fn([usize; D], &T, &T) -> T,
1377 {
1378 let left_shape = self.shape();
1379 let right_shape = rhs.shape();
1380 if left_shape != right_shape {
1381 panic!(
1382 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1383 left_shape, right_shape
1384 );
1385 }
1386 // we just checked both shapes were the same, so we don't need to propagate indexes
1387 // for both tensors because they'll be identical
1388 let mapped = self
1389 .direct_iter_reference()
1390 .zip(rhs.iter_reference().with_index())
1391 .map(|(x, (i, y))| mapping_function(i, x, y))
1392 .collect();
1393 // We're not changing the shape of the Tensor, so don't need to revalidate
1394 Tensor::direct_from(mapped, self.shape, self.strides)
1395 }
1396
1397 /**
1398 * Returns a TensorView which makes the order of the data in this tensor appear to be in
1399 * a different order. The order of the dimension names is unchanged, although their lengths
1400 * may swap.
1401 *
1402 * This is a shorthand for constructing the TensorView from this Tensor.
1403 *
1404 * See also: [transpose](Tensor::transpose), [TensorTranspose]
1405 *
1406 * # Panics
1407 *
1408 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1409 * order need not match.
1410 */
1411 pub fn transpose_view(
1412 &self,
1413 dimensions: [Dimension; D],
1414 ) -> TensorView<T, TensorTranspose<T, &Tensor<T, D>, D>, D> {
1415 TensorView::from(TensorTranspose::from(self, dimensions))
1416 }
1417}
1418
1419impl<T, const D: usize> Tensor<T, D>
1420where
1421 T: Clone,
1422{
1423 /**
1424 * Creates a tensor with a particular number of dimensions and length in each dimension
1425 * with all elements initialised to the provided value.
1426 *
1427 * # Panics
1428 *
1429 * - If a dimension name is not unique
1430 * - If any dimension has 0 elements
1431 */
1432 #[track_caller]
1433 pub fn empty(shape: [(Dimension, usize); D], value: T) -> Self {
1434 let elements = crate::tensors::dimensions::elements(&shape);
1435 Tensor::from(shape, vec![value; elements])
1436 }
1437
1438 /**
1439 * Gets a copy of the first value in this tensor.
1440 * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
1441 * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
1442 */
1443 pub fn first(&self) -> T {
1444 self.data
1445 .first()
1446 .expect("Tensors always have at least 1 element")
1447 .clone()
1448 }
1449
1450 /**
1451 * Returns a new Tensor which has the same data as this tensor, but with the order of data
1452 * changed. The order of the dimension names is unchanged, although their lengths may swap.
1453 *
1454 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1455 * `transpose(["y", "x"])` which would return a new tensor with a shape of
1456 * `[("x", y), ("y", x)]` where every (x,y) of its data corresponds to (y,x) in the original.
1457 *
1458 * This method need not shift *all* the dimensions though, you could also swap the width
1459 * and height of images in a tensor with a shape of
1460 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `transpose(["batch", "w", "h", "c"])`
1461 * which would return a new tensor where all the images have been swapped over the diagonal.
1462 *
1463 * See also: [TensorAccess], [reorder](Tensor::reorder)
1464 *
1465 * # Panics
1466 *
1467 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1468 * order need not match (and if the order does match, this function is just an expensive
1469 * clone).
1470 */
1471 #[track_caller]
1472 pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1473 let shape = self.shape;
1474 let mut reordered = self.reorder(dimensions);
1475 // Transposition is essentially reordering, but we retain the dimension name ordering
1476 // of the original order, this means we may swap dimension lengths, but the dimensions
1477 // will not change order.
1478 #[allow(clippy::needless_range_loop)]
1479 for d in 0..D {
1480 reordered.shape[d].0 = shape[d].0;
1481 }
1482 reordered
1483 }
1484
1485 /**
1486 * Modifies this tensor to have the same data as before, but with the order of data changed.
1487 * The order of the dimension names is unchanged, although their lengths may swap.
1488 *
1489 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1490 * `transpose_mut(["y", "x"])` which would edit the tensor, updating its shape to
1491 * `[("x", y), ("y", x)]`, so every (x,y) of its data corresponds to (y,x) before the
1492 * transposition.
1493 *
1494 * The order swapping will try to be in place, but this is currently only supported for
1495 * square tensors with 2 dimensions. Other types of tensors will not be transposed in place.
1496 *
1497 * # Panics
1498 *
1499 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1500 * order need not match (and if the order does match, this function is just an expensive
1501 * clone).
1502 */
1503 #[track_caller]
1504 pub fn transpose_mut(&mut self, dimensions: [Dimension; D]) {
1505 let shape = self.shape;
1506 self.reorder_mut(dimensions);
1507 // Transposition is essentially reordering, but we retain the dimension name ordering
1508 // we had before, this means we may swap dimension lengths, but the dimensions
1509 // will not change order.
1510 #[allow(clippy::needless_range_loop)]
1511 for d in 0..D {
1512 self.shape[d].0 = shape[d].0;
1513 }
1514 }
1515
1516 /**
1517 * Returns a new Tensor which has the same data as this tensor, but with the order of the
1518 * dimensions and corresponding order of data changed.
1519 *
1520 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1521 * `reorder(["y", "x"])` which would return a new tensor with a shape of
1522 * `[("y", y), ("x", x)]` where every (y,x) of its data corresponds to (x,y) in the original.
1523 *
1524 * This method need not shift *all* the dimensions though, you could also swap the width
1525 * and height of images in a tensor with a shape of
1526 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `reorder(["batch", "w", "h", "c"])`
1527 * which would return a new tensor where every (b,w,h,c) of its data corresponds to (b,h,w,c)
1528 * in the original.
1529 *
1530 * See also: [TensorAccess], [transpose](Tensor::transpose)
1531 *
1532 * # Panics
1533 *
1534 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1535 * order need not match (and if the order does match, this function is just an expensive
1536 * clone).
1537 */
1538 #[track_caller]
1539 pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1540 let reorderd = match TensorAccess::try_from(&self, dimensions) {
1541 Ok(reordered) => reordered,
1542 Err(_error) => panic!(
1543 "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1544 dimensions, &self.shape,
1545 ),
1546 };
1547 let reorderd_shape = reorderd.shape();
1548 Tensor::from(reorderd_shape, reorderd.iter().collect())
1549 }
1550
1551 /**
1552 * Modifies this tensor to have the same data as before, but with the order of the
1553 * dimensions and corresponding order of data changed.
1554 *
1555 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1556 * `reorder_mut(["y", "x"])` which would edit the tensor, updating its shape to
1557 * `[("y", y), ("x", x)]`, so every (y,x) of its data corresponds to (x,y) before the
1558 * transposition.
1559 *
1560 * The order swapping will try to be in place, but this is currently only supported for
1561 * square tensors with 2 dimensions. Other types of tensors will not be reordered in place.
1562 *
1563 * # Panics
1564 *
1565 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1566 * order need not match (and if the order does match, this function is just an expensive
1567 * clone).
1568 */
1569 #[track_caller]
1570 pub fn reorder_mut(&mut self, dimensions: [Dimension; D]) {
1571 use crate::tensors::dimensions::DimensionMappings;
1572 if D == 2 && crate::tensors::dimensions::is_square(&self.shape) {
1573 let dimension_mapping = match DimensionMappings::new(&self.shape, &dimensions) {
1574 Some(dimension_mapping) => dimension_mapping,
1575 None => panic!(
1576 "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1577 dimensions, &self.shape,
1578 ),
1579 };
1580
1581 let shape = dimension_mapping.map_shape_to_requested(&self.shape);
1582 let shape_iterator = ShapeIterator::from(shape);
1583
1584 for index in shape_iterator {
1585 let i = index[0];
1586 let j = index[1];
1587 if j >= i {
1588 let mapped_index = dimension_mapping.map_dimensions_to_source(&index);
1589 // Swap elements from the upper triangle (using index order of the actual tensor's
1590 // shape)
1591 let temp = self.get_reference(index).unwrap().clone();
1592 // tensor[i,j] becomes tensor[mapping(i,j)]
1593 *self.get_reference_mut(index).unwrap() =
1594 self.get_reference(mapped_index).unwrap().clone();
1595 // tensor[mapping(i,j)] becomes tensor[i,j]
1596 *self.get_reference_mut(mapped_index).unwrap() = temp;
1597 // If the mapping is a noop we've assigned i,j to i,j
1598 // If the mapping is i,j -> j,i we've assigned i,j to j,i and j,i to i,j
1599 }
1600 }
1601
1602 // now update our shape and strides to match
1603 self.shape = shape;
1604 self.strides = compute_strides(&shape);
1605 } else {
1606 // fallback to allocating a new reordered tensor
1607 let reordered = self.reorder(dimensions);
1608 self.data = reordered.data;
1609 self.shape = reordered.shape;
1610 self.strides = reordered.strides;
1611 }
1612 }
1613
1614 /**
1615 * Returns an iterator over copies of the data in this Tensor.
1616 */
1617 pub fn iter(&self) -> TensorIterator<'_, T, Tensor<T, D>, D> {
1618 TensorIterator::from(self)
1619 }
1620
1621 /**
1622 * Creates and returns a new tensor with all values from the original with the
1623 * function applied to each. This can be used to change the type of the tensor
1624 * such as creating a mask:
1625 * ```
1626 * use easy_ml::tensors::Tensor;
1627 * let x = Tensor::from([("a", 2), ("b", 2)], vec![
1628 * 0.0, 1.2,
1629 * 5.8, 6.9
1630 * ]);
1631 * let y = x.map(|element| element > 2.0);
1632 * let result = Tensor::from([("a", 2), ("b", 2)], vec![
1633 * false, false,
1634 * true, true
1635 * ]);
1636 * assert_eq!(&y, &result);
1637 * ```
1638 */
1639 pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
1640 let mapped = self
1641 .data
1642 .iter()
1643 .map(|x| mapping_function(x.clone()))
1644 .collect();
1645 // We're not changing the shape of the Tensor, so don't need to revalidate
1646 Tensor::direct_from(mapped, self.shape, self.strides)
1647 }
1648
1649 /**
1650 * Creates and returns a new tensor with all values from the original and
1651 * the index of each value mapped by a function.
1652 */
1653 pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
1654 let mapped = self
1655 .iter()
1656 .with_index()
1657 .map(|(i, x)| mapping_function(i, x))
1658 .collect();
1659 // We're not changing the shape of the Tensor, so don't need to revalidate
1660 Tensor::direct_from(mapped, self.shape, self.strides)
1661 }
1662
1663 /**
1664 * Applies a function to all values in the tensor, modifying
1665 * the tensor in place.
1666 */
1667 pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1668 for value in self.data.iter_mut() {
1669 *value = mapping_function(value.clone());
1670 }
1671 }
1672
1673 /**
1674 * Applies a function to all values and each value's index in the tensor, modifying
1675 * the tensor in place.
1676 */
1677 pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
1678 self.iter_reference_mut()
1679 .with_index()
1680 .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
1681 }
1682
1683 /**
1684 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1685 * mapped by a function.
1686 *
1687 * ```
1688 * use easy_ml::tensors::Tensor;
1689 * let lhs = Tensor::from([("a", 4)], vec![1, 2, 3, 4]);
1690 * let rhs = Tensor::from([("a", 4)], vec![0, 1, 2, 3]);
1691 * let multiplied = lhs.elementwise(&rhs, |l, r| l * r);
1692 * assert_eq!(
1693 * multiplied,
1694 * Tensor::from([("a", 4)], vec![0, 2, 6, 12])
1695 * );
1696 * ```
1697 *
1698 * # Generics
1699 *
1700 * This method can be called with any right hand side that can be converted to a TensorView,
1701 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1702 *
1703 * # Panics
1704 *
1705 * If the two tensors have different shapes.
1706 */
1707 #[track_caller]
1708 pub fn elementwise<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1709 where
1710 I: Into<TensorView<T, S, D>>,
1711 S: TensorRef<T, D>,
1712 M: Fn(T, T) -> T,
1713 {
1714 self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
1715 mapping_function(lhs.clone(), rhs.clone())
1716 })
1717 }
1718
1719 /**
1720 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1721 * mapped by a function. The mapping function also receives each index corresponding to the
1722 * value pairs.
1723 *
1724 * # Generics
1725 *
1726 * This method can be called with any right hand side that can be converted to a TensorView,
1727 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1728 *
1729 * # Panics
1730 *
1731 * If the two tensors have different shapes.
1732 */
1733 #[track_caller]
1734 pub fn elementwise_with_index<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1735 where
1736 I: Into<TensorView<T, S, D>>,
1737 S: TensorRef<T, D>,
1738 M: Fn([usize; D], T, T) -> T,
1739 {
1740 self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
1741 mapping_function(i, lhs.clone(), rhs.clone())
1742 })
1743 }
1744}
1745
1746impl<T> Tensor<T, 1>
1747where
1748 T: Numeric,
1749 for<'a> &'a T: NumericRef<T>,
1750{
1751 /**
1752 * Computes the scalar product of two equal length vectors. For two vectors `[a,b,c]` and
1753 * `[d,e,f]`, returns `a*d + b*e + c*f`. This is also known as the dot product.
1754 *
1755 * ```
1756 * use easy_ml::tensors::Tensor;
1757 * let tensor = Tensor::from([("sequence", 5)], vec![3, 4, 5, 6, 7]);
1758 * assert_eq!(tensor.scalar_product(&tensor), 3*3 + 4*4 + 5*5 + 6*6 + 7*7);
1759 * ```
1760 *
1761 * # Generics
1762 *
1763 * This method can be called with any right hand side that can be converted to a TensorView,
1764 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1765 *
1766 * # Panics
1767 *
1768 * If the two vectors are not of equal length or their dimension names do not match.
1769 */
1770 // Would like this impl block to be in operations.rs too but then it would show first in the
1771 // Tensor docs which isn't ideal
1772 pub fn scalar_product<S, I>(&self, rhs: I) -> T
1773 where
1774 I: Into<TensorView<T, S, 1>>,
1775 S: TensorRef<T, 1>,
1776 {
1777 self.scalar_product_less_generic(rhs.into())
1778 }
1779}
1780
1781impl<T> Tensor<T, 2>
1782where
1783 T: Numeric,
1784 for<'a> &'a T: NumericRef<T>,
1785{
1786 /**
1787 * Returns the determinant of this square matrix, or None if the matrix
1788 * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant_tensor())
1789 */
1790 pub fn determinant(&self) -> Option<T> {
1791 linear_algebra::determinant_tensor::<T, _, _>(self)
1792 }
1793
1794 /**
1795 * Computes the inverse of a matrix provided that it exists. To have an inverse a
1796 * matrix must be square (same number of rows and columns) and it must also have a
1797 * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse_tensor())
1798 */
1799 pub fn inverse(&self) -> Option<Tensor<T, 2>> {
1800 linear_algebra::inverse_tensor::<T, _, _>(self)
1801 }
1802
1803 /**
1804 * Computes the covariance matrix for this feature matrix along the specified feature
1805 * dimension in this matrix. See [`linear_algebra`](crate::linear_algebra::covariance()).
1806 */
1807 pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
1808 linear_algebra::covariance::<T, _, _>(self, feature_dimension)
1809 }
1810}
1811
1812// FIXME: want this to be callable in the main numeric impl block
1813impl<T> Tensor<T, 2>
1814where
1815 T: Numeric,
1816{
1817 /**
1818 * Creates a diagonal matrix of the provided size with the diagonal elements
1819 * set to the provided value and all other elements in the tensor set to 0.
1820 * A diagonal matrix is always square.
1821 *
1822 * The size is still taken as a shape to facilitate creating a diagonal matrix
1823 * from the dimensionality of an existing one. If the provided value is 1 then
1824 * this will create an identity matrix.
1825 *
1826 * A 3 x 3 identity matrix:
1827 * ```ignore
1828 * [
1829 * 1, 0, 0
1830 * 0, 1, 0
1831 * 0, 0, 1
1832 * ]
1833 * ```
1834 *
1835 * # Panics
1836 *
1837 * - If the shape is not square.
1838 * - If a dimension name is not unique
1839 * - If any dimension has 0 elements
1840 */
1841 #[track_caller]
1842 pub fn diagonal(shape: [(Dimension, usize); 2], value: T) -> Tensor<T, 2> {
1843 if !crate::tensors::dimensions::is_square(&shape) {
1844 panic!("Shape must be square: {:?}", shape);
1845 }
1846 let mut tensor = Tensor::empty(shape, T::zero());
1847 for ([r, c], x) in tensor.iter_reference_mut().with_index() {
1848 if r == c {
1849 *x = value.clone();
1850 }
1851 }
1852 tensor
1853 }
1854}
1855
1856impl<T> Tensor<T, 2> {
1857 /**
1858 * Converts this 2 dimensional Tensor into a Matrix.
1859 *
1860 * This is a wrapper around the `From<Tensor<T, 2>>` implementation.
1861 *
1862 * The Matrix will have the data in the same order, with rows equal to the length of
1863 * the first dimension in the tensor, and columns equal to the length of the second.
1864 */
1865 pub fn into_matrix(self) -> crate::matrices::Matrix<T> {
1866 self.into()
1867 }
1868}
1869
1870/**
1871 * Methods for tensors with numerical real valued types, such as f32 or f64.
1872 *
1873 * This excludes signed and unsigned integers as they do not support decimal
1874 * precision and hence can't be used for operations like square roots.
1875 *
1876 * Third party fixed precision and infinite precision decimal types should
1877 * be able to implement all of the methods for [Real] and then utilise these functions.
1878 */
1879impl<T: Real> Tensor<T, 1>
1880where
1881 for<'a> &'a T: RealRef<T>,
1882{
1883 /**
1884 * Computes the [L2 norm](https://en.wikipedia.org/wiki/Euclidean_vector#Length)
1885 * of this vector, also referred to as the length or magnitude,
1886 * and written as ||x||, or sometimes |x|.
1887 *
1888 * ||**a**|| = sqrt(a<sub>1</sub><sup>2</sup> + a<sub>2</sub><sup>2</sup> + a<sub>3</sub><sup>2</sup>...) = sqrt(**a**<sup>T</sup> * **a**)
1889 *
1890 * This is a shorthand for `(x.iter().map(|x| x * x).sum().sqrt()`, ie
1891 * the square root of the dot product of a vector with itself.
1892 *
1893 * The euclidean length can be used to compute a
1894 * [unit vector](https://en.wikipedia.org/wiki/Unit_vector), that is, a
1895 * vector with length of 1. This should not be confused with a unit matrix,
1896 * which is another name for an identity matrix.
1897 *
1898 * ```
1899 * use easy_ml::tensors::Tensor;
1900 * let a = Tensor::from([("data", 3)], vec![ 1.0, 2.0, 3.0 ]);
1901 * let length = a.euclidean_length(); // (1^2 + 2^2 + 3^2)^0.5
1902 * let unit = a.map(|x| x / length);
1903 * assert_eq!(unit.euclidean_length(), 1.0);
1904 * ```
1905 */
1906 // TODO: Scalar ops for tensors
1907 pub fn euclidean_length(&self) -> T {
1908 self.direct_iter_reference()
1909 .map(|x| x * x)
1910 .sum::<T>()
1911 .sqrt()
1912 }
1913}
1914
1915#[cfg(feature = "serde")]
1916mod serde_impls {
1917 use crate::tensors::{Dimension, InvalidShapeError, Tensor};
1918 use serde::Deserialize;
1919 use std::convert::TryFrom;
1920
1921 /**
1922 * Deserialised data for a Tensor. Can be converted into a Tensor by providing `&'static str`
1923 * dimension names.
1924 */
1925 #[derive(Deserialize)]
1926 #[serde(rename = "Tensor")]
1927 pub struct TensorDeserialize<'a, T, const D: usize> {
1928 data: Vec<T>,
1929 #[serde(with = "serde_arrays")]
1930 #[serde(borrow)]
1931 shape: [(&'a str, usize); D],
1932 }
1933
1934 impl<'a, T, const D: usize> TensorDeserialize<'a, T, D> {
1935 /**
1936 * Converts this deserialised Tensor data to a Tensor, using the provided `&'static str`
1937 * dimension names in place of what was serialised (which wouldn't necessarily live
1938 * long enough).
1939 */
1940 pub fn into_tensor(
1941 self,
1942 dimensions: [Dimension; D],
1943 ) -> Result<Tensor<T, D>, InvalidShapeError<D>> {
1944 let shape = std::array::from_fn(|d| (dimensions[d], self.shape[d].1));
1945 // Safety: Use the normal constructor that performs validation to prevent invalid
1946 // serialized data being created as a Tensor, which would then break all the
1947 // code that's relying on these invariants.
1948 // By never serialising the strides in the first place, we reduce the possibility
1949 // of creating invalid serialised represenations at the slight increase in
1950 // serialisation work.
1951 Tensor::try_from(shape, self.data)
1952 }
1953 }
1954
1955 /**
1956 * Converts this deserialised Tensor data which has a static lifetime for the dimension
1957 * names to a Tensor, using the serialised data.
1958 */
1959 impl<T, const D: usize> TryFrom<TensorDeserialize<'static, T, D>> for Tensor<T, D> {
1960 type Error = InvalidShapeError<D>;
1961
1962 fn try_from(value: TensorDeserialize<'static, T, D>) -> Result<Self, Self::Error> {
1963 Tensor::try_from(value.shape, value.data)
1964 }
1965 }
1966}
1967
1968#[cfg(feature = "serde")]
1969#[test]
1970fn test_serialize() {
1971 fn assert_serialize<T: Serialize>() {}
1972 assert_serialize::<Tensor<f64, 3>>();
1973 assert_serialize::<Tensor<f64, 2>>();
1974 assert_serialize::<Tensor<f64, 1>>();
1975 assert_serialize::<Tensor<f64, 0>>();
1976}
1977
1978#[cfg(feature = "serde")]
1979#[test]
1980fn test_deserialize() {
1981 use serde::Deserialize;
1982 fn assert_deserialize<'de, T: Deserialize<'de>>() {}
1983 assert_deserialize::<TensorDeserialize<f64, 3>>();
1984 assert_deserialize::<TensorDeserialize<f64, 2>>();
1985 assert_deserialize::<TensorDeserialize<f64, 1>>();
1986 assert_deserialize::<TensorDeserialize<f64, 0>>();
1987}
1988
1989#[cfg(feature = "serde")]
1990#[test]
1991fn test_serialization_deserialization_loop() {
1992 #[rustfmt::skip]
1993 let tensor = Tensor::from(
1994 [("rows", 3), ("columns", 4)],
1995 vec![
1996 1, 2, 3, 4,
1997 5, 6, 7, 8,
1998 9, 10, 11, 12
1999 ],
2000 );
2001 let encoded = toml::to_string(&tensor).unwrap();
2002 assert_eq!(
2003 encoded,
2004 r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
2005shape = [["rows", 3], ["columns", 4]]
2006"#,
2007 );
2008 let parsed: Result<TensorDeserialize<i32, 2>, _> = toml::from_str(&encoded);
2009 assert!(parsed.is_ok());
2010 let result = parsed.unwrap().into_tensor(["rows", "columns"]);
2011 assert!(result.is_ok());
2012 assert_eq!(result.unwrap(), tensor);
2013}
2014
2015#[cfg(feature = "serde")]
2016#[test]
2017fn test_deserialization_validation() {
2018 let parsed: Result<TensorDeserialize<i32, 2>, _> = toml::from_str(
2019 r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
2020shape = [["rows", 4], ["columns", 4]]
2021"#,
2022 );
2023 assert!(parsed.is_ok());
2024 let result = parsed.unwrap().into_tensor(["rows", "columns"]);
2025 assert!(result.is_err());
2026}
2027
2028macro_rules! tensor_select_impl {
2029 (impl Tensor $d:literal 1) => {
2030 impl<T> Tensor<T, $d> {
2031 /**
2032 * Selects the provided dimension name and index pairs in this Tensor, returning a
2033 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
2034 * always indexed as the provided values.
2035 *
2036 * This is a shorthand for manually constructing the TensorView and
2037 * [TensorIndex]
2038 *
2039 * Note: due to limitations in Rust's const generics support, this method is only
2040 * implemented for `provided_indexes` of length 1 and `D` from 1 to 6. You can fall
2041 * back to manual construction to create `TensorIndex`es with multiple provided
2042 * indexes if you need to reduce dimensionality by more than 1 dimension at a time.
2043 */
2044 #[track_caller]
2045 pub fn select(
2046 &self,
2047 provided_indexes: [(Dimension, usize); 1],
2048 ) -> TensorView<T, TensorIndex<T, &Tensor<T, $d>, $d, 1>, { $d - 1 }> {
2049 TensorView::from(TensorIndex::from(self, provided_indexes))
2050 }
2051
2052 /**
2053 * Selects the provided dimension name and index pairs in this Tensor, returning a
2054 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
2055 * always indexed as the provided values. The TensorIndex mutably borrows this
2056 * Tensor, and can therefore mutate it
2057 *
2058 * See [select](Tensor::select)
2059 */
2060 #[track_caller]
2061 pub fn select_mut(
2062 &mut self,
2063 provided_indexes: [(Dimension, usize); 1],
2064 ) -> TensorView<T, TensorIndex<T, &mut Tensor<T, $d>, $d, 1>, { $d - 1 }> {
2065 TensorView::from(TensorIndex::from(self, provided_indexes))
2066 }
2067
2068 /**
2069 * Selects the provided dimension name and index pairs in this Tensor, returning a
2070 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
2071 * always indexed as the provided values. The TensorIndex takes ownership of this
2072 * Tensor, and can therefore mutate it
2073 *
2074 * See [select](Tensor::select)
2075 */
2076 #[track_caller]
2077 pub fn select_owned(
2078 self,
2079 provided_indexes: [(Dimension, usize); 1],
2080 ) -> TensorView<T, TensorIndex<T, Tensor<T, $d>, $d, 1>, { $d - 1 }> {
2081 TensorView::from(TensorIndex::from(self, provided_indexes))
2082 }
2083 }
2084 };
2085}
2086
2087tensor_select_impl!(impl Tensor 6 1);
2088tensor_select_impl!(impl Tensor 5 1);
2089tensor_select_impl!(impl Tensor 4 1);
2090tensor_select_impl!(impl Tensor 3 1);
2091tensor_select_impl!(impl Tensor 2 1);
2092tensor_select_impl!(impl Tensor 1 1);
2093
2094macro_rules! tensor_expand_impl {
2095 (impl Tensor $d:literal 1) => {
2096 impl<T> Tensor<T, $d> {
2097 /**
2098 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2099 * a particular position within the shape, returning a TensorView which has more
2100 * dimensions than this Tensor.
2101 *
2102 * This is a shorthand for manually constructing the TensorView and
2103 * [TensorExpansion]
2104 *
2105 * Note: due to limitations in Rust's const generics support, this method is only
2106 * implemented for `extra_dimension_names` of length 1 and `D` from 0 to 5. You can
2107 * fall back to manual construction to create `TensorExpansion`s with multiple provided
2108 * indexes if you need to increase dimensionality by more than 1 dimension at a time.
2109 */
2110 #[track_caller]
2111 pub fn expand(
2112 &self,
2113 extra_dimension_names: [(usize, Dimension); 1],
2114 ) -> TensorView<T, TensorExpansion<T, &Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2115 TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2116 }
2117
2118 /**
2119 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2120 * a particular position within the shape, returning a TensorView which has more
2121 * dimensions than this Tensor. The TensorIndex mutably borrows this
2122 * Tensor, and can therefore mutate it
2123 *
2124 * See [expand](Tensor::expand)
2125 */
2126 #[track_caller]
2127 pub fn expand_mut(
2128 &mut self,
2129 extra_dimension_names: [(usize, Dimension); 1],
2130 ) -> TensorView<T, TensorExpansion<T, &mut Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2131 TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2132 }
2133
2134 /**
2135 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2136 * a particular position within the shape, returning a TensorView which has more
2137 * dimensions than this Tensor. The TensorIndex takes ownership of this
2138 * Tensor, and can therefore mutate it
2139 *
2140 * See [expand](Tensor::expand)
2141 */
2142 #[track_caller]
2143 pub fn expand_owned(
2144 self,
2145 extra_dimension_names: [(usize, Dimension); 1],
2146 ) -> TensorView<T, TensorExpansion<T, Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2147 TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2148 }
2149 }
2150 };
2151}
2152
2153tensor_expand_impl!(impl Tensor 0 1);
2154tensor_expand_impl!(impl Tensor 1 1);
2155tensor_expand_impl!(impl Tensor 2 1);
2156tensor_expand_impl!(impl Tensor 3 1);
2157tensor_expand_impl!(impl Tensor 4 1);
2158tensor_expand_impl!(impl Tensor 5 1);