easy_ml/tensors/views.rs
1/*!
2 * Generic views into a tensor.
3 *
4 * The concept of a view into a tensor is built from the low level [TensorRef] and
5 * [TensorMut] traits which define having read and read/write access to Tensor data
6 * respectively, and the high level API implemented on the [TensorView] struct.
7 *
8 * Since a Tensor is itself a TensorRef, the APIs for the traits are a little verbose to
9 * avoid name clashes with methods defined on the Tensor and TensorView types. You should
10 * typically use TensorRef and TensorMut implementations via the TensorView struct which provides
11 * an API closely resembling Tensor.
12 */
13
14use std::marker::PhantomData;
15
16use crate::linear_algebra;
17use crate::numeric::{Numeric, NumericRef};
18use crate::tensors::dimensions;
19use crate::tensors::indexing::{
20 TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
21 TensorReferenceMutIterator, TensorTranspose,
22};
23use crate::tensors::{Dimension, Tensor};
24
25mod indexes;
26mod map;
27mod ranges;
28mod renamed;
29mod reshape;
30mod reverse;
31pub mod traits;
32mod zip;
33
34pub use indexes::*;
35pub(crate) use map::*;
36pub use ranges::*;
37pub use renamed::*;
38pub use reshape::*;
39pub use reverse::*;
40pub use zip::*;
41
42/**
43* A shared/immutable reference to a tensor (or a portion of it) of some type and number of
44* dimensions.
45*
46* # Indexing
47*
48* A TensorRef has a shape of type `[(Dimension, usize); D]`. This defines the valid indexes along
49* each dimension name and length pair from 0 inclusive to the length exclusive. If the shape was
50* `[("r", 2), ("c", 2)]` the indexes used would be `[0,0]`, `[0,1]`, `[1,0]` and `[1,1]`.
51* Although the dimension name in each pair is used for many high level APIs, for TensorRef the
52* order of dimensions is used, and the indexes (`[usize; D]`) these trait methods are called with
53* must be in the same order as the shape. In general some `[("a", a), ("b", b), ("c", c)...]`
54* shape is indexed from `[0,0,0...]` through to `[a - 1, b - 1, c - 1...]`, regardless of how the
55* data is actually laid out in memory.
56*
57* # Safety
58*
59* In order to support returning references without bounds checking in a useful way, the
60* implementing type is required to uphold several invariants that cannot be checked by
61* the compiler.
62*
63* 1 - Any valid index as described in Indexing will yield a safe reference when calling
64* `get_reference_unchecked` and `get_reference_unchecked_mut`.
65*
66* 2 - The view shape that defines which indexes are valid may not be changed by a shared reference
67* to the TensorRef implementation. ie, the tensor may not be resized while a mutable reference is
68* held to it, except by that reference.
69*
70* 3 - All dimension names in the `view_shape` must be unique.
71*
72* 4 - All dimension lengths in the `view_shape` must be non zero.
73*
74* 5 - `data_layout` must return values correctly as documented on [`DataLayout`]
75*
76* Essentially, interior mutability causes problems, since code looping through the range of valid
77* indexes in a TensorRef needs to be able to rely on that range of valid indexes not changing.
78* This is trivially the case by default since a [Tensor] does not have any form of
79* interior mutability, and therefore an iterator holding a shared reference to a Tensor prevents
80* that tensor being resized. However, a type *wrongly* implementing TensorRef could introduce
81* interior mutability by putting the Tensor in an `Arc<Mutex<>>` which would allow another thread
82* to resize a tensor while an iterator was looping through previously valid indexes on a different
83* thread. This is the same contract as
84* [`NoInteriorMutability`](crate::matrices::views::NoInteriorMutability) used in the matrix APIs.
85*
86* Note that it is okay to be able to resize any TensorRef implementation if that always requires
87* an exclusive reference to the TensorRef/Tensor, since the exclusivity prevents the above
88* scenario.
89*/
90pub unsafe trait TensorRef<T, const D: usize> {
91 /**
92 * Gets a reference to the value at the index if the index is in range. Otherwise returns None.
93 */
94 fn get_reference(&self, indexes: [usize; D]) -> Option<&T>;
95
96 /**
97 * The shape this tensor has. See [dimensions] for an overview.
98 * The product of the lengths in the pairs define how many elements are in the tensor
99 * (or the portion of it that is visible).
100 */
101 fn view_shape(&self) -> [(Dimension, usize); D];
102
103 /**
104 * Gets a reference to the value at the index without doing any bounds checking. For a safe
105 * alternative see [get_reference](TensorRef::get_reference).
106 *
107 * # Safety
108 *
109 * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
110 * resulting reference is not used. Valid indexes are defined as in [TensorRef].
111 *
112 * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
113 * [TensorRef]: TensorRef
114 */
115 #[allow(clippy::missing_safety_doc)] // it's not missing
116 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T;
117
118 /**
119 * The way the data in this tensor is laid out in memory. In particular,
120 * [`Linear`](DataLayout) has several requirements on what is returned that must be upheld
121 * by implementations of this trait.
122 *
123 * For a [Tensor] this would return `DataLayout::Linear` in the same order as the `view_shape`,
124 * since the data is in a single line and the `view_shape` is in most significant dimension
125 * to least. Many views however, create shapes that do not correspond to linear memory, either
126 * by combining non array data with tensors, or hiding dimensions. Similarly, a row major
127 * Tensor might be transposed to a column major TensorView, so the view shape could be reversed
128 * compared to the order of significance of the dimensions in memory.
129 *
130 * In general, an array of dimension names matching the view shape order is big endian, and
131 * will be iterated through efficiently by [TensorIterator], and an array of dimension names
132 * in reverse of the view shape is in little endian order.
133 *
134 * The implementation of this trait must ensure that if it returns `DataLayout::Linear`
135 * the set of dimension names are returned in the order of most significant to least and match
136 * the set of the dimension names returned by `view_shape`.
137 */
138 fn data_layout(&self) -> DataLayout<D>;
139}
140
141/**
142 * How the data in the tensor is laid out in memory.
143 */
144#[derive(Clone, Debug, Eq, PartialEq)]
145pub enum DataLayout<const D: usize> {
146 /**
147 * The data is laid out in linear storage in memory, such that we could take a slice over the
148 * entire data specified by our `view_shape`.
149 *
150 * The `D` length array specifies the dimensions in the `view_shape` in the order of most
151 * significant dimension (in memory) to least.
152 *
153 * In general, an array of dimension names in the same order as the view shape is big endian
154 * order (implying the order of dimensions in the view shape is most significant to least),
155 * and will be iterated through efficiently by [TensorIterator], and an array of
156 * dimension names in reverse of the view shape is in little endian order
157 * (implying the order of dimensions in the view shape is least significant to most).
158 *
159 * In memory, the data will have some order such that if we want repeatedly take 1 step
160 * through memory from the first value to the last there will be a most significant dimension
161 * that always goes up, through to a least significant dimension with the most rapidly varying
162 * index.
163 *
164 * In most of Easy ML's Tensors, the `view_shape` dimensions would already be in the order of
165 * most significant dimension to least (since [Tensor] stores its data in big endian order),
166 * so the list of dimension names will just increment match the order of the view shape.
167 *
168 * For example, a tensor with a view shape of `[("batch", 2), ("row", 2), ("column", 3)]` that
169 * stores its data in most significant to least would be (indexed) like:
170 * ```ignore
171 * [
172 * (0,0,0), (0,0,1), (0,0,2),
173 * (0,1,0), (0,1,1), (0,1,2),
174 * (1,0,0), (1,0,1), (1,0,2),
175 * (1,1,0), (1,1,1), (1,1,2)
176 * ]
177 * ```
178 *
179 * To take one step in memory, we would increment the right most dimension index ("column"),
180 * counting our way up through to the left most dimension index ("batch"). If we changed
181 * this tensor to `[("column", 3), ("row", 2), ("batch", 2)]` so that the `view_shape` was
182 * swapped to least significant dimension to most but the data remained in the same order,
183 * our tensor would still have a DataLayout of `Linear(["batch", "row", "column"])`, since
184 * the indexes in the transposed `view_shape` correspond to an actual memory layout that's
185 * completely reversed. Alternatively, you could say we reversed the view shape but the
186 * memory layout never changed:
187 * ```ignore
188 * [
189 * (0,0,0), (1,0,0), (2,0,0),
190 * (0,1,0), (1,1,0), (2,1,0),
191 * (0,0,1), (1,0,1), (2,0,1),
192 * (0,1,1), (1,1,1), (2,1,1)
193 * ]
194 * ```
195 *
196 * To take one step in memory, we now need to increment the left most dimension index (on
197 * the view shape) ("column"), counting our way in reverse to the right most dimension
198 * index ("batch").
199 *
200 * That `["batch", "row", "column"]` is also exactly the order you would need to swap your
201 * dimensions on the `view_shape` to get back to most significant to least. A [TensorAccess]
202 * could reorder the tensor by this array order to get back to most significant to least
203 * ordering on the `view_shape` in order to iterate through the data efficiently.
204 */
205 Linear([Dimension; D]),
206 /**
207 * The data is not laid out in linear storage in memory.
208 */
209 NonLinear,
210 /**
211 * The data is not laid out in a linear or non linear way, or we don't know how it's laid
212 * out.
213 */
214 Other,
215}
216
217/**
218 * A unique/mutable reference to a tensor (or a portion of it) of some type.
219 *
220 * # Safety
221 *
222 * See [TensorRef].
223 */
224pub unsafe trait TensorMut<T, const D: usize>: TensorRef<T, D> {
225 /**
226 * Gets a mutable reference to the value at the index, if the index is in range. Otherwise
227 * returns None.
228 */
229 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T>;
230
231 /**
232 * Gets a mutable reference to the value at the index without doing any bounds checking.
233 * For a safe alternative see [get_reference_mut](TensorMut::get_reference_mut).
234 *
235 * # Safety
236 *
237 * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
238 * resulting reference is not used. Valid indexes are defined as in [TensorRef].
239 *
240 * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
241 * [TensorRef]: TensorRef
242 */
243 #[allow(clippy::missing_safety_doc)] // it's not missing
244 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T;
245}
246
247/**
248 * A view into some or all of a tensor.
249 *
250 * A TensorView has a similar relationship to a [`Tensor`] as a `&str` has to a `String`, or an
251 * array slice to an array. A TensorView cannot resize its source, and may span only a portion
252 * of the source Tensor in each dimension.
253 *
254 * However a TensorView is generic not only over the type of the data in the Tensor,
255 * but also over the way the Tensor is 'sliced' and the two are orthogonal to each other.
256 *
257 * TensorView closely mirrors the API of Tensor.
258 * Methods that create a new tensor do not return a TensorView, they return a Tensor.
259 */
260#[derive(Clone)]
261pub struct TensorView<T, S, const D: usize> {
262 source: S,
263 _type: PhantomData<T>,
264}
265
266/**
267 * TensorView methods which require only read access via a [TensorRef] source.
268 */
269impl<T, S, const D: usize> TensorView<T, S, D>
270where
271 S: TensorRef<T, D>,
272{
273 /**
274 * Creates a TensorView from a source of some type.
275 *
276 * The lifetime of the source determines the lifetime of the TensorView created. If the
277 * TensorView is created from a reference to a Tensor, then the TensorView cannot live
278 * longer than the Tensor referenced.
279 */
280 pub fn from(source: S) -> TensorView<T, S, D> {
281 TensorView {
282 source,
283 _type: PhantomData,
284 }
285 }
286
287 /**
288 * Consumes the tensor view, yielding the source it was created from.
289 */
290 pub fn source(self) -> S {
291 self.source
292 }
293
294 /**
295 * Gives a reference to the tensor view's source.
296 */
297 pub fn source_ref(&self) -> &S {
298 &self.source
299 }
300
301 /**
302 * Gives a mutable reference to the tensor view's source.
303 */
304 pub fn source_ref_mut(&mut self) -> &mut S {
305 &mut self.source
306 }
307
308 /**
309 * The shape of this tensor view. Since Tensors are named Tensors, their shape is not just a
310 * list of length along each dimension, but instead a list of pairs of names and lengths.
311 *
312 * Note that a TensorView may have a shape which is different than the Tensor it is providing
313 * access to the data of. The TensorView might be [masking dimensions](TensorIndex) or
314 * elements from the shape, or exposing [false ones](TensorExpansion).
315 *
316 * See also
317 * - [dimensions]
318 * - [indexing](crate::tensors::indexing)
319 */
320 pub fn shape(&self) -> [(Dimension, usize); D] {
321 self.source.view_shape()
322 }
323
324 /**
325 * Returns the length of the dimension name provided, if one is present in the tensor view.
326 *
327 * See also
328 * - [dimensions]
329 * - [indexing](crate::tensors::indexing)
330 */
331 pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
332 dimensions::length_of(&self.source.view_shape(), dimension)
333 }
334
335 /**
336 * Returns the last index of the dimension name provided, if one is present in the tensor view.
337 *
338 * This is always 1 less than the length, the 'index' in this sense is based on what the
339 * Tensor's shape is, not any implementation index.
340 *
341 * See also
342 * - [dimensions]
343 * - [indexing](crate::tensors::indexing)
344 */
345 pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
346 dimensions::last_index_of(&self.source.view_shape(), dimension)
347 }
348
349 /**
350 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
351 * to read values from this tensor view.
352 *
353 * # Panics
354 *
355 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
356 */
357 #[track_caller]
358 pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &S, D> {
359 TensorAccess::from(&self.source, dimensions)
360 }
361
362 /**
363 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
364 * to read or write values from this tensor view.
365 *
366 * # Panics
367 *
368 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
369 */
370 #[track_caller]
371 pub fn index_by_mut(&mut self, dimensions: [Dimension; D]) -> TensorAccess<T, &mut S, D> {
372 TensorAccess::from(&mut self.source, dimensions)
373 }
374
375 /**
376 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
377 * to read or write values from this tensor view.
378 *
379 * # Panics
380 *
381 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
382 */
383 #[track_caller]
384 pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
385 TensorAccess::from(self.source, dimensions)
386 }
387
388 /**
389 * Creates a TensorAccess which will index into the dimensions of the source this TensorView
390 * was created with in the same order as they were declared.
391 * See [TensorAccess::from_source_order].
392 */
393 pub fn index(&self) -> TensorAccess<T, &S, D> {
394 TensorAccess::from_source_order(&self.source)
395 }
396
397 /**
398 * Creates a TensorAccess which will index into the dimensions of the source this TensorView
399 * was created with in the same order as they were declared. The TensorAccess mutably borrows
400 * the source, and can therefore mutate it if it implements TensorMut.
401 * See [TensorAccess::from_source_order].
402 */
403 pub fn index_mut(&mut self) -> TensorAccess<T, &mut S, D> {
404 TensorAccess::from_source_order(&mut self.source)
405 }
406
407 /**
408 * Creates a TensorAccess which will index into the dimensions this Tensor was
409 * created with in the same order as they were provided. The TensorAccess takes ownership
410 * of the Tensor, and can therefore mutate it. The TensorAccess takes ownership of
411 * the source, and can therefore mutate it if it implements TensorMut.
412 * See [TensorAccess::from_source_order].
413 */
414 pub fn index_owned(self) -> TensorAccess<T, S, D> {
415 TensorAccess::from_source_order(self.source)
416 }
417
418 /**
419 * Returns an iterator over references to the data in this TensorView.
420 */
421 pub fn iter_reference(&self) -> TensorReferenceIterator<T, S, D> {
422 TensorReferenceIterator::from(&self.source)
423 }
424
425 /**
426 * Returns a TensorView with the dimension names of the shape renamed to the provided
427 * dimensions. The data of this tensor and the dimension lengths and order remain unchanged.
428 *
429 * This is a shorthand for constructing the TensorView from this TensorView. See
430 * [`Tensor::rename_view`](Tensor::rename_view).
431 *
432 * # Panics
433 *
434 * - If a dimension name is not unique
435 */
436 #[track_caller]
437 pub fn rename_view(
438 &self,
439 dimensions: [Dimension; D],
440 ) -> TensorView<T, TensorRename<T, &S, D>, D> {
441 TensorView::from(TensorRename::from(&self.source, dimensions))
442 }
443
444 /**
445 * Returns a TensorView with the dimensions changed to the provided shape without moving any
446 * data around. The new Tensor may also have a different number of dimensions.
447 *
448 * This is a shorthand for constructing the TensorView from this TensorView. See
449 * [`Tensor::reshape_view`](Tensor::reshape_view).
450 *
451 * # Panics
452 *
453 * - If the number of provided elements in the new shape does not match the product of the
454 * dimension lengths in the existing tensor's shape.
455 * - If a dimension name is not unique
456 */
457 pub fn reshape<const D2: usize>(
458 &self,
459 shape: [(Dimension, usize); D2],
460 ) -> TensorView<T, TensorReshape<T, &S, D, D2>, D2> {
461 TensorView::from(TensorReshape::from(&self.source, shape))
462 }
463
464 /**
465 * Returns a TensorView with the dimensions changed to the provided shape without moving any
466 * data around. The new Tensor may also have a different number of dimensions. The
467 * TensorReshape mutably borrows the source, and can therefore mutate it if it implements
468 * TensorMut.
469 *
470 * This is a shorthand for constructing the TensorView from this TensorView. See
471 * [`Tensor::reshape_view`](Tensor::reshape_view).
472 *
473 * # Panics
474 *
475 * - If the number of provided elements in the new shape does not match the product of the
476 * dimension lengths in the existing tensor's shape.
477 * - If a dimension name is not unique
478 */
479 pub fn reshape_mut<const D2: usize>(
480 &mut self,
481 shape: [(Dimension, usize); D2],
482 ) -> TensorView<T, TensorReshape<T, &mut S, D, D2>, D2> {
483 TensorView::from(TensorReshape::from(&mut self.source, shape))
484 }
485
486 /**
487 * Returns a TensorView with the dimensions changed to the provided shape without moving any
488 * data around. The new Tensor may also have a different number of dimensions. The
489 * TensorReshape takes ownership of the source, and can therefore mutate it if it implements
490 * TensorMut.
491 *
492 * This is a shorthand for constructing the TensorView from this TensorView. See
493 * [`Tensor::reshape_view`](Tensor::reshape_view).
494 *
495 * # Panics
496 *
497 * - If the number of provided elements in the new shape does not match the product of the
498 * dimension lengths in the existing tensor's shape.
499 * - If a dimension name is not unique
500 */
501 pub fn reshape_owned<const D2: usize>(
502 self,
503 shape: [(Dimension, usize); D2],
504 ) -> TensorView<T, TensorReshape<T, S, D, D2>, D2> {
505 TensorView::from(TensorReshape::from(self.source, shape))
506 }
507
508 /**
509 * Given the dimension name, returns a view of this tensor reshaped to one dimension
510 * with a length equal to the number of elements in this tensor.
511 */
512 pub fn flatten(&self, dimension: Dimension) -> TensorView<T, TensorReshape<T, &S, D, 1>, 1> {
513 self.reshape([(dimension, dimensions::elements(&self.shape()))])
514 }
515
516 /**
517 * Given the dimension name, returns a view of this tensor reshaped to one dimension
518 * with a length equal to the number of elements in this tensor.
519 */
520 pub fn flatten_mut(
521 &mut self,
522 dimension: Dimension,
523 ) -> TensorView<T, TensorReshape<T, &mut S, D, 1>, 1> {
524 self.reshape_mut([(dimension, dimensions::elements(&self.shape()))])
525 }
526
527 /**
528 * Given the dimension name, returns a view of this tensor reshaped to one dimension
529 * with a length equal to the number of elements in this tensor.
530 *
531 * If you intend to query the tensor a lot after creating the view, consider
532 * using [flatten_into_tensor](TensorView::flatten_into_tensor) instead as it will have
533 * less overhead to index after creation.
534 */
535 pub fn flatten_owned(
536 self,
537 dimension: Dimension,
538 ) -> TensorView<T, TensorReshape<T, S, D, 1>, 1> {
539 let length = dimensions::elements(&self.shape());
540 self.reshape_owned([(dimension, length)])
541 }
542
543 /**
544 * Given the dimension name, returns a new tensor reshaped to one dimension
545 * with a length equal to the number of elements in this tensor and all elements
546 * copied into the new tensor.
547 */
548 pub fn flatten_into_tensor(self, dimension: Dimension) -> Tensor<T, 1>
549 where
550 T: Clone,
551 {
552 // TODO: Want a specialisation here to use Tensor::reshape_owned when we
553 // know that our source type is Tensor
554 let length = dimensions::elements(&self.shape());
555 Tensor::from([(dimension, length)], self.iter().collect())
556 }
557
558 /**
559 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
560 * range from view. Error cases are documented on [TensorRange].
561 *
562 * This is a shorthand for constructing the TensorView from this TensorView. See
563 * [`Tensor::range`](Tensor::range).
564 */
565 pub fn range<R, const P: usize>(
566 &self,
567 ranges: [(Dimension, R); P],
568 ) -> Result<TensorView<T, TensorRange<T, &S, D>, D>, IndexRangeValidationError<D, P>>
569 where
570 R: Into<IndexRange>,
571 {
572 TensorRange::from(&self.source, ranges).map(|range| TensorView::from(range))
573 }
574
575 /**
576 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
577 * range from view. Error cases are documented on [TensorRange]. The TensorRange
578 * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
579 *
580 * This is a shorthand for constructing the TensorView from this TensorView. See
581 * [`Tensor::range`](Tensor::range).
582 */
583 pub fn range_mut<R, const P: usize>(
584 &mut self,
585 ranges: [(Dimension, R); P],
586 ) -> Result<TensorView<T, TensorRange<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
587 where
588 R: Into<IndexRange>,
589 {
590 TensorRange::from(&mut self.source, ranges).map(|range| TensorView::from(range))
591 }
592
593 /**
594 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
595 * range from view. Error cases are documented on [TensorRange]. The TensorRange
596 * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
597 *
598 * This is a shorthand for constructing the TensorView from this TensorView. See
599 * [`Tensor::range`](Tensor::range).
600 */
601 pub fn range_owned<R, const P: usize>(
602 self,
603 ranges: [(Dimension, R); P],
604 ) -> Result<TensorView<T, TensorRange<T, S, D>, D>, IndexRangeValidationError<D, P>>
605 where
606 R: Into<IndexRange>,
607 {
608 TensorRange::from(self.source, ranges).map(|range| TensorView::from(range))
609 }
610
611 /**
612 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
613 * range from view. Error cases are documented on [TensorMask].
614 *
615 * This is a shorthand for constructing the TensorView from this TensorView. See
616 * [`Tensor::mask`](Tensor::mask).
617 */
618 pub fn mask<R, const P: usize>(
619 &self,
620 masks: [(Dimension, R); P],
621 ) -> Result<TensorView<T, TensorMask<T, &S, D>, D>, IndexRangeValidationError<D, P>>
622 where
623 R: Into<IndexRange>,
624 {
625 TensorMask::from(&self.source, masks).map(|mask| TensorView::from(mask))
626 }
627
628 /**
629 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
630 * range from view. Error cases are documented on [TensorMask]. The TensorMask
631 * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
632 *
633 * This is a shorthand for constructing the TensorView from this TensorView. See
634 * [`Tensor::mask`](Tensor::mask).
635 */
636 pub fn mask_mut<R, const P: usize>(
637 &mut self,
638 masks: [(Dimension, R); P],
639 ) -> Result<TensorView<T, TensorMask<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
640 where
641 R: Into<IndexRange>,
642 {
643 TensorMask::from(&mut self.source, masks).map(|mask| TensorView::from(mask))
644 }
645
646 /**
647 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
648 * range from view. Error cases are documented on [TensorMask]. The TensorMask
649 * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
650 *
651 * This is a shorthand for constructing the TensorView from this TensorView. See
652 * [`Tensor::mask`](Tensor::mask).
653 */
654 pub fn mask_owned<R, const P: usize>(
655 self,
656 masks: [(Dimension, R); P],
657 ) -> Result<TensorView<T, TensorMask<T, S, D>, D>, IndexRangeValidationError<D, P>>
658 where
659 R: Into<IndexRange>,
660 {
661 TensorMask::from(self.source, masks).map(|mask| TensorView::from(mask))
662 }
663
664 /**
665 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
666 * order. The data of this tensor and the dimension lengths remain unchanged.
667 *
668 * This is a shorthand for constructing the TensorView from this TensorView. See
669 * [`Tensor::reverse`](Tensor::reverse).
670 *
671 * # Panics
672 *
673 * - If a dimension name is not in the tensor's shape or is repeated.
674 */
675 #[track_caller]
676 pub fn reverse(&self, dimensions: &[Dimension]) -> TensorView<T, TensorReverse<T, &S, D>, D> {
677 TensorView::from(TensorReverse::from(&self.source, dimensions))
678 }
679
680 /**
681 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
682 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
683 * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
684 *
685 * This is a shorthand for constructing the TensorView from this TensorView. See
686 * [`Tensor::reverse`](Tensor::reverse).
687 *
688 * # Panics
689 *
690 * - If a dimension name is not in the tensor's shape or is repeated.
691 */
692 #[track_caller]
693 pub fn reverse_mut(
694 &mut self,
695 dimensions: &[Dimension],
696 ) -> TensorView<T, TensorReverse<T, &mut S, D>, D> {
697 TensorView::from(TensorReverse::from(&mut self.source, dimensions))
698 }
699
700 /**
701 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
702 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
703 * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
704 *
705 * This is a shorthand for constructing the TensorView from this TensorView. See
706 * [`Tensor::reverse`](Tensor::reverse).
707 *
708 * # Panics
709 *
710 * - If a dimension name is not in the tensor's shape or is repeated.
711 */
712 #[track_caller]
713 pub fn reverse_owned(
714 self,
715 dimensions: &[Dimension],
716 ) -> TensorView<T, TensorReverse<T, S, D>, D> {
717 TensorView::from(TensorReverse::from(self.source, dimensions))
718 }
719
720 /**
721 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
722 * mapped by a function. The value pairs are not copied for you, if you're using `Copy` types
723 * or need to clone the values anyway, you can use
724 * [`TensorView::elementwise`](TensorView::elementwise) instead.
725 *
726 * # Generics
727 *
728 * This method can be called with any right hand side that can be converted to a TensorView,
729 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
730 *
731 * # Panics
732 *
733 * If the two tensors have different shapes.
734 */
735 #[track_caller]
736 pub fn elementwise_reference<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
737 where
738 I: Into<TensorView<T, S2, D>>,
739 S2: TensorRef<T, D>,
740 M: Fn(&T, &T) -> T,
741 {
742 self.elementwise_reference_less_generic(rhs.into(), mapping_function)
743 }
744
745 /**
746 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
747 * mapped by a function. The mapping function also receives each index corresponding to the
748 * value pairs. The value pairs are not copied for you, if you're using `Copy` types
749 * or need to clone the values anyway, you can use
750 * [`TensorView::elementwise_with_index`](TensorView::elementwise_with_index) instead.
751 *
752 * # Generics
753 *
754 * This method can be called with any right hand side that can be converted to a TensorView,
755 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
756 *
757 * # Panics
758 *
759 * If the two tensors have different shapes.
760 */
761 #[track_caller]
762 pub fn elementwise_reference_with_index<S2, I, M>(
763 &self,
764 rhs: I,
765 mapping_function: M,
766 ) -> Tensor<T, D>
767 where
768 I: Into<TensorView<T, S2, D>>,
769 S2: TensorRef<T, D>,
770 M: Fn([usize; D], &T, &T) -> T,
771 {
772 self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
773 }
774
775 #[track_caller]
776 fn elementwise_reference_less_generic<S2, M>(
777 &self,
778 rhs: TensorView<T, S2, D>,
779 mapping_function: M,
780 ) -> Tensor<T, D>
781 where
782 S2: TensorRef<T, D>,
783 M: Fn(&T, &T) -> T,
784 {
785 let left_shape = self.shape();
786 let right_shape = rhs.shape();
787 if left_shape != right_shape {
788 panic!(
789 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
790 left_shape, right_shape
791 );
792 }
793 let mapped = self
794 .iter_reference()
795 .zip(rhs.iter_reference())
796 .map(|(x, y)| mapping_function(x, y))
797 .collect();
798 Tensor::from(left_shape, mapped)
799 }
800
801 #[track_caller]
802 fn elementwise_reference_less_generic_with_index<S2, M>(
803 &self,
804 rhs: TensorView<T, S2, D>,
805 mapping_function: M,
806 ) -> Tensor<T, D>
807 where
808 S2: TensorRef<T, D>,
809 M: Fn([usize; D], &T, &T) -> T,
810 {
811 let left_shape = self.shape();
812 let right_shape = rhs.shape();
813 if left_shape != right_shape {
814 panic!(
815 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
816 left_shape, right_shape
817 );
818 }
819 // we just checked both shapes were the same, so we don't need to propagate indexes
820 // for both tensors because they'll be identical
821 let mapped = self
822 .iter_reference()
823 .with_index()
824 .zip(rhs.iter_reference())
825 .map(|((i, x), y)| mapping_function(i, x, y))
826 .collect();
827 Tensor::from(left_shape, mapped)
828 }
829
830 /**
831 * Returns a TensorView which makes the order of the data in this tensor appear to be in
832 * a different order. The order of the dimension names is unchanged, although their lengths
833 * may swap.
834 *
835 * This is a shorthand for constructing the TensorView from this TensorView.
836 *
837 * See also: [transpose](TensorView::transpose), [TensorTranspose]
838 *
839 * # Panics
840 *
841 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
842 * order need not match.
843 */
844 pub fn transpose_view(
845 &self,
846 dimensions: [Dimension; D],
847 ) -> TensorView<T, TensorTranspose<T, &S, D>, D> {
848 TensorView::from(TensorTranspose::from(&self.source, dimensions))
849 }
850
851 /// Unverified constructor for interal use when we know the dimensions/data/strides are
852 /// the same as the existing instance and don't need reverification
853 pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> TensorView<T, Tensor<T, D>, D> {
854 let shape = self.shape();
855 let strides = crate::tensors::compute_strides(&shape);
856 TensorView::from(Tensor {
857 data,
858 shape,
859 strides,
860 })
861 }
862}
863
864impl<T, S, const D: usize> TensorView<T, S, D>
865where
866 S: TensorMut<T, D>,
867{
868 /**
869 * Returns an iterator over mutable references to the data in this TensorView.
870 */
871 pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<T, S, D> {
872 TensorReferenceMutIterator::from(&mut self.source)
873 }
874
875 /**
876 * Creates an iterator over the values in this TensorView.
877 */
878 pub fn iter_owned(self) -> TensorOwnedIterator<T, S, D>
879 where
880 T: Default,
881 {
882 TensorOwnedIterator::from(self.source())
883 }
884}
885
886/**
887 * TensorView methods which require only read access via a [TensorRef] source
888 * and a clonable type.
889 */
890impl<T, S, const D: usize> TensorView<T, S, D>
891where
892 T: Clone,
893 S: TensorRef<T, D>,
894{
895 /**
896 * Gets a copy of the first value in this tensor.
897 * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
898 * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
899 */
900 pub fn first(&self) -> T {
901 self.iter()
902 .next()
903 .expect("Tensors always have at least 1 element")
904 }
905
906 /**
907 * Returns a new Tensor which has the same data as this tensor, but with the order of data
908 * changed. The order of the dimension names is unchanged, although their lengths may swap.
909 *
910 * For example, with a `[("x", x), ("y", y)]` tensor you could call
911 * `transpose(["y", "x"])` which would return a new tensor with a shape of
912 * `[("x", y), ("y", x)]` where every (x,y) of its data corresponds to (y,x) in the original.
913 *
914 * This method need not shift *all* the dimensions though, you could also swap the width
915 * and height of images in a tensor with a shape of
916 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `transpose(["batch", "w", "h", "c"])`
917 * which would return a new tensor where all the images have been swapped over the diagonal.
918 *
919 * See also: [TensorAccess], [reorder](TensorView::reorder)
920 *
921 * # Panics
922 *
923 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
924 * order need not match (and if the order does match, this function is just an expensive
925 * clone).
926 */
927 #[track_caller]
928 pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
929 let shape = self.shape();
930 let mut reordered = self.reorder(dimensions);
931 // Transposition is essentially reordering, but we retain the dimension name ordering
932 // of the original order, this means we may swap dimension lengths, but the dimensions
933 // will not change order.
934 #[allow(clippy::needless_range_loop)]
935 for d in 0..D {
936 reordered.shape[d].0 = shape[d].0;
937 }
938 reordered
939 }
940
941 /**
942 * Returns a new Tensor which has the same data as this tensor, but with the order of the
943 * dimensions and corresponding order of data changed.
944 *
945 * For example, with a `[("x", x), ("y", y)]` tensor you could call
946 * `reorder(["y", "x"])` which would return a new tensor with a shape of
947 * `[("y", y), ("x", x)]` where every (y,x) of its data corresponds to (x,y) in the original.
948 *
949 * This method need not shift *all* the dimensions though, you could also swap the width
950 * and height of images in a tensor with a shape of
951 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `reorder(["batch", "w", "h", "c"])`
952 * which would return a new tensor where every (b,w,h,c) of its data corresponds to (b,h,w,c)
953 * in the original.
954 *
955 * See also: [TensorAccess], [transpose](TensorView::transpose)
956 *
957 * # Panics
958 *
959 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
960 * order need not match (and if the order does match, this function is just an expensive
961 * clone).
962 */
963 #[track_caller]
964 pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
965 let reorderd = match TensorAccess::try_from(&self.source, dimensions) {
966 Ok(reordered) => reordered,
967 Err(_error) => panic!(
968 "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
969 dimensions,
970 self.shape(),
971 ),
972 };
973 let reorderd_shape = reorderd.shape();
974 Tensor::from(reorderd_shape, reorderd.iter().collect())
975 }
976
977 /**
978 * Creates and returns a new tensor with all values from the original with the
979 * function applied to each. This can be used to change the type of the tensor
980 * such as creating a mask:
981 * ```
982 * use easy_ml::tensors::Tensor;
983 * use easy_ml::tensors::views::TensorView;
984 * let x = TensorView::from(Tensor::from([("a", 2), ("b", 2)], vec![
985 * 0.0, 1.2,
986 * 5.8, 6.9
987 * ]));
988 * let y = x.map(|element| element > 2.0);
989 * let result = Tensor::from([("a", 2), ("b", 2)], vec![
990 * false, false,
991 * true, true
992 * ]);
993 * assert_eq!(&y, &result);
994 * ```
995 */
996 pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
997 let mapped = self.iter().map(mapping_function).collect();
998 Tensor::from(self.shape(), mapped)
999 }
1000
1001 /**
1002 * Creates and returns a new tensor with all values from the original and
1003 * the index of each value mapped by a function.
1004 */
1005 pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
1006 let mapped = self
1007 .iter()
1008 .with_index()
1009 .map(|(i, x)| mapping_function(i, x))
1010 .collect();
1011 Tensor::from(self.shape(), mapped)
1012 }
1013
1014 /**
1015 * Returns an iterator over copies of the data in this TensorView.
1016 */
1017 pub fn iter(&self) -> TensorIterator<T, S, D> {
1018 TensorIterator::from(&self.source)
1019 }
1020
1021 /**
1022 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1023 * mapped by a function.
1024 *
1025 * ```
1026 * use easy_ml::tensors::Tensor;
1027 * use easy_ml::tensors::views::TensorView;
1028 * let lhs = TensorView::from(Tensor::from([("a", 4)], vec![1, 2, 3, 4]));
1029 * let rhs = TensorView::from(Tensor::from([("a", 4)], vec![0, 1, 2, 3]));
1030 * let multiplied = lhs.elementwise(&rhs, |l, r| l * r);
1031 * assert_eq!(
1032 * multiplied,
1033 * Tensor::from([("a", 4)], vec![0, 2, 6, 12])
1034 * );
1035 * ```
1036 *
1037 * # Generics
1038 *
1039 * This method can be called with any right hand side that can be converted to a TensorView,
1040 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1041 *
1042 * # Panics
1043 *
1044 * If the two tensors have different shapes.
1045 */
1046 #[track_caller]
1047 pub fn elementwise<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1048 where
1049 I: Into<TensorView<T, S2, D>>,
1050 S2: TensorRef<T, D>,
1051 M: Fn(T, T) -> T,
1052 {
1053 self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
1054 mapping_function(lhs.clone(), rhs.clone())
1055 })
1056 }
1057
1058 /**
1059 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1060 * mapped by a function. The mapping function also receives each index corresponding to the
1061 * value pairs.
1062 *
1063 * # Generics
1064 *
1065 * This method can be called with any right hand side that can be converted to a TensorView,
1066 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1067 *
1068 * # Panics
1069 *
1070 * If the two tensors have different shapes.
1071 */
1072 #[track_caller]
1073 pub fn elementwise_with_index<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1074 where
1075 I: Into<TensorView<T, S2, D>>,
1076 S2: TensorRef<T, D>,
1077 M: Fn([usize; D], T, T) -> T,
1078 {
1079 self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
1080 mapping_function(i, lhs.clone(), rhs.clone())
1081 })
1082 }
1083}
1084
1085/**
1086 * TensorView methods which require only read access via a scalar [TensorRef] source
1087 * and a clonable type.
1088 */
1089impl<T, S> TensorView<T, S, 0>
1090where
1091 T: Clone,
1092 S: TensorRef<T, 0>,
1093{
1094 /**
1095 * Returns a copy of the sole element in the 0 dimensional tensor.
1096 */
1097 pub fn scalar(&self) -> T {
1098 self.source.get_reference([]).unwrap().clone()
1099 }
1100}
1101
1102/**
1103 * TensorView methods which require mutable access via a [TensorMut] source and a [Default].
1104 */
1105impl<T, S> TensorView<T, S, 0>
1106where
1107 T: Default,
1108 S: TensorMut<T, 0>,
1109{
1110 /**
1111 * Returns the sole element in the 0 dimensional tensor.
1112 */
1113 pub fn into_scalar(self) -> T {
1114 TensorOwnedIterator::from(self.source).next().unwrap()
1115 }
1116}
1117
1118/**
1119 * TensorView methods which require mutable access via a [TensorMut] source.
1120 */
1121impl<T, S, const D: usize> TensorView<T, S, D>
1122where
1123 T: Clone,
1124 S: TensorMut<T, D>,
1125{
1126 /**
1127 * Applies a function to all values in the tensor view, modifying
1128 * the tensor in place.
1129 */
1130 pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1131 self.iter_reference_mut()
1132 .for_each(|x| *x = mapping_function(x.clone()));
1133 }
1134
1135 /**
1136 * Applies a function to all values and each value's index in the tensor view, modifying
1137 * the tensor view in place.
1138 */
1139 pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
1140 self.iter_reference_mut()
1141 .with_index()
1142 .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
1143 }
1144}
1145
1146impl<T, S> TensorView<T, S, 1>
1147where
1148 T: Numeric,
1149 for<'a> &'a T: NumericRef<T>,
1150 S: TensorRef<T, 1>,
1151{
1152 /**
1153 * Computes the scalar product of two equal length vectors. For two vectors `[a,b,c]` and
1154 * `[d,e,f]`, returns `a*d + b*e + c*f`. This is also known as the dot product.
1155 *
1156 * ```
1157 * use easy_ml::tensors::Tensor;
1158 * use easy_ml::tensors::views::TensorView;
1159 * let tensor_view = TensorView::from(Tensor::from([("sequence", 5)], vec![3, 4, 5, 6, 7]));
1160 * assert_eq!(tensor_view.scalar_product(&tensor_view), 3*3 + 4*4 + 5*5 + 6*6 + 7*7);
1161 * ```
1162 *
1163 * # Generics
1164 *
1165 * This method can be called with any right hand side that can be converted to a TensorView,
1166 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1167 *
1168 * # Panics
1169 *
1170 * If the two vectors are not of equal length or their dimension names do not match.
1171 */
1172 // Would like this impl block to be in operations.rs too but then it would show first in the
1173 // TensorView docs which isn't ideal
1174 pub fn scalar_product<S2, I>(&self, rhs: I) -> T
1175 where
1176 I: Into<TensorView<T, S2, 1>>,
1177 S2: TensorRef<T, 1>,
1178 {
1179 self.scalar_product_less_generic(rhs.into())
1180 }
1181}
1182
1183impl<T, S> TensorView<T, S, 2>
1184where
1185 T: Numeric,
1186 for<'a> &'a T: NumericRef<T>,
1187 S: TensorRef<T, 2>,
1188{
1189 /**
1190 * Returns the determinant of this square matrix, or None if the matrix
1191 * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant_tensor())
1192 */
1193 pub fn determinant(&self) -> Option<T> {
1194 linear_algebra::determinant_tensor::<T, _, _>(self)
1195 }
1196
1197 /**
1198 * Computes the inverse of a matrix provided that it exists. To have an inverse a
1199 * matrix must be square (same number of rows and columns) and it must also have a
1200 * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse_tensor())
1201 */
1202 pub fn inverse(&self) -> Option<Tensor<T, 2>> {
1203 linear_algebra::inverse_tensor::<T, _, _>(self)
1204 }
1205
1206 /**
1207 * Computes the covariance matrix for this feature matrix along the specified feature
1208 * dimension in this matrix. See [`linear_algebra`](crate::linear_algebra::covariance()).
1209 */
1210 pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
1211 linear_algebra::covariance::<T, _, _>(self, feature_dimension)
1212 }
1213}
1214
1215macro_rules! tensor_view_select_impl {
1216 (impl TensorView $d:literal 1) => {
1217 impl<T, S> TensorView<T, S, $d>
1218 where
1219 S: TensorRef<T, $d>,
1220 {
1221 /**
1222 * Selects the provided dimension name and index pairs in this TensorView, returning a
1223 * TensorView which has fewer dimensions than this TensorView, with the removed dimensions
1224 * always indexed as the provided values.
1225 *
1226 * This is a shorthand for manually constructing the TensorView and
1227 * [TensorIndex]
1228 *
1229 * Note: due to limitations in Rust's const generics support, this method is only
1230 * implemented for `provided_indexes` of length 1 and `D` from 1 to 6. You can fall
1231 * back to manual construction to create `TensorIndex`es with multiple provided
1232 * indexes if you need to reduce dimensionality by more than 1 dimension at a time.
1233 */
1234 #[track_caller]
1235 pub fn select(
1236 &self,
1237 provided_indexes: [(Dimension, usize); 1],
1238 ) -> TensorView<T, TensorIndex<T, &S, $d, 1>, { $d - 1 }> {
1239 TensorView::from(TensorIndex::from(&self.source, provided_indexes))
1240 }
1241
1242 /**
1243 * Selects the provided dimension name and index pairs in this TensorView, returning a
1244 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1245 * always indexed as the provided values. The TensorIndex mutably borrows this
1246 * Tensor, and can therefore mutate it
1247 *
1248 * See [select](TensorView::select)
1249 */
1250 #[track_caller]
1251 pub fn select_mut(
1252 &mut self,
1253 provided_indexes: [(Dimension, usize); 1],
1254 ) -> TensorView<T, TensorIndex<T, &mut S, $d, 1>, { $d - 1 }> {
1255 TensorView::from(TensorIndex::from(&mut self.source, provided_indexes))
1256 }
1257
1258 /**
1259 * Selects the provided dimension name and index pairs in this TensorView, returning a
1260 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1261 * always indexed as the provided values. The TensorIndex takes ownership of this
1262 * Tensor, and can therefore mutate it
1263 *
1264 * See [select](TensorView::select)
1265 */
1266 #[track_caller]
1267 pub fn select_owned(
1268 self,
1269 provided_indexes: [(Dimension, usize); 1],
1270 ) -> TensorView<T, TensorIndex<T, S, $d, 1>, { $d - 1 }> {
1271 TensorView::from(TensorIndex::from(self.source, provided_indexes))
1272 }
1273 }
1274 };
1275}
1276
1277tensor_view_select_impl!(impl TensorView 6 1);
1278tensor_view_select_impl!(impl TensorView 5 1);
1279tensor_view_select_impl!(impl TensorView 4 1);
1280tensor_view_select_impl!(impl TensorView 3 1);
1281tensor_view_select_impl!(impl TensorView 2 1);
1282tensor_view_select_impl!(impl TensorView 1 1);
1283
1284macro_rules! tensor_view_expand_impl {
1285 (impl Tensor $d:literal 1) => {
1286 impl<T, S> TensorView<T, S, $d>
1287 where
1288 S: TensorRef<T, $d>,
1289 {
1290 /**
1291 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1292 * a particular position within the shape, returning a TensorView which has more
1293 * dimensions than this TensorView.
1294 *
1295 * This is a shorthand for manually constructing the TensorView and
1296 * [TensorExpansion]
1297 *
1298 * Note: due to limitations in Rust's const generics support, this method is only
1299 * implemented for `extra_dimension_names` of length 1 and `D` from 0 to 5. You can
1300 * fall back to manual construction to create `TensorExpansion`s with multiple provided
1301 * indexes if you need to increase dimensionality by more than 1 dimension at a time.
1302 */
1303 #[track_caller]
1304 pub fn expand(
1305 &self,
1306 extra_dimension_names: [(usize, Dimension); 1],
1307 ) -> TensorView<T, TensorExpansion<T, &S, $d, 1>, { $d + 1 }> {
1308 TensorView::from(TensorExpansion::from(&self.source, extra_dimension_names))
1309 }
1310
1311 /**
1312 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1313 * a particular position within the shape, returning a TensorView which has more
1314 * dimensions than this Tensor. The TensorIndex mutably borrows this
1315 * Tensor, and can therefore mutate it
1316 *
1317 * See [expand](Tensor::expand)
1318 */
1319 #[track_caller]
1320 pub fn expand_mut(
1321 &mut self,
1322 extra_dimension_names: [(usize, Dimension); 1],
1323 ) -> TensorView<T, TensorExpansion<T, &mut S, $d, 1>, { $d + 1 }> {
1324 TensorView::from(TensorExpansion::from(
1325 &mut self.source,
1326 extra_dimension_names,
1327 ))
1328 }
1329
1330 /**
1331 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1332 * a particular position within the shape, returning a TensorView which has more
1333 * dimensions than this Tensor. The TensorIndex takes ownership of this
1334 * Tensor, and can therefore mutate it
1335 *
1336 * See [expand](Tensor::expand)
1337 */
1338 #[track_caller]
1339 pub fn expand_owned(
1340 self,
1341 extra_dimension_names: [(usize, Dimension); 1],
1342 ) -> TensorView<T, TensorExpansion<T, S, $d, 1>, { $d + 1 }> {
1343 TensorView::from(TensorExpansion::from(self.source, extra_dimension_names))
1344 }
1345 }
1346 };
1347}
1348
1349tensor_view_expand_impl!(impl Tensor 0 1);
1350tensor_view_expand_impl!(impl Tensor 1 1);
1351tensor_view_expand_impl!(impl Tensor 2 1);
1352tensor_view_expand_impl!(impl Tensor 3 1);
1353tensor_view_expand_impl!(impl Tensor 4 1);
1354tensor_view_expand_impl!(impl Tensor 5 1);
1355
1356/**
1357 * Debug implementations for TensorView additionally show the visible data and visible dimensions
1358 * reported by the source as fields. This is in addition to recursive debug content of the actual
1359 * source.
1360 */
1361impl<T, S, const D: usize> std::fmt::Debug for TensorView<T, S, D>
1362where
1363 T: std::fmt::Debug,
1364 S: std::fmt::Debug + TensorRef<T, D>,
1365{
1366 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1367 f.debug_struct("TensorView")
1368 .field("visible", &DebugSourceVisible::from(&self.source))
1369 .field("shape", &self.source.view_shape())
1370 .field("source", &self.source)
1371 .finish()
1372 }
1373}
1374
1375struct DebugSourceVisible<T, S, const D: usize> {
1376 source: S,
1377 _type: PhantomData<T>,
1378}
1379
1380impl<T, S, const D: usize> DebugSourceVisible<T, S, D>
1381where
1382 T: std::fmt::Debug,
1383 S: std::fmt::Debug + TensorRef<T, D>,
1384{
1385 fn from(source: S) -> DebugSourceVisible<T, S, D> {
1386 DebugSourceVisible {
1387 source,
1388 _type: PhantomData,
1389 }
1390 }
1391}
1392
1393impl<T, S, const D: usize> std::fmt::Debug for DebugSourceVisible<T, S, D>
1394where
1395 T: std::fmt::Debug,
1396 S: std::fmt::Debug + TensorRef<T, D>,
1397{
1398 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1399 f.debug_list()
1400 .entries(TensorReferenceIterator::from(&self.source))
1401 .finish()
1402 }
1403}
1404
1405#[test]
1406fn test_debug() {
1407 let x = Tensor::from([("rows", 3), ("columns", 4)], (0..12).collect());
1408 let view = TensorView::from(&x);
1409 let debugged = format!("{:?}\n{:?}", x, view);
1410 assert_eq!(
1411 debugged,
1412 r#"Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] }
1413TensorView { visible: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], source: Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] } }"#
1414 )
1415}
1416
1417#[test]
1418fn test_debug_clipped() {
1419 let x = Tensor::from([("rows", 2), ("columns", 3)], (0..6).collect());
1420 let view = TensorView::from(&x)
1421 .range_owned([("columns", IndexRange::new(1, 2))])
1422 .unwrap();
1423 let debugged = format!("{:#?}\n{:#?}", x, view);
1424 println!("{:#?}\n{:#?}", x, view);
1425 assert_eq!(
1426 debugged,
1427 r#"Tensor {
1428 data: [
1429 0,
1430 1,
1431 2,
1432 3,
1433 4,
1434 5,
1435 ],
1436 shape: [
1437 (
1438 "rows",
1439 2,
1440 ),
1441 (
1442 "columns",
1443 3,
1444 ),
1445 ],
1446 strides: [
1447 3,
1448 1,
1449 ],
1450}
1451TensorView {
1452 visible: [
1453 1,
1454 2,
1455 4,
1456 5,
1457 ],
1458 shape: [
1459 (
1460 "rows",
1461 2,
1462 ),
1463 (
1464 "columns",
1465 2,
1466 ),
1467 ],
1468 source: TensorRange {
1469 source: Tensor {
1470 data: [
1471 0,
1472 1,
1473 2,
1474 3,
1475 4,
1476 5,
1477 ],
1478 shape: [
1479 (
1480 "rows",
1481 2,
1482 ),
1483 (
1484 "columns",
1485 3,
1486 ),
1487 ],
1488 strides: [
1489 3,
1490 1,
1491 ],
1492 },
1493 range: [
1494 IndexRange {
1495 start: 0,
1496 length: 2,
1497 },
1498 IndexRange {
1499 start: 1,
1500 length: 2,
1501 },
1502 ],
1503 _type: PhantomData<i32>,
1504 },
1505}"#
1506 )
1507}
1508
1509/**
1510 * Any tensor view of a Displayable type implements Display
1511 *
1512 * You can control the precision of the formatting using format arguments, i.e.
1513 * `format!("{:.3}", tensor)`
1514 */
1515impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorView<T, S, D>
1516where
1517 T: std::fmt::Display,
1518 S: TensorRef<T, D>,
1519{
1520 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1521 crate::tensors::display::format_view(&self.source, f)
1522 }
1523}
1524
1525/**
1526 * A Tensor can be converted to a TensorView of that Tensor.
1527 */
1528impl<T, const D: usize> From<Tensor<T, D>> for TensorView<T, Tensor<T, D>, D> {
1529 fn from(tensor: Tensor<T, D>) -> TensorView<T, Tensor<T, D>, D> {
1530 TensorView::from(tensor)
1531 }
1532}
1533
1534/**
1535 * A reference to a Tensor can be converted to a TensorView of that referenced Tensor.
1536 */
1537impl<'a, T, const D: usize> From<&'a Tensor<T, D>> for TensorView<T, &'a Tensor<T, D>, D> {
1538 fn from(tensor: &Tensor<T, D>) -> TensorView<T, &Tensor<T, D>, D> {
1539 TensorView::from(tensor)
1540 }
1541}
1542
1543/**
1544 * A mutable reference to a Tensor can be converted to a TensorView of that mutably referenced
1545 * Tensor.
1546 */
1547impl<'a, T, const D: usize> From<&'a mut Tensor<T, D>> for TensorView<T, &'a mut Tensor<T, D>, D> {
1548 fn from(tensor: &mut Tensor<T, D>) -> TensorView<T, &mut Tensor<T, D>, D> {
1549 TensorView::from(tensor)
1550 }
1551}
1552
1553/**
1554 * A reference to a TensorView can be converted to an owned TensorView with a reference to the
1555 * source type of that first TensorView.
1556 */
1557impl<'a, T, S, const D: usize> From<&'a TensorView<T, S, D>> for TensorView<T, &'a S, D>
1558where
1559 S: TensorRef<T, D>,
1560{
1561 fn from(tensor_view: &TensorView<T, S, D>) -> TensorView<T, &S, D> {
1562 TensorView::from(tensor_view.source_ref())
1563 }
1564}
1565
1566/**
1567 * A mutable reference to a TensorView can be converted to an owned TensorView with a mutable
1568 * reference to the source type of that first TensorView.
1569 */
1570impl<'a, T, S, const D: usize> From<&'a mut TensorView<T, S, D>> for TensorView<T, &'a mut S, D>
1571where
1572 S: TensorRef<T, D>,
1573{
1574 fn from(tensor_view: &mut TensorView<T, S, D>) -> TensorView<T, &mut S, D> {
1575 TensorView::from(tensor_view.source_ref_mut())
1576 }
1577}