easy_ml/tensors/mod.rs
1/*!
2 * Generic N dimensional [named tensors](http://nlp.seas.harvard.edu/NamedTensor).
3 *
4 * Tensors are generic over some type `T` and some usize `D`. If `T` is [Numeric](super::numeric)
5 * then the tensor can be used in a mathematical way. `D` is the number of dimensions in the tensor
6 * and a compile time constant. Each tensor also carries `D` dimension name and length pairs.
7 */
8use crate::linear_algebra;
9use crate::numeric::extra::{Real, RealRef};
10use crate::numeric::{Numeric, NumericRef};
11use crate::tensors::indexing::{
12 ShapeIterator, TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
13 TensorReferenceMutIterator, TensorTranspose,
14};
15use crate::tensors::views::{
16 DataLayout, IndexRange, IndexRangeValidationError, TensorExpansion, TensorIndex, TensorMask,
17 TensorMut, TensorRange, TensorRef, TensorRename, TensorReshape, TensorReverse, TensorView,
18};
19
20use std::error::Error;
21use std::fmt;
22
23#[cfg(feature = "serde")]
24use serde::Serialize;
25
26pub mod dimensions;
27mod display;
28pub mod einsum;
29pub mod indexing;
30pub mod operations;
31pub mod views;
32
33#[cfg(feature = "serde")]
34pub use serde_impls::{TensorDeserialize, TensorDeserializeOwned};
35
36/**
37 * Dimension names are represented as static string references.
38 *
39 * This allows you to use string literals to refer to named dimensions, for example you might want
40 * to construct a tensor with a shape of
41 * `[("batch", 1000), ("height", 100), ("width", 100), ("rgba", 4)]`.
42 *
43 * Alternatively you can define the strings once as constants and refer to your dimension
44 * names by the constant identifiers.
45 *
46 * ```
47 * const BATCH: &'static str = "batch";
48 * const HEIGHT: &'static str = "height";
49 * const WIDTH: &'static str = "width";
50 * const RGBA: &'static str = "rgba";
51 * ```
52 *
53 * Although `Dimension` is interchangable with `&'static str` as it is just a type alias, Easy ML
54 * uses `Dimension` whenever dimension names are expected to distinguish the types from just
55 * strings.
56 */
57pub type Dimension = &'static str;
58
59/**
60 * An error indicating failure to do something with a Tensor because the requested shape
61 * is not valid.
62 */
63#[derive(Clone, Debug, Eq, PartialEq)]
64pub struct InvalidShapeError<const D: usize> {
65 shape: [(Dimension, usize); D],
66}
67
68impl<const D: usize> InvalidShapeError<D> {
69 /**
70 * Checks if this shape is valid. This is mainly for internal library use but may also be
71 * useful for unit testing.
72 *
73 * Note: in some functions and methods, an InvalidShapeError may be returned which is a valid
74 * shape, but not the right size for the quantity of data provided.
75 */
76 pub fn is_valid(&self) -> bool {
77 !crate::tensors::dimensions::has_duplicates(&self.shape)
78 && !self.shape.iter().any(|d| d.1 == 0)
79 }
80
81 /**
82 * Constructs an InvalidShapeError for assistance with unit testing. Note that you can
83 * construct an InvalidShapeError that *is* a valid shape in this way.
84 */
85 pub fn new(shape: [(Dimension, usize); D]) -> InvalidShapeError<D> {
86 InvalidShapeError { shape }
87 }
88
89 pub fn shape(&self) -> [(Dimension, usize); D] {
90 self.shape
91 }
92
93 pub fn shape_ref(&self) -> &[(Dimension, usize); D] {
94 &self.shape
95 }
96
97 // Panics if the shape is invalid for any reason with the appropriate error message.
98 #[track_caller]
99 #[inline]
100 fn validate_dimensions_or_panic(shape: &[(Dimension, usize); D], data_len: usize) {
101 let elements = crate::tensors::dimensions::elements(shape);
102 if data_len != elements {
103 panic!(
104 "Product of dimension lengths must match size of data. {} != {}",
105 elements, data_len
106 );
107 }
108 if crate::tensors::dimensions::has_duplicates(shape) {
109 panic!("Dimension names must all be unique: {:?}", &shape);
110 }
111 if shape.iter().any(|d| d.1 == 0) {
112 panic!("No dimension can have 0 elements: {:?}", &shape);
113 }
114 }
115
116 // Returns true if the shape is valid and matches the data length
117 fn validate_dimensions(shape: &[(Dimension, usize); D], data_len: usize) -> bool {
118 let elements = crate::tensors::dimensions::elements(shape);
119 data_len == elements
120 && !crate::tensors::dimensions::has_duplicates(shape)
121 && !shape.iter().any(|d| d.1 == 0)
122 }
123}
124
125impl<const D: usize> fmt::Display for InvalidShapeError<D> {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 write!(
128 f,
129 "Dimensions must all be at least length 1 with unique names: {:?}",
130 self.shape
131 )
132 }
133}
134
135impl<const D: usize> Error for InvalidShapeError<D> {}
136
137/**
138 * An error indicating failure to do something with a Tensor because the dimension names that
139 * were provided did not match with the dimension names that were valid.
140 *
141 * Typically this would be due to the same dimension name being provided multiple times, or a
142 * dimension name being provided that is not present in the shape of the Tensor in use.
143 */
144#[derive(Clone, Debug, Eq, PartialEq)]
145pub struct InvalidDimensionsError<const D: usize, const P: usize> {
146 valid: [Dimension; D],
147 provided: [Dimension; P],
148}
149
150impl<const D: usize, const P: usize> InvalidDimensionsError<D, P> {
151 /**
152 * Checks if the provided dimensions have duplicate names. This is mainly for internal library
153 * use but may also be useful for unit testing.
154 */
155 pub fn has_duplicates(&self) -> bool {
156 crate::tensors::dimensions::has_duplicates_names(&self.provided)
157 }
158
159 // TODO: method to check provided is a subset of valid
160
161 /**
162 * Constructs an InvalidDimensions for assistance with unit testing.
163 */
164 pub fn new(provided: [Dimension; P], valid: [Dimension; D]) -> InvalidDimensionsError<D, P> {
165 InvalidDimensionsError { valid, provided }
166 }
167
168 pub fn provided_names(&self) -> [Dimension; P] {
169 self.provided
170 }
171
172 pub fn provided_names_ref(&self) -> &[Dimension; P] {
173 &self.provided
174 }
175
176 pub fn valid_names(&self) -> [Dimension; D] {
177 self.valid
178 }
179
180 pub fn valid_names_ref(&self) -> &[Dimension; D] {
181 &self.valid
182 }
183}
184
185impl<const D: usize, const P: usize> fmt::Display for InvalidDimensionsError<D, P> {
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 if P > 0 {
188 write!(
189 f,
190 "Dimensions names {:?} were incorrect, valid dimensions in this context are: {:?}",
191 self.provided, self.valid
192 )
193 } else {
194 write!(f, "Dimensions names {:?} were incorrect", self.provided)
195 }
196 }
197}
198
199impl<const D: usize, const P: usize> Error for InvalidDimensionsError<D, P> {}
200
201#[test]
202fn test_sync() {
203 fn assert_sync<T: Sync>() {}
204 assert_sync::<InvalidShapeError<2>>();
205 assert_sync::<InvalidDimensionsError<2, 2>>();
206}
207
208#[test]
209fn test_send() {
210 fn assert_send<T: Send>() {}
211 assert_send::<InvalidShapeError<2>>();
212 assert_send::<InvalidDimensionsError<2, 2>>();
213}
214
215/**
216 * A [named tensor](http://nlp.seas.harvard.edu/NamedTensor) of some type `T` and number of
217 * dimensions `D`.
218 *
219 * Tensors are a generalisation of matrices; whereas [Matrix](crate::matrices::Matrix) only
220 * supports 2 dimensions, and vectors are represented in Matrix by making either the rows or
221 * columns have a length of one, [Tensor] supports an arbitary number of dimensions,
222 * with 0 through 6 having full API support. A `Tensor<T, 2>` is very similar to a `Matrix<T>`
223 * except that this type associates each dimension with a name, and favors names to refer to
224 * dimensions instead of index order.
225 *
226 * Like Matrix, the type of the data in this Tensor may implement no traits, in which case the
227 * tensor will be rather useless. If the type implements Clone most storage and accessor methods
228 * are defined and if the type implements Numeric then the tensor can be used in a mathematical
229 * way.
230 *
231 * Like Matrix, a Tensor must always contain at least one element, and it may not not have more
232 * elements than `std::isize::MAX`. Concerned readers should note that on a 64 bit computer this
233 * maximum value is 9,223,372,036,854,775,807 so running out of memory is likely to occur first.
234 *
235 * When doing numeric operations with Tensors you should be careful to not consume a tensor by
236 * accidentally using it by value. All the operations are also defined on references to tensors
237 * so you should favor &x + &y style notation for tensors you intend to continue using.
238 *
239 * See also:
240 * - [indexing]
241 */
242#[derive(Debug)]
243#[cfg_attr(feature = "serde", derive(Serialize))]
244pub struct Tensor<T, const D: usize> {
245 data: Vec<T>,
246 #[cfg_attr(feature = "serde", serde(with = "serde_arrays"))]
247 shape: [(Dimension, usize); D],
248 #[cfg_attr(feature = "serde", serde(skip))]
249 strides: [usize; D],
250}
251
252impl<T, const D: usize> Tensor<T, D> {
253 /**
254 * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
255 *
256 * The product of the dimension lengths corresponds to the number of elements the Tensor
257 * will store. Elements are stored in what would be row major order for a Matrix.
258 * Each step in memory through the N dimensions corresponds to incrementing the rightmost
259 * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
260 * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
261 * for the 25th and final element.
262 *
263 * # Panics
264 *
265 * - If the number of provided elements does not match the product of the dimension lengths.
266 * - If a dimension name is not unique
267 * - If any dimension has 0 elements
268 *
269 * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
270 * a single element (since the product of an empty list is 1).
271 */
272 #[track_caller]
273 pub fn from(shape: [(Dimension, usize); D], data: Vec<T>) -> Self {
274 InvalidShapeError::validate_dimensions_or_panic(&shape, data.len());
275 let strides = compute_strides(&shape);
276 Tensor {
277 data,
278 shape,
279 strides,
280 }
281 }
282
283 /**
284 * Creates a Tensor with a particular shape initialised from a function.
285 *
286 * The product of the dimension lengths corresponds to the number of elements the Tensor
287 * will store. Elements are stored in what would be row major order for a Matrix.
288 * Each step in memory through the N dimensions corresponds to incrementing the rightmost
289 * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
290 * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
291 * for the 25th and final element. These same indexes will be passed to the producer function
292 * to initialised the values for the Tensor.
293 *
294 * ```
295 * use easy_ml::tensors::Tensor;
296 * let tensor = Tensor::from_fn([("rows", 4), ("columns", 4)], |[r, c]| r * c);
297 * assert_eq!(
298 * tensor,
299 * Tensor::from([("rows", 4), ("columns", 4)], vec![
300 * 0, 0, 0, 0,
301 * 0, 1, 2, 3,
302 * 0, 2, 4, 6,
303 * 0, 3, 6, 9,
304 * ])
305 * );
306 * ```
307 *
308 * # Panics
309 *
310 * - If a dimension name is not unique
311 * - If any dimension has 0 elements
312 *
313 * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
314 * a single element (since the product of an empty list is 1).
315 */
316 #[track_caller]
317 pub fn from_fn<F>(shape: [(Dimension, usize); D], mut producer: F) -> Self
318 where
319 F: FnMut([usize; D]) -> T,
320 {
321 let length = dimensions::elements(&shape);
322 let mut data = Vec::with_capacity(length);
323 let iterator = ShapeIterator::from(shape);
324 for index in iterator {
325 data.push(producer(index));
326 }
327 Tensor::from(shape, data)
328 }
329
330 /**
331 * The shape of this tensor. Since Tensors are named Tensors, their shape is not just a
332 * list of lengths along each dimension, but instead a list of pairs of names and lengths.
333 *
334 * See also
335 * - [dimensions]
336 * - [indexing]
337 */
338 pub fn shape(&self) -> [(Dimension, usize); D] {
339 self.shape
340 }
341
342 /**
343 * Returns the length of the dimension name provided, if one is present in the Tensor.
344 *
345 * See also
346 * - [dimensions]
347 * - [indexing]
348 */
349 pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
350 dimensions::length_of(&self.shape, dimension)
351 }
352
353 /**
354 * Returns the last index of the dimension name provided, if one is present in the Tensor.
355 *
356 * This is always 1 less than the length, the 'index' in this sense is based on what the
357 * Tensor's shape is, not any implementation index.
358 *
359 * See also
360 * - [dimensions]
361 * - [indexing]
362 */
363 pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
364 dimensions::last_index_of(&self.shape, dimension)
365 }
366
367 /**
368 * A non panicking version of [from](Tensor::from) which returns `Result::Err` if the input
369 * is invalid.
370 *
371 * Creates a Tensor with a particular number of dimensions and lengths in each dimension.
372 *
373 * The product of the dimension lengths corresponds to the number of elements the Tensor
374 * will store. Elements are stored in what would be row major order for a Matrix.
375 * Each step in memory through the N dimensions corresponds to incrementing the rightmost
376 * index, hence a shape of `[("row", 5), ("column", 5)]` would mean the first 6 elements
377 * passed in the Vec would be for (0,0), (0,1), (0,2), (0,3), (0,4), (1,0) and so on to (4,4)
378 * for the 25th and final element.
379 *
380 * Returns the Err variant if
381 * - If the number of provided elements does not match the product of the dimension lengths.
382 * - If a dimension name is not unique
383 * - If any dimension has 0 elements
384 *
385 * Note that an empty list for dimensions is valid, and constructs a 0 dimensional tensor with
386 * a single element (since the product of an empty list is 1).
387 */
388 pub fn try_from(
389 shape: [(Dimension, usize); D],
390 data: Vec<T>,
391 ) -> Result<Self, InvalidShapeError<D>> {
392 let valid = InvalidShapeError::validate_dimensions(&shape, data.len());
393 if !valid {
394 return Err(InvalidShapeError::new(shape));
395 }
396 let strides = compute_strides(&shape);
397 Ok(Tensor {
398 data,
399 shape,
400 strides,
401 })
402 }
403
404 /// Unverified constructor for interal use when we know the dimensions/data/strides are
405 /// unchanged and don't need reverification
406 pub(crate) fn direct_from(
407 data: Vec<T>,
408 shape: [(Dimension, usize); D],
409 strides: [usize; D],
410 ) -> Self {
411 Tensor {
412 data,
413 shape,
414 strides,
415 }
416 }
417
418 /// Unverified constructor for interal use when we know the dimensions/data/strides are
419 /// the same as the existing instance and don't need reverification
420 #[allow(dead_code)] // pretty sure something else will want this in the future
421 pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> Self {
422 Tensor {
423 data,
424 shape: self.shape,
425 strides: self.strides,
426 }
427 }
428}
429
430impl<T> Tensor<T, 0> {
431 /**
432 * Creates a 0 dimensional tensor from some scalar
433 */
434 pub fn from_scalar(value: T) -> Tensor<T, 0> {
435 Tensor {
436 data: vec![value],
437 shape: [],
438 strides: [],
439 }
440 }
441
442 /**
443 * Returns the sole element of the 0 dimensional tensor.
444 */
445 pub fn into_scalar(self) -> T {
446 self.data
447 .into_iter()
448 .next()
449 .expect("Tensors always have at least 1 element")
450 }
451}
452
453impl<T> Tensor<T, 0>
454where
455 T: Clone,
456{
457 /**
458 * Returns a copy of the sole element in the 0 dimensional tensor.
459 */
460 pub fn scalar(&self) -> T {
461 self.data
462 .first()
463 .expect("Tensors always have at least 1 element")
464 .clone()
465 }
466}
467
468impl<T> From<T> for Tensor<T, 0> {
469 fn from(scalar: T) -> Tensor<T, 0> {
470 Tensor::from_scalar(scalar)
471 }
472}
473// TODO: See if we can find a way to write the reverse Tensor<T, 0> -> T conversion using From or Into (doesn't seem like we can?)
474
475// # Safety
476//
477// We promise to never implement interior mutability for Tensor.
478/**
479 * A Tensor implements TensorRef.
480 */
481unsafe impl<T, const D: usize> TensorRef<T, D> for Tensor<T, D> {
482 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
483 let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
484 self.data.get(i)
485 }
486
487 fn view_shape(&self) -> [(Dimension, usize); D] {
488 Tensor::shape(self)
489 }
490
491 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
492 unsafe {
493 // The point of get_reference_unchecked is no bounds checking, and therefore
494 // it does not make any sense to just use `unwrap` here. The trait documents that
495 // it's undefind behaviour to call this method with an out of bounds index, so we
496 // can assume the None case will never happen.
497 let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
498 self.data.get_unchecked(i)
499 }
500 }
501
502 fn data_layout(&self) -> DataLayout<D> {
503 // We always have our memory in most significant to least
504 DataLayout::Linear(std::array::from_fn(|i| self.shape[i].0))
505 }
506}
507
508// # Safety
509//
510// We promise to never implement interior mutability for Tensor.
511unsafe impl<T, const D: usize> TensorMut<T, D> for Tensor<T, D> {
512 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
513 let i = get_index_direct(&indexes, &self.strides, &self.shape)?;
514 self.data.get_mut(i)
515 }
516
517 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
518 unsafe {
519 // The point of get_reference_unchecked_mut is no bounds checking, and therefore
520 // it does not make any sense to just use `unwrap` here. The trait documents that
521 // it's undefind behaviour to call this method with an out of bounds index, so we
522 // can assume the None case will never happen.
523 let i = get_index_direct(&indexes, &self.strides, &self.shape).unwrap_unchecked();
524 self.data.get_unchecked_mut(i)
525 }
526 }
527}
528
529/**
530 * Any tensor of a Cloneable type implements Clone.
531 */
532impl<T: Clone, const D: usize> Clone for Tensor<T, D> {
533 fn clone(&self) -> Self {
534 self.map(|element| element)
535 }
536}
537
538/**
539 * Any tensor of a Displayable type implements Display
540 *
541 * You can control the precision of the formatting using format arguments, i.e.
542 * `format!("{:.3}", tensor)`
543 */
544impl<T: std::fmt::Display, const D: usize> std::fmt::Display for Tensor<T, D> {
545 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
546 crate::tensors::display::format_view(self, f)
547 }
548}
549
550/**
551 * Any 2 dimensional tensor can be converted to a matrix with rows equal to the length of the
552 * first dimension in the tensor, and columns equal to the length of the second.
553 */
554impl<T> From<Tensor<T, 2>> for crate::matrices::Matrix<T> {
555 fn from(tensor: Tensor<T, 2>) -> Self {
556 crate::matrices::Matrix::from_flat_row_major(
557 (tensor.shape[0].1, tensor.shape[1].1),
558 tensor.data,
559 )
560 }
561}
562
563pub(crate) fn compute_strides<const D: usize>(shape: &[(Dimension, usize); D]) -> [usize; D] {
564 std::array::from_fn(|d| shape.iter().skip(d + 1).map(|d| d.1).product())
565}
566
567/// returns the 1 dimensional index to use to get the requested index into some tensor
568#[inline]
569pub(crate) fn get_index_direct<const D: usize>(
570 // indexes to use
571 indexes: &[usize; D],
572 // strides for indexing into the tensor
573 strides: &[usize; D],
574 // shape of the tensor to index into
575 shape: &[(Dimension, usize); D],
576) -> Option<usize> {
577 let mut index = 0;
578 for d in 0..D {
579 let n = indexes[d];
580 if n >= shape[d].1 {
581 return None;
582 }
583 index += n * strides[d];
584 }
585 Some(index)
586}
587
588/// returns the 1 dimensional index to use to get the requested index into some tensor, without
589/// checking the indexes are within bounds for the shape.
590#[inline]
591fn get_index_direct_unchecked<const D: usize>(
592 // indexes to use
593 indexes: &[usize; D],
594 // strides for indexing into the tensor
595 strides: &[usize; D],
596) -> usize {
597 let mut index = 0;
598 for d in 0..D {
599 let n = indexes[d];
600 index += n * strides[d];
601 }
602 index
603}
604
605impl<T, const D: usize> Tensor<T, D> {
606 pub fn view(&self) -> TensorView<T, &Tensor<T, D>, D> {
607 TensorView::from(self)
608 }
609
610 pub fn view_mut(&mut self) -> TensorView<T, &mut Tensor<T, D>, D> {
611 TensorView::from(self)
612 }
613
614 pub fn view_owned(self) -> TensorView<T, Tensor<T, D>, D> {
615 TensorView::from(self)
616 }
617
618 /**
619 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
620 * to read values from this tensor.
621 *
622 * # Panics
623 *
624 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
625 */
626 #[track_caller]
627 pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &Tensor<T, D>, D> {
628 TensorAccess::from(self, dimensions)
629 }
630
631 /**
632 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
633 * to read or write values from this tensor.
634 *
635 * # Panics
636 *
637 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
638 */
639 #[track_caller]
640 pub fn index_by_mut(
641 &mut self,
642 dimensions: [Dimension; D],
643 ) -> TensorAccess<T, &mut Tensor<T, D>, D> {
644 TensorAccess::from(self, dimensions)
645 }
646
647 /**
648 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
649 * to read or write values from this tensor.
650 *
651 * # Panics
652 *
653 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
654 */
655 #[track_caller]
656 pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, Tensor<T, D>, D> {
657 TensorAccess::from(self, dimensions)
658 }
659
660 /**
661 * Creates a TensorAccess which will index into the dimensions this Tensor was created with
662 * in the same order as they were provided. See [TensorAccess::from_source_order].
663 */
664 pub fn index(&self) -> TensorAccess<T, &Tensor<T, D>, D> {
665 TensorAccess::from_source_order(self)
666 }
667
668 /**
669 * Creates a TensorAccess which will index into the dimensions this Tensor was
670 * created with in the same order as they were provided. The TensorAccess mutably borrows
671 * the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
672 */
673 pub fn index_mut(&mut self) -> TensorAccess<T, &mut Tensor<T, D>, D> {
674 TensorAccess::from_source_order(self)
675 }
676
677 /**
678 * Creates a TensorAccess which will index into the dimensions this Tensor was
679 * created with in the same order as they were provided. The TensorAccess takes ownership
680 * of the Tensor, and can therefore mutate it. See [TensorAccess::from_source_order].
681 */
682 pub fn index_owned(self) -> TensorAccess<T, Tensor<T, D>, D> {
683 TensorAccess::from_source_order(self)
684 }
685
686 /**
687 * Returns an iterator over references to the data in this Tensor.
688 */
689 pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, Tensor<T, D>, D> {
690 TensorReferenceIterator::from(self)
691 }
692
693 /**
694 * Returns an iterator over mutable references to the data in this Tensor.
695 */
696 pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<'_, T, Tensor<T, D>, D> {
697 TensorReferenceMutIterator::from(self)
698 }
699
700 /**
701 * Creates an iterator over the values in this Tensor.
702 */
703 pub fn iter_owned(self) -> TensorOwnedIterator<T, Tensor<T, D>, D>
704 where
705 T: Default,
706 {
707 TensorOwnedIterator::from(self)
708 }
709
710 // Non public index order reference iterator since we don't want to expose our implementation
711 // details to public API since then we could never change them.
712 pub(crate) fn direct_iter_reference(&self) -> std::slice::Iter<'_, T> {
713 self.data.iter()
714 }
715
716 // Non public index order reference iterator since we don't want to expose our implementation
717 // details to public API since then we could never change them.
718 pub(crate) fn direct_iter_reference_mut(&mut self) -> std::slice::IterMut<'_, T> {
719 self.data.iter_mut()
720 }
721
722 /**
723 * Renames the dimension names of the tensor without changing the lengths of the dimensions
724 * in the tensor or moving any data around.
725 *
726 * ```
727 * use easy_ml::tensors::Tensor;
728 * let mut tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6]);
729 * tensor.rename(["y", "z"]);
730 * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
731 * ```
732 *
733 * # Panics
734 *
735 * - If a dimension name is not unique
736 */
737 #[track_caller]
738 pub fn rename(&mut self, dimensions: [Dimension; D]) {
739 if crate::tensors::dimensions::has_duplicates_names(&dimensions) {
740 panic!("Dimension names must all be unique: {:?}", &dimensions);
741 }
742 #[allow(clippy::needless_range_loop)]
743 for d in 0..D {
744 self.shape[d].0 = dimensions[d];
745 }
746 }
747
748 /**
749 * Renames the dimension names of the tensor and returns it without changing the lengths
750 * of the dimensions in the tensor or moving any data around.
751 *
752 * ```
753 * use easy_ml::tensors::Tensor;
754 * let tensor = Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4, 5, 6])
755 * .rename_owned(["y", "z"]);
756 * assert_eq!([("y", 2), ("z", 3)], tensor.shape());
757 * ```
758 *
759 * # Panics
760 *
761 * - If a dimension name is not unique
762 */
763 #[track_caller]
764 pub fn rename_owned(mut self, dimensions: [Dimension; D]) -> Tensor<T, D> {
765 self.rename(dimensions);
766 self
767 }
768
769 /**
770 * Returns a TensorView with the dimension names of the shape renamed to the provided
771 * dimensions. The data of this tensor and the dimension lengths and order remain unchanged.
772 *
773 * This is a shorthand for constructing the TensorView from this Tensor.
774 *
775 * ```
776 * use easy_ml::tensors::Tensor;
777 * use easy_ml::tensors::views::{TensorView, TensorRename};
778 * let abc = Tensor::from([("a", 3), ("b", 3), ("c", 3)], (0..27).collect());
779 * let xyz = abc.rename_view(["x", "y", "z"]);
780 * let also_xyz = TensorView::from(TensorRename::from(&abc, ["x", "y", "z"]));
781 * assert_eq!(xyz, also_xyz);
782 * assert_eq!(xyz, Tensor::from([("x", 3), ("y", 3), ("z", 3)], (0..27).collect()));
783 * ```
784 *
785 * # Panics
786 *
787 * - If a dimension name is not unique
788 */
789 #[track_caller]
790 pub fn rename_view(
791 &self,
792 dimensions: [Dimension; D],
793 ) -> TensorView<T, TensorRename<T, &Tensor<T, D>, D>, D> {
794 TensorView::from(TensorRename::from(self, dimensions))
795 }
796
797 /**
798 * Changes the shape of the tensor without changing the number of dimensions or moving any
799 * data around.
800 *
801 * # Panics
802 *
803 * - If the number of provided elements in the new shape does not match the product of the
804 * dimension lengths in the existing tensor's shape.
805 * - If a dimension name is not unique
806 * - If any dimension has 0 elements
807 *
808 * ```
809 * use easy_ml::tensors::Tensor;
810 * let mut tensor = Tensor::from([("width", 2), ("height", 2)], vec![
811 * 1, 2,
812 * 3, 4
813 * ]);
814 * tensor.reshape_mut([("batch", 1), ("image", 4)]);
815 * assert_eq!(tensor, Tensor::from([("batch", 1), ("image", 4)], vec![ 1, 2, 3, 4 ]));
816 * ```
817 */
818 #[track_caller]
819 pub fn reshape_mut(&mut self, shape: [(Dimension, usize); D]) {
820 InvalidShapeError::validate_dimensions_or_panic(&shape, self.data.len());
821 let strides = compute_strides(&shape);
822 self.shape = shape;
823 self.strides = strides;
824 }
825
826 /**
827 * Consumes the tensor and changes the shape of the tensor without moving any
828 * data around. The new Tensor may also have a different number of dimensions.
829 *
830 * # Panics
831 *
832 * - If the number of provided elements in the new shape does not match the product of the
833 * dimension lengths in the existing tensor's shape.
834 * - If a dimension name is not unique
835 *
836 * ```
837 * use easy_ml::tensors::Tensor;
838 * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
839 * 1, 2,
840 * 3, 4
841 * ]);
842 * let flattened = tensor.reshape_owned([("image", 4)]);
843 * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
844 * ```
845 *
846 * See also [reshape_view_owned](Tensor::reshape_view_owned)
847 */
848 #[track_caller]
849 pub fn reshape_owned<const D2: usize>(self, shape: [(Dimension, usize); D2]) -> Tensor<T, D2> {
850 Tensor::from(shape, self.data)
851 }
852
853 /**
854 * Returns a TensorView with the dimensions changed to the provided shape without moving any
855 * data around. The new Tensor may also have a different number of dimensions.
856 *
857 * This is a shorthand for constructing the TensorView from this Tensor.
858 *
859 * # Panics
860 *
861 * - If the number of provided elements in the new shape does not match the product of the
862 * dimension lengths in the existing tensor's shape.
863 * - If a dimension name is not unique
864 *
865 * ```
866 * use easy_ml::tensors::Tensor;
867 * let tensor = Tensor::from([("width", 2), ("height", 2)], vec![
868 * 1, 2,
869 * 3, 4
870 * ]);
871 * let flattened = tensor.reshape_view([("image", 4)]);
872 * assert_eq!(flattened, Tensor::from([("image", 4)], vec![ 1, 2, 3, 4 ]));
873 * ```
874 */
875 pub fn reshape_view<const D2: usize>(
876 &self,
877 shape: [(Dimension, usize); D2],
878 ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, D2>, D2> {
879 TensorView::from(TensorReshape::from(self, shape))
880 }
881
882 /**
883 * Returns a TensorView with the dimensions changed to the provided shape without moving any
884 * data around. The new Tensor may also have a different number of dimensions.
885 *
886 * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
887 * mutably borrows this Tensor, and can therefore mutate it
888 *
889 * # Panics
890 *
891 * - If the number of provided elements in the new shape does not match the product of the
892 * dimension lengths in the existing tensor's shape.
893 * - If a dimension name is not unique
894 */
895 pub fn reshape_view_mut<const D2: usize>(
896 &mut self,
897 shape: [(Dimension, usize); D2],
898 ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, D2>, D2> {
899 TensorView::from(TensorReshape::from(self, shape))
900 }
901
902 /**
903 * Returns a TensorView with the dimensions changed to the provided shape without moving any
904 * data around. The new Tensor may also have a different number of dimensions.
905 *
906 * This is a shorthand for constructing the TensorView from this Tensor. The TensorReshape
907 * takes ownership of this Tensor, and can therefore mutate it
908 *
909 * # Panics
910 *
911 * - If the number of provided elements in the new shape does not match the product of the
912 * dimension lengths in the existing tensor's shape.
913 * - If a dimension name is not unique
914 *
915 * See also [reshape_owned](Tensor::reshape_owned)
916 */
917 pub fn reshape_view_owned<const D2: usize>(
918 self,
919 shape: [(Dimension, usize); D2],
920 ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, D2>, D2> {
921 TensorView::from(TensorReshape::from(self, shape))
922 }
923
924 /**
925 * Given the dimension name, returns a view of this tensor reshaped to one dimension
926 * with a length equal to the number of elements in this tensor.
927 */
928 pub fn flatten_view(
929 &self,
930 dimension: Dimension,
931 ) -> TensorView<T, TensorReshape<T, &Tensor<T, D>, D, 1>, 1> {
932 self.reshape_view([(dimension, dimensions::elements(&self.shape))])
933 }
934
935 /**
936 * Given the dimension name, returns a view of this tensor reshaped to one dimension
937 * with a length equal to the number of elements in this tensor.
938 */
939 pub fn flatten_view_mut(
940 &mut self,
941 dimension: Dimension,
942 ) -> TensorView<T, TensorReshape<T, &mut Tensor<T, D>, D, 1>, 1> {
943 self.reshape_view_mut([(dimension, dimensions::elements(&self.shape))])
944 }
945
946 /**
947 * Given the dimension name, returns a view of this tensor reshaped to one dimension
948 * with a length equal to the number of elements in this tensor.
949 *
950 * If you intend to query the tensor a lot after creating the view, consider
951 * using [flatten](Tensor::flatten) instead as it will have less overhead to
952 * index after creation.
953 */
954 pub fn flatten_view_owned(
955 self,
956 dimension: Dimension,
957 ) -> TensorView<T, TensorReshape<T, Tensor<T, D>, D, 1>, 1> {
958 let length = dimensions::elements(&self.shape);
959 self.reshape_view_owned([(dimension, length)])
960 }
961
962 /**
963 * Given the dimension name, returns a new tensor reshaped to one dimension
964 * with a length equal to the number of elements in this tensor.
965 */
966 pub fn flatten(self, dimension: Dimension) -> Tensor<T, 1> {
967 let length = dimensions::elements(&self.shape);
968 self.reshape_owned([(dimension, length)])
969 }
970
971 /**
972 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
973 * range from view. Error cases are documented on [TensorRange].
974 *
975 * This is a shorthand for constructing the TensorView from this Tensor.
976 *
977 * ```
978 * use easy_ml::tensors::Tensor;
979 * use easy_ml::tensors::views::{TensorView, TensorRange, IndexRange};
980 * # use easy_ml::tensors::views::IndexRangeValidationError;
981 * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
982 * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
983 * let cropped = samples.range([("x", IndexRange::new(1, 5)), ("y", IndexRange::new(1, 5))])?;
984 * let also_cropped = TensorView::from(
985 * TensorRange::from(&samples, [("x", 1..6), ("y", 1..6)])?
986 * );
987 * assert_eq!(cropped, also_cropped);
988 * assert_eq!(
989 * cropped.select([("batch", 0)]),
990 * Tensor::from([("x", 5), ("y", 5)], vec![
991 * 8, 9, 10, 11, 12,
992 * 15, 16, 17, 18, 19,
993 * 22, 23, 24, 25, 26,
994 * 29, 30, 31, 32, 33,
995 * 36, 37, 38, 39, 40
996 * ])
997 * );
998 * # Ok(())
999 * # }
1000 * ```
1001 */
1002 pub fn range<R, const P: usize>(
1003 &self,
1004 ranges: [(Dimension, R); P],
1005 ) -> Result<TensorView<T, TensorRange<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1006 where
1007 R: Into<IndexRange>,
1008 {
1009 TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1010 }
1011
1012 /**
1013 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1014 * range from view. Error cases are documented on [TensorRange]. The TensorRange
1015 * mutably borrows this Tensor, and can therefore mutate it
1016 *
1017 * This is a shorthand for constructing the TensorView from this Tensor.
1018 */
1019 pub fn range_mut<R, const P: usize>(
1020 &mut self,
1021 ranges: [(Dimension, R); P],
1022 ) -> Result<
1023 TensorView<T, TensorRange<T, &mut Tensor<T, D>, D>, D>,
1024 IndexRangeValidationError<D, P>,
1025 >
1026 where
1027 R: Into<IndexRange>,
1028 {
1029 TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1030 }
1031
1032 /**
1033 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
1034 * range from view. Error cases are documented on [TensorRange]. The TensorRange
1035 * takes ownership of this Tensor, and can therefore mutate it
1036 *
1037 * This is a shorthand for constructing the TensorView from this Tensor.
1038 */
1039 pub fn range_owned<R, const P: usize>(
1040 self,
1041 ranges: [(Dimension, R); P],
1042 ) -> Result<TensorView<T, TensorRange<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1043 where
1044 R: Into<IndexRange>,
1045 {
1046 TensorRange::from(self, ranges).map(|range| TensorView::from(range))
1047 }
1048
1049 /**
1050 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1051 * range from view. Error cases are documented on [TensorMask].
1052 *
1053 * This is a shorthand for constructing the TensorView from this Tensor.
1054 *
1055 * ```
1056 * use easy_ml::tensors::Tensor;
1057 * use easy_ml::tensors::views::{TensorView, TensorMask, IndexRange};
1058 * # use easy_ml::tensors::views::IndexRangeValidationError;
1059 * # fn main() -> Result<(), IndexRangeValidationError<3, 2>> {
1060 * let samples = Tensor::from([("batch", 5), ("x", 7), ("y", 7)], (0..(5 * 7 * 7)).collect());
1061 * let corners = samples.mask([("x", IndexRange::new(3, 2)), ("y", IndexRange::new(3, 2))])?;
1062 * let also_corners = TensorView::from(
1063 * TensorMask::from(&samples, [("x", 3..5), ("y", 3..5)])?
1064 * );
1065 * assert_eq!(corners, also_corners);
1066 * assert_eq!(
1067 * corners.select([("batch", 0)]),
1068 * Tensor::from([("x", 5), ("y", 5)], vec![
1069 * 0, 1, 2, 5, 6,
1070 * 7, 8, 9, 12, 13,
1071 * 14, 15, 16, 19, 20,
1072 *
1073 * 35, 36, 37, 40, 41,
1074 * 42, 43, 44, 47, 48
1075 * ])
1076 * );
1077 * # Ok(())
1078 * # }
1079 * ```
1080 */
1081 pub fn mask<R, const P: usize>(
1082 &self,
1083 masks: [(Dimension, R); P],
1084 ) -> Result<TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1085 where
1086 R: Into<IndexRange>,
1087 {
1088 TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1089 }
1090
1091 /**
1092 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1093 * range from view. Error cases are documented on [TensorMask]. The TensorMask
1094 * mutably borrows this Tensor, and can therefore mutate it
1095 *
1096 * This is a shorthand for constructing the TensorView from this Tensor.
1097 */
1098 pub fn mask_mut<R, const P: usize>(
1099 &mut self,
1100 masks: [(Dimension, R); P],
1101 ) -> Result<
1102 TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D>,
1103 IndexRangeValidationError<D, P>,
1104 >
1105 where
1106 R: Into<IndexRange>,
1107 {
1108 TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1109 }
1110
1111 /**
1112 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
1113 * range from view. Error cases are documented on [TensorMask]. The TensorMask
1114 * takes ownership of this Tensor, and can therefore mutate it
1115 *
1116 * This is a shorthand for constructing the TensorView from this Tensor.
1117 */
1118 pub fn mask_owned<R, const P: usize>(
1119 self,
1120 masks: [(Dimension, R); P],
1121 ) -> Result<TensorView<T, TensorMask<T, Tensor<T, D>, D>, D>, IndexRangeValidationError<D, P>>
1122 where
1123 R: Into<IndexRange>,
1124 {
1125 TensorMask::from(self, masks).map(|mask| TensorView::from(mask))
1126 }
1127
1128 /**
1129 * Returns a TensorView with a mask taken in the provided dimension, hiding
1130 * all but the start_and_end number of values at the start and end of the
1131 * dimension from view.
1132 *
1133 * This is a shorthand for constructing the TensorView from this Tensor.
1134 *
1135 * # Panics
1136 *
1137 * - If the start_and_end value is 0 - this is not a valid mask as it would
1138 * hide all elements
1139 * - If the dimension is not in the tensor's shape.
1140 *
1141 * ```
1142 * use easy_ml::tensors::Tensor;
1143 * let samples = Tensor::from([("batch", 5), ("x", 2), ("y", 2)], (0..20).collect());
1144 * let shortlist = samples.start_and_end_of("batch", 1);
1145 * assert_eq!(
1146 * shortlist,
1147 * Tensor::from([("batch", 2), ("x", 2), ("y", 2)], vec![
1148 * 0, 1,
1149 * 2, 3,
1150 *
1151 * 16, 17,
1152 * 18, 19,
1153 * ])
1154 * )
1155 * ```
1156 */
1157 #[track_caller]
1158 pub fn start_and_end_of(
1159 &self,
1160 dimension: Dimension,
1161 start_and_end: usize,
1162 ) -> TensorView<T, TensorMask<T, &Tensor<T, D>, D>, D> {
1163 TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
1164 }
1165
1166 /**
1167 * Returns a TensorView with a mask taken in the provided dimension, hiding
1168 * all but the start_and_end number of values at the start and end of the
1169 * dimension from view. The TensorMask mutably borrows this Tensor, and can
1170 * therefore mutate it
1171 *
1172 * This is a shorthand for constructing the TensorView from this Tensor.
1173 *
1174 * # Panics
1175 *
1176 * - If the start_and_end value is 0 - this is not a valid mask as it would
1177 * hide all elements
1178 * - If the dimension is not in the tensor's shape.
1179 */
1180 #[track_caller]
1181 pub fn start_and_end_of_mut(
1182 &mut self,
1183 dimension: Dimension,
1184 start_and_end: usize,
1185 ) -> TensorView<T, TensorMask<T, &mut Tensor<T, D>, D>, D> {
1186 TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
1187 }
1188
1189 /**
1190 * Returns a TensorView with a mask taken in the provided dimension, hiding
1191 * all but the start_and_end number of values at the start and end of the
1192 * dimension from view. The TensorMask takes ownership of this Tensor, and
1193 * can therefore mutate it
1194 *
1195 * This is a shorthand for constructing the TensorView from this Tensor.
1196 *
1197 * # Panics
1198 *
1199 * - If the start_and_end value is 0 - this is not a valid mask as it would
1200 * hide all elements
1201 * - If the dimension is not in the tensor's shape.
1202 */
1203 #[track_caller]
1204 pub fn start_and_end_of_owned(
1205 self,
1206 dimension: Dimension,
1207 start_and_end: usize,
1208 ) -> TensorView<T, TensorMask<T, Tensor<T, D>, D>, D> {
1209 TensorMask::panicking_start_and_end_of(self, dimension, start_and_end)
1210 }
1211
1212 /**
1213 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1214 * order. The data of this tensor and the dimension lengths remain unchanged.
1215 *
1216 * This is a shorthand for constructing the TensorView from this Tensor.
1217 *
1218 * ```
1219 * use easy_ml::tensors::Tensor;
1220 * use easy_ml::tensors::views::{TensorView, TensorReverse};
1221 * let ab = Tensor::from([("a", 2), ("b", 3)], (0..6).collect());
1222 * let reversed = ab.reverse(&["a"]);
1223 * let also_reversed = TensorView::from(TensorReverse::from(&ab, &["a"]));
1224 * assert_eq!(reversed, also_reversed);
1225 * assert_eq!(
1226 * reversed,
1227 * Tensor::from(
1228 * [("a", 2), ("b", 3)],
1229 * vec![
1230 * 3, 4, 5,
1231 * 0, 1, 2,
1232 * ]
1233 * )
1234 * );
1235 * ```
1236 *
1237 * # Panics
1238 *
1239 * - If a dimension name is not in the tensor's shape or is repeated.
1240 */
1241 #[track_caller]
1242 pub fn reverse(
1243 &self,
1244 dimensions: &[Dimension],
1245 ) -> TensorView<T, TensorReverse<T, &Tensor<T, D>, D>, D> {
1246 TensorView::from(TensorReverse::from(self, dimensions))
1247 }
1248
1249 /**
1250 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1251 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1252 * mutably borrows this Tensor, and can therefore mutate it
1253 *
1254 * This is a shorthand for constructing the TensorView from this Tensor.
1255 *
1256 * # Panics
1257 *
1258 * - If a dimension name is not in the tensor's shape or is repeated.
1259 */
1260 #[track_caller]
1261 pub fn reverse_mut(
1262 &mut self,
1263 dimensions: &[Dimension],
1264 ) -> TensorView<T, TensorReverse<T, &mut Tensor<T, D>, D>, D> {
1265 TensorView::from(TensorReverse::from(self, dimensions))
1266 }
1267
1268 /**
1269 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
1270 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
1271 * takes ownership of this Tensor, and can therefore mutate it
1272 *
1273 * This is a shorthand for constructing the TensorView from this Tensor.
1274 *
1275 * # Panics
1276 *
1277 * - If a dimension name is not in the tensor's shape or is repeated.
1278 */
1279 #[track_caller]
1280 pub fn reverse_owned(
1281 self,
1282 dimensions: &[Dimension],
1283 ) -> TensorView<T, TensorReverse<T, Tensor<T, D>, D>, D> {
1284 TensorView::from(TensorReverse::from(self, dimensions))
1285 }
1286
1287 /**
1288 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1289 * mapped by a function. The value pairs are not copied for you, if you're using `Copy` types
1290 * or need to clone the values anyway, you can use
1291 * [`Tensor::elementwise`](Tensor::elementwise) instead.
1292 *
1293 * # Generics
1294 *
1295 * This method can be called with any right hand side that can be converted to a TensorView,
1296 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1297 *
1298 * # Panics
1299 *
1300 * If the two tensors have different shapes.
1301 */
1302 #[track_caller]
1303 pub fn elementwise_reference<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1304 where
1305 I: Into<TensorView<T, S, D>>,
1306 S: TensorRef<T, D>,
1307 M: Fn(&T, &T) -> T,
1308 {
1309 self.elementwise_reference_less_generic(rhs.into(), mapping_function)
1310 }
1311
1312 /**
1313 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1314 * mapped by a function. The mapping function also receives each index corresponding to the
1315 * value pairs. The value pairs are not copied for you, if you're using `Copy` types
1316 * or need to clone the values anyway, you can use
1317 * [`Tensor::elementwise_with_index`](Tensor::elementwise_with_index) instead.
1318 *
1319 * # Generics
1320 *
1321 * This method can be called with any right hand side that can be converted to a TensorView,
1322 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1323 *
1324 * # Panics
1325 *
1326 * If the two tensors have different shapes.
1327 */
1328 #[track_caller]
1329 pub fn elementwise_reference_with_index<S, I, M>(
1330 &self,
1331 rhs: I,
1332 mapping_function: M,
1333 ) -> Tensor<T, D>
1334 where
1335 I: Into<TensorView<T, S, D>>,
1336 S: TensorRef<T, D>,
1337 M: Fn([usize; D], &T, &T) -> T,
1338 {
1339 self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
1340 }
1341
1342 #[track_caller]
1343 fn elementwise_reference_less_generic<S, M>(
1344 &self,
1345 rhs: TensorView<T, S, D>,
1346 mapping_function: M,
1347 ) -> Tensor<T, D>
1348 where
1349 S: TensorRef<T, D>,
1350 M: Fn(&T, &T) -> T,
1351 {
1352 let left_shape = self.shape();
1353 let right_shape = rhs.shape();
1354 if left_shape != right_shape {
1355 panic!(
1356 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1357 left_shape, right_shape
1358 );
1359 }
1360 let mapped = self
1361 .direct_iter_reference()
1362 .zip(rhs.iter_reference())
1363 .map(|(x, y)| mapping_function(x, y))
1364 .collect();
1365 // We're not changing the shape of the Tensor, so don't need to revalidate
1366 Tensor::direct_from(mapped, self.shape, self.strides)
1367 }
1368
1369 #[track_caller]
1370 fn elementwise_reference_less_generic_with_index<S, M>(
1371 &self,
1372 rhs: TensorView<T, S, D>,
1373 mapping_function: M,
1374 ) -> Tensor<T, D>
1375 where
1376 S: TensorRef<T, D>,
1377 M: Fn([usize; D], &T, &T) -> T,
1378 {
1379 let left_shape = self.shape();
1380 let right_shape = rhs.shape();
1381 if left_shape != right_shape {
1382 panic!(
1383 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
1384 left_shape, right_shape
1385 );
1386 }
1387 // we just checked both shapes were the same, so we don't need to propagate indexes
1388 // for both tensors because they'll be identical
1389 let mapped = self
1390 .direct_iter_reference()
1391 .zip(rhs.iter_reference().with_index())
1392 .map(|(x, (i, y))| mapping_function(i, x, y))
1393 .collect();
1394 // We're not changing the shape of the Tensor, so don't need to revalidate
1395 Tensor::direct_from(mapped, self.shape, self.strides)
1396 }
1397
1398 /**
1399 * Returns a TensorView which makes the order of the data in this tensor appear to be in
1400 * a different order. The order of the dimension names is unchanged, although their lengths
1401 * may swap.
1402 *
1403 * This is a shorthand for constructing the TensorView from this Tensor.
1404 *
1405 * See also: [transpose](Tensor::transpose), [TensorTranspose]
1406 *
1407 * # Panics
1408 *
1409 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1410 * order need not match.
1411 */
1412 pub fn transpose_view(
1413 &self,
1414 dimensions: [Dimension; D],
1415 ) -> TensorView<T, TensorTranspose<T, &Tensor<T, D>, D>, D> {
1416 TensorView::from(TensorTranspose::from(self, dimensions))
1417 }
1418}
1419
1420impl<T, const D: usize> Tensor<T, D>
1421where
1422 T: Clone,
1423{
1424 /**
1425 * Creates a tensor with a particular number of dimensions and length in each dimension
1426 * with all elements initialised to the provided value.
1427 *
1428 * # Panics
1429 *
1430 * - If a dimension name is not unique
1431 * - If any dimension has 0 elements
1432 */
1433 #[track_caller]
1434 pub fn empty(shape: [(Dimension, usize); D], value: T) -> Self {
1435 let elements = crate::tensors::dimensions::elements(&shape);
1436 Tensor::from(shape, vec![value; elements])
1437 }
1438
1439 /**
1440 * Gets a copy of the first value in this tensor.
1441 * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
1442 * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
1443 */
1444 pub fn first(&self) -> T {
1445 self.data
1446 .first()
1447 .expect("Tensors always have at least 1 element")
1448 .clone()
1449 }
1450
1451 /**
1452 * Returns a new Tensor which has the same data as this tensor, but with the order of data
1453 * changed. The order of the dimension names is unchanged, although their lengths may swap.
1454 *
1455 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1456 * `transpose(["y", "x"])` which would return a new tensor with a shape of
1457 * `[("x", y), ("y", x)]` where every (x,y) of its data corresponds to (y,x) in the original.
1458 *
1459 * This method need not shift *all* the dimensions though, you could also swap the width
1460 * and height of images in a tensor with a shape of
1461 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `transpose(["batch", "w", "h", "c"])`
1462 * which would return a new tensor where all the images have been swapped over the diagonal.
1463 *
1464 * See also: [TensorAccess], [reorder](Tensor::reorder)
1465 *
1466 * # Panics
1467 *
1468 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1469 * order need not match (and if the order does match, this function is just an expensive
1470 * clone).
1471 */
1472 #[track_caller]
1473 pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1474 let shape = self.shape;
1475 let mut reordered = self.reorder(dimensions);
1476 // Transposition is essentially reordering, but we retain the dimension name ordering
1477 // of the original order, this means we may swap dimension lengths, but the dimensions
1478 // will not change order.
1479 #[allow(clippy::needless_range_loop)]
1480 for d in 0..D {
1481 reordered.shape[d].0 = shape[d].0;
1482 }
1483 reordered
1484 }
1485
1486 /**
1487 * Modifies this tensor to have the same data as before, but with the order of data changed.
1488 * The order of the dimension names is unchanged, although their lengths may swap.
1489 *
1490 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1491 * `transpose_mut(["y", "x"])` which would edit the tensor, updating its shape to
1492 * `[("x", y), ("y", x)]`, so every (x,y) of its data corresponds to (y,x) before the
1493 * transposition.
1494 *
1495 * The order swapping will try to be in place, but this is currently only supported for
1496 * square tensors with 2 dimensions. Other types of tensors will not be transposed in place.
1497 *
1498 * # Panics
1499 *
1500 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1501 * order need not match (and if the order does match, this function is just an expensive
1502 * clone).
1503 */
1504 #[track_caller]
1505 pub fn transpose_mut(&mut self, dimensions: [Dimension; D]) {
1506 let shape = self.shape;
1507 self.reorder_mut(dimensions);
1508 // Transposition is essentially reordering, but we retain the dimension name ordering
1509 // we had before, this means we may swap dimension lengths, but the dimensions
1510 // will not change order.
1511 #[allow(clippy::needless_range_loop)]
1512 for d in 0..D {
1513 self.shape[d].0 = shape[d].0;
1514 }
1515 }
1516
1517 /**
1518 * Returns a new Tensor which has the same data as this tensor, but with the order of the
1519 * dimensions and corresponding order of data changed.
1520 *
1521 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1522 * `reorder(["y", "x"])` which would return a new tensor with a shape of
1523 * `[("y", y), ("x", x)]` where every (y,x) of its data corresponds to (x,y) in the original.
1524 *
1525 * This method need not shift *all* the dimensions though, you could also swap the width
1526 * and height of images in a tensor with a shape of
1527 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `reorder(["batch", "w", "h", "c"])`
1528 * which would return a new tensor where every (b,w,h,c) of its data corresponds to (b,h,w,c)
1529 * in the original.
1530 *
1531 * See also: [TensorAccess], [transpose](Tensor::transpose)
1532 *
1533 * # Panics
1534 *
1535 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1536 * order need not match (and if the order does match, this function is just an expensive
1537 * clone).
1538 */
1539 #[track_caller]
1540 pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1541 let reorderd = match TensorAccess::try_from(&self, dimensions) {
1542 Ok(reordered) => reordered,
1543 Err(_error) => panic!(
1544 "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1545 dimensions, &self.shape,
1546 ),
1547 };
1548 let reorderd_shape = reorderd.shape();
1549 Tensor::from(reorderd_shape, reorderd.iter().collect())
1550 }
1551
1552 /**
1553 * Modifies this tensor to have the same data as before, but with the order of the
1554 * dimensions and corresponding order of data changed.
1555 *
1556 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1557 * `reorder_mut(["y", "x"])` which would edit the tensor, updating its shape to
1558 * `[("y", y), ("x", x)]`, so every (y,x) of its data corresponds to (x,y) before the
1559 * transposition.
1560 *
1561 * The order swapping will try to be in place, but this is currently only supported for
1562 * square tensors with 2 dimensions. Other types of tensors will not be reordered in place.
1563 *
1564 * # Panics
1565 *
1566 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1567 * order need not match (and if the order does match, this function is just an expensive
1568 * clone).
1569 */
1570 #[track_caller]
1571 pub fn reorder_mut(&mut self, dimensions: [Dimension; D]) {
1572 use crate::tensors::dimensions::DimensionMappings;
1573 if D == 2 && crate::tensors::dimensions::is_square(&self.shape) {
1574 let dimension_mapping = match DimensionMappings::new(&self.shape, &dimensions) {
1575 Some(dimension_mapping) => dimension_mapping,
1576 None => panic!(
1577 "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1578 dimensions, &self.shape,
1579 ),
1580 };
1581
1582 let shape = dimension_mapping.map_shape_to_requested(&self.shape);
1583 let shape_iterator = ShapeIterator::from(shape);
1584
1585 for index in shape_iterator {
1586 let i = index[0];
1587 let j = index[1];
1588 if j >= i {
1589 let mapped_index = dimension_mapping.map_dimensions_to_source(&index);
1590 // Swap elements from the upper triangle (using index order of the actual tensor's
1591 // shape)
1592 let temp = self.get_reference(index).unwrap().clone();
1593 // tensor[i,j] becomes tensor[mapping(i,j)]
1594 *self.get_reference_mut(index).unwrap() =
1595 self.get_reference(mapped_index).unwrap().clone();
1596 // tensor[mapping(i,j)] becomes tensor[i,j]
1597 *self.get_reference_mut(mapped_index).unwrap() = temp;
1598 // If the mapping is a noop we've assigned i,j to i,j
1599 // If the mapping is i,j -> j,i we've assigned i,j to j,i and j,i to i,j
1600 }
1601 }
1602
1603 // now update our shape and strides to match
1604 self.shape = shape;
1605 self.strides = compute_strides(&shape);
1606 } else {
1607 // fallback to allocating a new reordered tensor
1608 let reordered = self.reorder(dimensions);
1609 self.data = reordered.data;
1610 self.shape = reordered.shape;
1611 self.strides = reordered.strides;
1612 }
1613 }
1614
1615 /**
1616 * Returns an iterator over copies of the data in this Tensor.
1617 */
1618 pub fn iter(&self) -> TensorIterator<'_, T, Tensor<T, D>, D> {
1619 TensorIterator::from(self)
1620 }
1621
1622 /**
1623 * Creates and returns a new tensor with all values from the original with the
1624 * function applied to each. This can be used to change the type of the tensor
1625 * such as creating a mask:
1626 * ```
1627 * use easy_ml::tensors::Tensor;
1628 * let x = Tensor::from([("a", 2), ("b", 2)], vec![
1629 * 0.0, 1.2,
1630 * 5.8, 6.9
1631 * ]);
1632 * let y = x.map(|element| element > 2.0);
1633 * let result = Tensor::from([("a", 2), ("b", 2)], vec![
1634 * false, false,
1635 * true, true
1636 * ]);
1637 * assert_eq!(&y, &result);
1638 * ```
1639 */
1640 pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
1641 let mapped = self
1642 .data
1643 .iter()
1644 .map(|x| mapping_function(x.clone()))
1645 .collect();
1646 // We're not changing the shape of the Tensor, so don't need to revalidate
1647 Tensor::direct_from(mapped, self.shape, self.strides)
1648 }
1649
1650 /**
1651 * Creates and returns a new tensor with all values from the original and
1652 * the index of each value mapped by a function.
1653 */
1654 pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
1655 let mapped = self
1656 .iter()
1657 .with_index()
1658 .map(|(i, x)| mapping_function(i, x))
1659 .collect();
1660 // We're not changing the shape of the Tensor, so don't need to revalidate
1661 Tensor::direct_from(mapped, self.shape, self.strides)
1662 }
1663
1664 /**
1665 * Applies a function to all values in the tensor, modifying
1666 * the tensor in place.
1667 */
1668 pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1669 for value in self.data.iter_mut() {
1670 *value = mapping_function(value.clone());
1671 }
1672 }
1673
1674 /**
1675 * Applies a function to all values and each value's index in the tensor, modifying
1676 * the tensor in place.
1677 */
1678 pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
1679 self.iter_reference_mut()
1680 .with_index()
1681 .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
1682 }
1683
1684 /**
1685 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1686 * mapped by a function.
1687 *
1688 * ```
1689 * use easy_ml::tensors::Tensor;
1690 * let lhs = Tensor::from([("a", 4)], vec![1, 2, 3, 4]);
1691 * let rhs = Tensor::from([("a", 4)], vec![0, 1, 2, 3]);
1692 * let multiplied = lhs.elementwise(&rhs, |l, r| l * r);
1693 * assert_eq!(
1694 * multiplied,
1695 * Tensor::from([("a", 4)], vec![0, 2, 6, 12])
1696 * );
1697 * ```
1698 *
1699 * # Generics
1700 *
1701 * This method can be called with any right hand side that can be converted to a TensorView,
1702 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1703 *
1704 * # Panics
1705 *
1706 * If the two tensors have different shapes.
1707 */
1708 #[track_caller]
1709 pub fn elementwise<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1710 where
1711 I: Into<TensorView<T, S, D>>,
1712 S: TensorRef<T, D>,
1713 M: Fn(T, T) -> T,
1714 {
1715 self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
1716 mapping_function(lhs.clone(), rhs.clone())
1717 })
1718 }
1719
1720 /**
1721 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1722 * mapped by a function. The mapping function also receives each index corresponding to the
1723 * value pairs.
1724 *
1725 * # Generics
1726 *
1727 * This method can be called with any right hand side that can be converted to a TensorView,
1728 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1729 *
1730 * # Panics
1731 *
1732 * If the two tensors have different shapes.
1733 */
1734 #[track_caller]
1735 pub fn elementwise_with_index<S, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1736 where
1737 I: Into<TensorView<T, S, D>>,
1738 S: TensorRef<T, D>,
1739 M: Fn([usize; D], T, T) -> T,
1740 {
1741 self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
1742 mapping_function(i, lhs.clone(), rhs.clone())
1743 })
1744 }
1745}
1746
1747impl<T> Tensor<T, 1>
1748where
1749 T: Numeric,
1750 for<'a> &'a T: NumericRef<T>,
1751{
1752 /**
1753 * Computes the scalar product of two equal length vectors. For two vectors `[a,b,c]` and
1754 * `[d,e,f]`, returns `a*d + b*e + c*f`. This is also known as the dot product.
1755 *
1756 * ```
1757 * use easy_ml::tensors::Tensor;
1758 * let tensor = Tensor::from([("sequence", 5)], vec![3, 4, 5, 6, 7]);
1759 * assert_eq!(tensor.scalar_product(&tensor), 3*3 + 4*4 + 5*5 + 6*6 + 7*7);
1760 * ```
1761 *
1762 * # Generics
1763 *
1764 * This method can be called with any right hand side that can be converted to a TensorView,
1765 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1766 *
1767 * # Panics
1768 *
1769 * If the two vectors are not of equal length or their dimension names do not match.
1770 */
1771 // Would like this impl block to be in operations.rs too but then it would show first in the
1772 // Tensor docs which isn't ideal
1773 pub fn scalar_product<S, I>(&self, rhs: I) -> T
1774 where
1775 I: Into<TensorView<T, S, 1>>,
1776 S: TensorRef<T, 1>,
1777 {
1778 self.scalar_product_less_generic(rhs.into())
1779 }
1780}
1781
1782impl<T> Tensor<T, 2>
1783where
1784 T: Numeric,
1785 for<'a> &'a T: NumericRef<T>,
1786{
1787 /**
1788 * Returns the determinant of this square matrix, or None if the matrix
1789 * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant_tensor())
1790 */
1791 pub fn determinant(&self) -> Option<T> {
1792 linear_algebra::determinant_tensor::<T, _, _>(self)
1793 }
1794
1795 /**
1796 * Computes the inverse of a matrix provided that it exists. To have an inverse a
1797 * matrix must be square (same number of rows and columns) and it must also have a
1798 * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse_tensor())
1799 */
1800 pub fn inverse(&self) -> Option<Tensor<T, 2>> {
1801 linear_algebra::inverse_tensor::<T, _, _>(self)
1802 }
1803
1804 /**
1805 * Computes the covariance matrix for this feature matrix along the specified feature
1806 * dimension in this matrix. See [`linear_algebra`](crate::linear_algebra::covariance()).
1807 */
1808 pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
1809 linear_algebra::covariance::<T, _, _>(self, feature_dimension)
1810 }
1811}
1812
1813// FIXME: want this to be callable in the main numeric impl block
1814impl<T> Tensor<T, 2>
1815where
1816 T: Numeric,
1817{
1818 /**
1819 * Creates a diagonal matrix of the provided size with the diagonal elements
1820 * set to the provided value and all other elements in the tensor set to 0.
1821 * A diagonal matrix is always square.
1822 *
1823 * The size is still taken as a shape to facilitate creating a diagonal matrix
1824 * from the dimensionality of an existing one. If the provided value is 1 then
1825 * this will create an identity matrix.
1826 *
1827 * A 3 x 3 identity matrix:
1828 * ```ignore
1829 * [
1830 * 1, 0, 0
1831 * 0, 1, 0
1832 * 0, 0, 1
1833 * ]
1834 * ```
1835 *
1836 * # Panics
1837 *
1838 * - If the shape is not square.
1839 * - If a dimension name is not unique
1840 * - If any dimension has 0 elements
1841 */
1842 #[track_caller]
1843 pub fn diagonal(shape: [(Dimension, usize); 2], value: T) -> Tensor<T, 2> {
1844 if !crate::tensors::dimensions::is_square(&shape) {
1845 panic!("Shape must be square: {:?}", shape);
1846 }
1847 let mut tensor = Tensor::empty(shape, T::zero());
1848 for ([r, c], x) in tensor.iter_reference_mut().with_index() {
1849 if r == c {
1850 *x = value.clone();
1851 }
1852 }
1853 tensor
1854 }
1855}
1856
1857impl<T> Tensor<T, 2> {
1858 /**
1859 * Converts this 2 dimensional Tensor into a Matrix.
1860 *
1861 * This is a wrapper around the `From<Tensor<T, 2>>` implementation.
1862 *
1863 * The Matrix will have the data in the same order, with rows equal to the length of
1864 * the first dimension in the tensor, and columns equal to the length of the second.
1865 */
1866 pub fn into_matrix(self) -> crate::matrices::Matrix<T> {
1867 self.into()
1868 }
1869}
1870
1871/**
1872 * Methods for tensors with numerical real valued types, such as f32 or f64.
1873 *
1874 * This excludes signed and unsigned integers as they do not support decimal
1875 * precision and hence can't be used for operations like square roots.
1876 *
1877 * Third party fixed precision and infinite precision decimal types should
1878 * be able to implement all of the methods for [Real] and then utilise these functions.
1879 */
1880impl<T: Real> Tensor<T, 1>
1881where
1882 for<'a> &'a T: RealRef<T>,
1883{
1884 /**
1885 * Computes the [L2 norm](https://en.wikipedia.org/wiki/Euclidean_vector#Length)
1886 * of this vector, also referred to as the length or magnitude,
1887 * and written as ||x||, or sometimes |x|.
1888 *
1889 * ||**a**|| = sqrt(a<sub>1</sub><sup>2</sup> + a<sub>2</sub><sup>2</sup> + a<sub>3</sub><sup>2</sup>...) = sqrt(**a**<sup>T</sup> * **a**)
1890 *
1891 * This is a shorthand for `(x.iter().map(|x| x * x).sum().sqrt()`, ie
1892 * the square root of the dot product of a vector with itself.
1893 *
1894 * The euclidean length can be used to compute a
1895 * [unit vector](https://en.wikipedia.org/wiki/Unit_vector), that is, a
1896 * vector with length of 1. This should not be confused with a unit matrix,
1897 * which is another name for an identity matrix.
1898 *
1899 * ```
1900 * use easy_ml::tensors::Tensor;
1901 * let a = Tensor::from([("data", 3)], vec![ 1.0, 2.0, 3.0 ]);
1902 * let length = a.euclidean_length(); // (1^2 + 2^2 + 3^2)^0.5
1903 * let unit = a.map(|x| x / length);
1904 * assert_eq!(unit.euclidean_length(), 1.0);
1905 * ```
1906 */
1907 // TODO: Scalar ops for tensors
1908 pub fn euclidean_length(&self) -> T {
1909 self.direct_iter_reference()
1910 .map(|x| x * x)
1911 .sum::<T>()
1912 .sqrt()
1913 }
1914}
1915
1916#[cfg(feature = "serde")]
1917mod serde_impls {
1918 use crate::tensors::{Dimension, InvalidShapeError, Tensor};
1919 use serde::Deserialize;
1920 use std::convert::TryFrom;
1921
1922 /**
1923 * Deserialized data for a Tensor. Can be converted into a Tensor by
1924 * providing `&'static str` dimension names.
1925 *
1926 * This struct borrows the string references, so will not be supported
1927 * by any serde library that doesn't support deserializing to borrowed types.
1928 * However, if used with a serde library that can deserialize to borrowed
1929 * types, you could create a TensorDeserialize with a static lifetime by
1930 * using a static input, and then it will be possible to convert to a
1931 * Tensor using the [TryFrom](TryFrom) implementation without providing
1932 * dimension names to use via [into_tensor](TensorDeserialize::into_tensor).
1933 */
1934 #[derive(Deserialize, Debug)]
1935 #[serde(rename = "Tensor")]
1936 pub struct TensorDeserialize<'a, T, const D: usize> {
1937 data: Vec<T>,
1938 #[serde(with = "serde_arrays")]
1939 #[serde(borrow)]
1940 shape: [(&'a str, usize); D],
1941 }
1942
1943 /**
1944 * Deserialized data for a Tensor. Can be converted into a Tensor by providing
1945 * `&'static str` dimension names.
1946 *
1947 * This struct owns the string references, so can be parsed by serde
1948 * libraries that don't support deserializing to borrowed types. However,
1949 * we can't convert from `String` to `&'static str` without leaking, so
1950 * there is no TryFrom implementation to convert to a Tensor, and
1951 * instead new static dimension names must always be provided via
1952 * [into_tensor](TensorDeserializeOwned::into_tensor).
1953 */
1954 #[derive(Deserialize, Debug)]
1955 #[serde(rename = "Tensor")]
1956 pub struct TensorDeserializeOwned<T, const D: usize> {
1957 data: Vec<T>,
1958 #[serde(with = "serde_arrays")]
1959 shape: [(String, usize); D],
1960 }
1961
1962 impl<'a, T, const D: usize> TensorDeserialize<'a, T, D> {
1963 /**
1964 * Converts this deserialised Tensor data to a Tensor, using the provided `&'static str`
1965 * dimension names in place of what was serialised (which wouldn't necessarily live
1966 * long enough).
1967 */
1968 pub fn into_tensor(
1969 self,
1970 dimensions: [Dimension; D],
1971 ) -> Result<Tensor<T, D>, InvalidShapeError<D>> {
1972 let shape = std::array::from_fn(|d| (dimensions[d], self.shape[d].1));
1973 // Safety: Use the normal constructor that performs validation to prevent invalid
1974 // serialized data being created as a Tensor, which would then break all the
1975 // code that's relying on these invariants.
1976 // By never serialising the strides in the first place, we reduce the possibility
1977 // of creating invalid serialised representations at the slight increase in
1978 // serialisation work.
1979 Tensor::try_from(shape, self.data)
1980 }
1981 }
1982
1983 impl<T, const D: usize> TensorDeserializeOwned<T, D> {
1984 /**
1985 * Converts this deserialised Tensor data to a Tensor, using the provided `&'static str`
1986 * dimension names in place of what was serialised (which wouldn't live
1987 * long enough).
1988 */
1989 pub fn into_tensor(
1990 self,
1991 dimensions: [Dimension; D],
1992 ) -> Result<Tensor<T, D>, InvalidShapeError<D>> {
1993 let shape = std::array::from_fn(|d| (dimensions[d], self.shape[d].1));
1994 // Safety: Use the normal constructor that performs validation to prevent invalid
1995 // serialized data being created as a Tensor, which would then break all the
1996 // code that's relying on these invariants.
1997 // By never serialising the strides in the first place, we reduce the possibility
1998 // of creating invalid serialised representations at the slight increase in
1999 // serialisation work.
2000 Tensor::try_from(shape, self.data)
2001 }
2002 }
2003
2004 /**
2005 * Converts this deserialised Tensor data which has a static lifetime for the dimension
2006 * names to a Tensor, using the serialised data.
2007 */
2008 impl<T, const D: usize> TryFrom<TensorDeserialize<'static, T, D>> for Tensor<T, D> {
2009 type Error = InvalidShapeError<D>;
2010
2011 fn try_from(value: TensorDeserialize<'static, T, D>) -> Result<Self, Self::Error> {
2012 Tensor::try_from(value.shape, value.data)
2013 }
2014 }
2015}
2016
2017#[cfg(feature = "serde")]
2018#[test]
2019fn test_serialize() {
2020 fn assert_serialize<T: Serialize>() {}
2021 assert_serialize::<Tensor<f64, 3>>();
2022 assert_serialize::<Tensor<f64, 2>>();
2023 assert_serialize::<Tensor<f64, 1>>();
2024 assert_serialize::<Tensor<f64, 0>>();
2025}
2026
2027#[cfg(feature = "serde")]
2028#[test]
2029fn test_deserialize() {
2030 use serde::Deserialize;
2031 fn assert_deserialize<'de, T: Deserialize<'de>>() {}
2032 assert_deserialize::<TensorDeserialize<f64, 3>>();
2033 assert_deserialize::<TensorDeserialize<f64, 2>>();
2034 assert_deserialize::<TensorDeserialize<f64, 1>>();
2035 assert_deserialize::<TensorDeserialize<f64, 0>>();
2036}
2037
2038#[cfg(feature = "serde")]
2039#[test]
2040fn test_serialization_deserialization_loop_toml() {
2041 #[rustfmt::skip]
2042 let tensor = Tensor::from(
2043 [("rows", 3), ("columns", 4)],
2044 vec![
2045 1, 2, 3, 4,
2046 5, 6, 7, 8,
2047 9, 10, 11, 12
2048 ],
2049 );
2050 let encoded = toml::to_string(&tensor).unwrap();
2051 assert_eq!(
2052 encoded,
2053 r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
2054shape = [["rows", 3], ["columns", 4]]
2055"#,
2056 );
2057 let parsed: Result<TensorDeserializeOwned<i32, 2>, _> = toml::from_str(&encoded);
2058 assert!(parsed.is_ok());
2059 let result = parsed.unwrap().into_tensor(["rows", "columns"]);
2060 assert!(result.is_ok());
2061 assert_eq!(result.unwrap(), tensor);
2062}
2063
2064#[cfg(feature = "serde")]
2065#[test]
2066fn test_serialization_deserialization_loop_json() {
2067 #[rustfmt::skip]
2068 let tensor = Tensor::from(
2069 [("rows", 3), ("columns", 4)],
2070 vec![
2071 1, 2, 3, 4,
2072 5, 6, 7, 8,
2073 9, 10, 11, 12
2074 ],
2075 );
2076 let encoded = serde_json::ser::to_string(&tensor).unwrap();
2077 assert_eq!(
2078 encoded,
2079 r#"{"data":[1,2,3,4,5,6,7,8,9,10,11,12],"shape":[["rows",3],["columns",4]]}"#,
2080 );
2081 let parsed: Result<TensorDeserialize<i32, 2>, _> = serde_json::de::from_str(&encoded);
2082 assert!(parsed.is_ok());
2083 let result = parsed.unwrap().into_tensor(["rows", "columns"]);
2084 assert!(result.is_ok());
2085 assert_eq!(result.unwrap(), tensor);
2086}
2087
2088#[cfg(feature = "serde")]
2089#[test]
2090fn test_deserialization_validation() {
2091 let parsed: Result<TensorDeserializeOwned<i32, 2>, _> = toml::from_str(
2092 r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
2093shape = [["rows", 4], ["columns", 4]]
2094"#,
2095 );
2096 assert!(parsed.is_ok());
2097 let result = parsed.unwrap().into_tensor(["rows", "columns"]);
2098 assert!(result.is_err());
2099}
2100
2101#[cfg(feature = "serde")]
2102#[cfg(test)]
2103const TENSOR_DATA: &'static str = r#"
2104{
2105 "data": [12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
2106 "shape": [["rows", 3], ["columns", 4]]
2107}
2108"#;
2109
2110#[cfg(feature = "serde")]
2111#[test]
2112fn test_deserialization_static_data() {
2113 #[rustfmt::skip]
2114 let tensor = Tensor::from(
2115 [("rows", 3), ("columns", 4)],
2116 vec![
2117 12, 11, 10, 9,
2118 8, 7, 6, 5,
2119 4, 3, 2, 1,
2120 ],
2121 );
2122 // To test TensorDeserialize we can't use toml because later
2123 // versions don't support parsing borrowed data, so we use
2124 // serde_json instead
2125 let parsed: Result<TensorDeserialize<i32, 2>, _> = serde_json::de::from_str(TENSOR_DATA);
2126 assert!(parsed.is_ok());
2127 let result: Result<Tensor<i32, 2>, _> = parsed.unwrap().try_into();
2128 assert!(result.is_ok());
2129 assert_eq!(result.unwrap(), tensor);
2130}
2131
2132macro_rules! tensor_select_impl {
2133 (impl Tensor $d:literal 1) => {
2134 impl<T> Tensor<T, $d> {
2135 /**
2136 * Selects the provided dimension name and index pairs in this Tensor, returning a
2137 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
2138 * always indexed as the provided values.
2139 *
2140 * This is a shorthand for manually constructing the TensorView and
2141 * [TensorIndex]
2142 *
2143 * Note: due to limitations in Rust's const generics support, this method is only
2144 * implemented for `provided_indexes` of length 1 and `D` from 1 to 6. You can fall
2145 * back to manual construction to create `TensorIndex`es with multiple provided
2146 * indexes if you need to reduce dimensionality by more than 1 dimension at a time.
2147 */
2148 #[track_caller]
2149 pub fn select(
2150 &self,
2151 provided_indexes: [(Dimension, usize); 1],
2152 ) -> TensorView<T, TensorIndex<T, &Tensor<T, $d>, $d, 1>, { $d - 1 }> {
2153 TensorView::from(TensorIndex::from(self, provided_indexes))
2154 }
2155
2156 /**
2157 * Selects the provided dimension name and index pairs in this Tensor, returning a
2158 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
2159 * always indexed as the provided values. The TensorIndex mutably borrows this
2160 * Tensor, and can therefore mutate it
2161 *
2162 * See [select](Tensor::select)
2163 */
2164 #[track_caller]
2165 pub fn select_mut(
2166 &mut self,
2167 provided_indexes: [(Dimension, usize); 1],
2168 ) -> TensorView<T, TensorIndex<T, &mut Tensor<T, $d>, $d, 1>, { $d - 1 }> {
2169 TensorView::from(TensorIndex::from(self, provided_indexes))
2170 }
2171
2172 /**
2173 * Selects the provided dimension name and index pairs in this Tensor, returning a
2174 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
2175 * always indexed as the provided values. The TensorIndex takes ownership of this
2176 * Tensor, and can therefore mutate it
2177 *
2178 * See [select](Tensor::select)
2179 */
2180 #[track_caller]
2181 pub fn select_owned(
2182 self,
2183 provided_indexes: [(Dimension, usize); 1],
2184 ) -> TensorView<T, TensorIndex<T, Tensor<T, $d>, $d, 1>, { $d - 1 }> {
2185 TensorView::from(TensorIndex::from(self, provided_indexes))
2186 }
2187 }
2188 };
2189}
2190
2191tensor_select_impl!(impl Tensor 6 1);
2192tensor_select_impl!(impl Tensor 5 1);
2193tensor_select_impl!(impl Tensor 4 1);
2194tensor_select_impl!(impl Tensor 3 1);
2195tensor_select_impl!(impl Tensor 2 1);
2196tensor_select_impl!(impl Tensor 1 1);
2197
2198macro_rules! tensor_expand_impl {
2199 (impl Tensor $d:literal 1) => {
2200 impl<T> Tensor<T, $d> {
2201 /**
2202 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2203 * a particular position within the shape, returning a TensorView which has more
2204 * dimensions than this Tensor.
2205 *
2206 * This is a shorthand for manually constructing the TensorView and
2207 * [TensorExpansion]
2208 *
2209 * Note: due to limitations in Rust's const generics support, this method is only
2210 * implemented for `extra_dimension_names` of length 1 and `D` from 0 to 5. You can
2211 * fall back to manual construction to create `TensorExpansion`s with multiple provided
2212 * indexes if you need to increase dimensionality by more than 1 dimension at a time.
2213 */
2214 #[track_caller]
2215 pub fn expand(
2216 &self,
2217 extra_dimension_names: [(usize, Dimension); 1],
2218 ) -> TensorView<T, TensorExpansion<T, &Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2219 TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2220 }
2221
2222 /**
2223 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2224 * a particular position within the shape, returning a TensorView which has more
2225 * dimensions than this Tensor. The TensorIndex mutably borrows this
2226 * Tensor, and can therefore mutate it
2227 *
2228 * See [expand](Tensor::expand)
2229 */
2230 #[track_caller]
2231 pub fn expand_mut(
2232 &mut self,
2233 extra_dimension_names: [(usize, Dimension); 1],
2234 ) -> TensorView<T, TensorExpansion<T, &mut Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2235 TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2236 }
2237
2238 /**
2239 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
2240 * a particular position within the shape, returning a TensorView which has more
2241 * dimensions than this Tensor. The TensorIndex takes ownership of this
2242 * Tensor, and can therefore mutate it
2243 *
2244 * See [expand](Tensor::expand)
2245 */
2246 #[track_caller]
2247 pub fn expand_owned(
2248 self,
2249 extra_dimension_names: [(usize, Dimension); 1],
2250 ) -> TensorView<T, TensorExpansion<T, Tensor<T, $d>, $d, 1>, { $d + 1 }> {
2251 TensorView::from(TensorExpansion::from(self, extra_dimension_names))
2252 }
2253 }
2254 };
2255}
2256
2257tensor_expand_impl!(impl Tensor 0 1);
2258tensor_expand_impl!(impl Tensor 1 1);
2259tensor_expand_impl!(impl Tensor 2 1);
2260tensor_expand_impl!(impl Tensor 3 1);
2261tensor_expand_impl!(impl Tensor 4 1);
2262tensor_expand_impl!(impl Tensor 5 1);