easy_ml/tensors/views.rs
1/*!
2 * Generic views into a tensor.
3 *
4 * The concept of a view into a tensor is built from the low level [TensorRef] and
5 * [TensorMut] traits which define having read and read/write access to Tensor data
6 * respectively, and the high level API implemented on the [TensorView] struct.
7 *
8 * Since a Tensor is itself a TensorRef, the APIs for the traits are a little verbose to
9 * avoid name clashes with methods defined on the Tensor and TensorView types. You should
10 * typically use TensorRef and TensorMut implementations via the TensorView struct which provides
11 * an API closely resembling Tensor.
12 */
13
14use std::marker::PhantomData;
15
16use crate::linear_algebra;
17use crate::numeric::{Numeric, NumericRef};
18use crate::tensors::dimensions;
19use crate::tensors::indexing::{
20 TensorAccess, TensorIterator, TensorOwnedIterator, TensorReferenceIterator,
21 TensorReferenceMutIterator, TensorTranspose,
22};
23use crate::tensors::{Dimension, Tensor};
24
25mod indexes;
26mod map;
27mod ranges;
28mod renamed;
29mod reshape;
30mod reverse;
31pub mod traits;
32mod zip;
33
34pub use indexes::*;
35pub(crate) use map::*;
36pub use ranges::*;
37pub use renamed::*;
38pub use reshape::*;
39pub use reverse::*;
40pub use zip::*;
41
42/**
43* A shared/immutable reference to a tensor (or a portion of it) of some type and number of
44* dimensions.
45*
46* # Indexing
47*
48* A TensorRef has a shape of type `[(Dimension, usize); D]`. This defines the valid indexes along
49* each dimension name and length pair from 0 inclusive to the length exclusive. If the shape was
50* `[("r", 2), ("c", 2)]` the indexes used would be `[0,0]`, `[0,1]`, `[1,0]` and `[1,1]`.
51* Although the dimension name in each pair is used for many high level APIs, for TensorRef the
52* order of dimensions is used, and the indexes (`[usize; D]`) these trait methods are called with
53* must be in the same order as the shape. In general some `[("a", a), ("b", b), ("c", c)...]`
54* shape is indexed from `[0,0,0...]` through to `[a - 1, b - 1, c - 1...]`, regardless of how the
55* data is actually laid out in memory.
56*
57* # Safety
58*
59* In order to support returning references without bounds checking in a useful way, the
60* implementing type is required to uphold several invariants that cannot be checked by
61* the compiler.
62*
63* 1 - Any valid index as described in Indexing will yield a safe reference when calling
64* `get_reference_unchecked` and `get_reference_unchecked_mut`.
65*
66* 2 - The view shape that defines which indexes are valid may not be changed by a shared reference
67* to the TensorRef implementation. ie, the tensor may not be resized while a mutable reference is
68* held to it, except by that reference.
69*
70* 3 - All dimension names in the `view_shape` must be unique.
71*
72* 4 - All dimension lengths in the `view_shape` must be non zero.
73*
74* 5 - `data_layout` must return values correctly as documented on [`DataLayout`]
75*
76* Essentially, interior mutability causes problems, since code looping through the range of valid
77* indexes in a TensorRef needs to be able to rely on that range of valid indexes not changing.
78* This is trivially the case by default since a [Tensor] does not have any form of
79* interior mutability, and therefore an iterator holding a shared reference to a Tensor prevents
80* that tensor being resized. However, a type *wrongly* implementing TensorRef could introduce
81* interior mutability by putting the Tensor in an `Arc<Mutex<>>` which would allow another thread
82* to resize a tensor while an iterator was looping through previously valid indexes on a different
83* thread. This is the same contract as
84* [`NoInteriorMutability`](crate::matrices::views::NoInteriorMutability) used in the matrix APIs.
85*
86* Note that it is okay to be able to resize any TensorRef implementation if that always requires
87* an exclusive reference to the TensorRef/Tensor, since the exclusivity prevents the above
88* scenario.
89*/
90pub unsafe trait TensorRef<T, const D: usize> {
91 /**
92 * Gets a reference to the value at the index if the index is in range. Otherwise returns None.
93 */
94 fn get_reference(&self, indexes: [usize; D]) -> Option<&T>;
95
96 /**
97 * The shape this tensor has. See [dimensions] for an overview.
98 * The product of the lengths in the pairs define how many elements are in the tensor
99 * (or the portion of it that is visible).
100 */
101 fn view_shape(&self) -> [(Dimension, usize); D];
102
103 /**
104 * Gets a reference to the value at the index without doing any bounds checking. For a safe
105 * alternative see [get_reference](TensorRef::get_reference).
106 *
107 * # Safety
108 *
109 * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
110 * resulting reference is not used. Valid indexes are defined as in [TensorRef].
111 *
112 * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
113 * [TensorRef]: TensorRef
114 */
115 #[allow(clippy::missing_safety_doc)] // it's not missing
116 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T;
117
118 /**
119 * The way the data in this tensor is laid out in memory. In particular,
120 * [`Linear`](DataLayout) has several requirements on what is returned that must be upheld
121 * by implementations of this trait.
122 *
123 * For a [Tensor] this would return `DataLayout::Linear` in the same order as the `view_shape`,
124 * since the data is in a single line and the `view_shape` is in most significant dimension
125 * to least. Many views however, create shapes that do not correspond to linear memory, either
126 * by combining non array data with tensors, or hiding dimensions. Similarly, a row major
127 * Tensor might be transposed to a column major TensorView, so the view shape could be reversed
128 * compared to the order of significance of the dimensions in memory.
129 *
130 * In general, an array of dimension names matching the view shape order is big endian, and
131 * will be iterated through efficiently by [TensorIterator], and an array of dimension names
132 * in reverse of the view shape is in little endian order.
133 *
134 * The implementation of this trait must ensure that if it returns `DataLayout::Linear`
135 * the set of dimension names are returned in the order of most significant to least and match
136 * the set of the dimension names returned by `view_shape`.
137 */
138 fn data_layout(&self) -> DataLayout<D>;
139}
140
141/**
142 * How the data in the tensor is laid out in memory.
143 */
144#[derive(Clone, Debug, Eq, PartialEq)]
145pub enum DataLayout<const D: usize> {
146 /**
147 * The data is laid out in linear storage in memory, such that we could take a slice over the
148 * entire data specified by our `view_shape`.
149 *
150 * The `D` length array specifies the dimensions in the `view_shape` in the order of most
151 * significant dimension (in memory) to least.
152 *
153 * In general, an array of dimension names in the same order as the view shape is big endian
154 * order (implying the order of dimensions in the view shape is most significant to least),
155 * and will be iterated through efficiently by [TensorIterator], and an array of
156 * dimension names in reverse of the view shape is in little endian order
157 * (implying the order of dimensions in the view shape is least significant to most).
158 *
159 * In memory, the data will have some order such that if we want repeatedly take 1 step
160 * through memory from the first value to the last there will be a most significant dimension
161 * that always goes up, through to a least significant dimension with the most rapidly varying
162 * index.
163 *
164 * In most of Easy ML's Tensors, the `view_shape` dimensions would already be in the order of
165 * most significant dimension to least (since [Tensor] stores its data in big endian order),
166 * so the list of dimension names will just increment match the order of the view shape.
167 *
168 * For example, a tensor with a view shape of `[("batch", 2), ("row", 2), ("column", 3)]` that
169 * stores its data in most significant to least would be (indexed) like:
170 * ```ignore
171 * [
172 * (0,0,0), (0,0,1), (0,0,2),
173 * (0,1,0), (0,1,1), (0,1,2),
174 * (1,0,0), (1,0,1), (1,0,2),
175 * (1,1,0), (1,1,1), (1,1,2)
176 * ]
177 * ```
178 *
179 * To take one step in memory, we would increment the right most dimension index ("column"),
180 * counting our way up through to the left most dimension index ("batch"). If we changed
181 * this tensor to `[("column", 3), ("row", 2), ("batch", 2)]` so that the `view_shape` was
182 * swapped to least significant dimension to most but the data remained in the same order,
183 * our tensor would still have a DataLayout of `Linear(["batch", "row", "column"])`, since
184 * the indexes in the transposed `view_shape` correspond to an actual memory layout that's
185 * completely reversed. Alternatively, you could say we reversed the view shape but the
186 * memory layout never changed:
187 * ```ignore
188 * [
189 * (0,0,0), (1,0,0), (2,0,0),
190 * (0,1,0), (1,1,0), (2,1,0),
191 * (0,0,1), (1,0,1), (2,0,1),
192 * (0,1,1), (1,1,1), (2,1,1)
193 * ]
194 * ```
195 *
196 * To take one step in memory, we now need to increment the left most dimension index (on
197 * the view shape) ("column"), counting our way in reverse to the right most dimension
198 * index ("batch").
199 *
200 * That `["batch", "row", "column"]` is also exactly the order you would need to swap your
201 * dimensions on the `view_shape` to get back to most significant to least. A [TensorAccess]
202 * could reorder the tensor by this array order to get back to most significant to least
203 * ordering on the `view_shape` in order to iterate through the data efficiently.
204 */
205 Linear([Dimension; D]),
206 /**
207 * The data is not laid out in linear storage in memory.
208 */
209 NonLinear,
210 /**
211 * The data is not laid out in a linear or non linear way, or we don't know how it's laid
212 * out.
213 */
214 Other,
215}
216
217/**
218 * A unique/mutable reference to a tensor (or a portion of it) of some type.
219 *
220 * # Safety
221 *
222 * See [TensorRef].
223 */
224pub unsafe trait TensorMut<T, const D: usize>: TensorRef<T, D> {
225 /**
226 * Gets a mutable reference to the value at the index, if the index is in range. Otherwise
227 * returns None.
228 */
229 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T>;
230
231 /**
232 * Gets a mutable reference to the value at the index without doing any bounds checking.
233 * For a safe alternative see [get_reference_mut](TensorMut::get_reference_mut).
234 *
235 * # Safety
236 *
237 * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
238 * resulting reference is not used. Valid indexes are defined as in [TensorRef].
239 *
240 * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
241 * [TensorRef]: TensorRef
242 */
243 #[allow(clippy::missing_safety_doc)] // it's not missing
244 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T;
245}
246
247/**
248 * A view into some or all of a tensor.
249 *
250 * A TensorView has a similar relationship to a [`Tensor`] as a `&str` has to a `String`, or an
251 * array slice to an array. A TensorView cannot resize its source, and may span only a portion
252 * of the source Tensor in each dimension.
253 *
254 * However a TensorView is generic not only over the type of the data in the Tensor,
255 * but also over the way the Tensor is 'sliced' and the two are orthogonal to each other.
256 *
257 * TensorView closely mirrors the API of Tensor.
258 * Methods that create a new tensor do not return a TensorView, they return a Tensor.
259 */
260#[derive(Clone)]
261pub struct TensorView<T, S, const D: usize> {
262 source: S,
263 _type: PhantomData<T>,
264}
265
266/**
267 * TensorView methods which require only read access via a [TensorRef] source.
268 */
269impl<T, S, const D: usize> TensorView<T, S, D>
270where
271 S: TensorRef<T, D>,
272{
273 /**
274 * Creates a TensorView from a source of some type.
275 *
276 * The lifetime of the source determines the lifetime of the TensorView created. If the
277 * TensorView is created from a reference to a Tensor, then the TensorView cannot live
278 * longer than the Tensor referenced.
279 */
280 pub fn from(source: S) -> TensorView<T, S, D> {
281 TensorView {
282 source,
283 _type: PhantomData,
284 }
285 }
286
287 /**
288 * Consumes the tensor view, yielding the source it was created from.
289 */
290 pub fn source(self) -> S {
291 self.source
292 }
293
294 /**
295 * Gives a reference to the tensor view's source.
296 */
297 pub fn source_ref(&self) -> &S {
298 &self.source
299 }
300
301 /**
302 * Gives a mutable reference to the tensor view's source.
303 */
304 pub fn source_ref_mut(&mut self) -> &mut S {
305 &mut self.source
306 }
307
308 /**
309 * The shape of this tensor view. Since Tensors are named Tensors, their shape is not just a
310 * list of length along each dimension, but instead a list of pairs of names and lengths.
311 *
312 * Note that a TensorView may have a shape which is different than the Tensor it is providing
313 * access to the data of. The TensorView might be [masking dimensions](TensorIndex) or
314 * elements from the shape, or exposing [false ones](TensorExpansion).
315 *
316 * See also
317 * - [dimensions]
318 * - [indexing](crate::tensors::indexing)
319 */
320 pub fn shape(&self) -> [(Dimension, usize); D] {
321 self.source.view_shape()
322 }
323
324 /**
325 * Returns the length of the dimension name provided, if one is present in the tensor view.
326 *
327 * See also
328 * - [dimensions]
329 * - [indexing](crate::tensors::indexing)
330 */
331 pub fn length_of(&self, dimension: Dimension) -> Option<usize> {
332 dimensions::length_of(&self.source.view_shape(), dimension)
333 }
334
335 /**
336 * Returns the last index of the dimension name provided, if one is present in the tensor view.
337 *
338 * This is always 1 less than the length, the 'index' in this sense is based on what the
339 * Tensor's shape is, not any implementation index.
340 *
341 * See also
342 * - [dimensions]
343 * - [indexing](crate::tensors::indexing)
344 */
345 pub fn last_index_of(&self, dimension: Dimension) -> Option<usize> {
346 dimensions::last_index_of(&self.source.view_shape(), dimension)
347 }
348
349 /**
350 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
351 * to read values from this tensor view.
352 *
353 * # Panics
354 *
355 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
356 */
357 #[track_caller]
358 pub fn index_by(&self, dimensions: [Dimension; D]) -> TensorAccess<T, &S, D> {
359 TensorAccess::from(&self.source, dimensions)
360 }
361
362 /**
363 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
364 * to read or write values from this tensor view.
365 *
366 * # Panics
367 *
368 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
369 */
370 #[track_caller]
371 pub fn index_by_mut(&mut self, dimensions: [Dimension; D]) -> TensorAccess<T, &mut S, D> {
372 TensorAccess::from(&mut self.source, dimensions)
373 }
374
375 /**
376 * Returns a TensorAccess which can be indexed in the order of the supplied dimensions
377 * to read or write values from this tensor view.
378 *
379 * # Panics
380 *
381 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
382 */
383 #[track_caller]
384 pub fn index_by_owned(self, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
385 TensorAccess::from(self.source, dimensions)
386 }
387
388 /**
389 * Creates a TensorAccess which will index into the dimensions of the source this TensorView
390 * was created with in the same order as they were declared.
391 * See [TensorAccess::from_source_order].
392 */
393 pub fn index(&self) -> TensorAccess<T, &S, D> {
394 TensorAccess::from_source_order(&self.source)
395 }
396
397 /**
398 * Creates a TensorAccess which will index into the dimensions of the source this TensorView
399 * was created with in the same order as they were declared. The TensorAccess mutably borrows
400 * the source, and can therefore mutate it if it implements TensorMut.
401 * See [TensorAccess::from_source_order].
402 */
403 pub fn index_mut(&mut self) -> TensorAccess<T, &mut S, D> {
404 TensorAccess::from_source_order(&mut self.source)
405 }
406
407 /**
408 * Creates a TensorAccess which will index into the dimensions this Tensor was
409 * created with in the same order as they were provided. The TensorAccess takes ownership
410 * of the Tensor, and can therefore mutate it. The TensorAccess takes ownership of
411 * the source, and can therefore mutate it if it implements TensorMut.
412 * See [TensorAccess::from_source_order].
413 */
414 pub fn index_owned(self) -> TensorAccess<T, S, D> {
415 TensorAccess::from_source_order(self.source)
416 }
417
418 /**
419 * Returns an iterator over references to the data in this TensorView.
420 */
421 pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, S, D> {
422 TensorReferenceIterator::from(&self.source)
423 }
424
425 /**
426 * Returns a TensorView with the dimension names of the shape renamed to the provided
427 * dimensions. The data of this tensor and the dimension lengths and order remain unchanged.
428 *
429 * This is a shorthand for constructing the TensorView from this TensorView. See
430 * [`Tensor::rename_view`](Tensor::rename_view).
431 *
432 * # Panics
433 *
434 * - If a dimension name is not unique
435 */
436 #[track_caller]
437 pub fn rename_view(
438 &self,
439 dimensions: [Dimension; D],
440 ) -> TensorView<T, TensorRename<T, &S, D>, D> {
441 TensorView::from(TensorRename::from(&self.source, dimensions))
442 }
443
444 /**
445 * Returns a TensorView with the dimensions changed to the provided shape without moving any
446 * data around. The new Tensor may also have a different number of dimensions.
447 *
448 * This is a shorthand for constructing the TensorView from this TensorView. See
449 * [`Tensor::reshape_view`](Tensor::reshape_view).
450 *
451 * # Panics
452 *
453 * - If the number of provided elements in the new shape does not match the product of the
454 * dimension lengths in the existing tensor's shape.
455 * - If a dimension name is not unique
456 */
457 pub fn reshape<const D2: usize>(
458 &self,
459 shape: [(Dimension, usize); D2],
460 ) -> TensorView<T, TensorReshape<T, &S, D, D2>, D2> {
461 TensorView::from(TensorReshape::from(&self.source, shape))
462 }
463
464 /**
465 * Returns a TensorView with the dimensions changed to the provided shape without moving any
466 * data around. The new Tensor may also have a different number of dimensions. The
467 * TensorReshape mutably borrows the source, and can therefore mutate it if it implements
468 * TensorMut.
469 *
470 * This is a shorthand for constructing the TensorView from this TensorView. See
471 * [`Tensor::reshape_view`](Tensor::reshape_view).
472 *
473 * # Panics
474 *
475 * - If the number of provided elements in the new shape does not match the product of the
476 * dimension lengths in the existing tensor's shape.
477 * - If a dimension name is not unique
478 */
479 pub fn reshape_mut<const D2: usize>(
480 &mut self,
481 shape: [(Dimension, usize); D2],
482 ) -> TensorView<T, TensorReshape<T, &mut S, D, D2>, D2> {
483 TensorView::from(TensorReshape::from(&mut self.source, shape))
484 }
485
486 /**
487 * Returns a TensorView with the dimensions changed to the provided shape without moving any
488 * data around. The new Tensor may also have a different number of dimensions. The
489 * TensorReshape takes ownership of the source, and can therefore mutate it if it implements
490 * TensorMut.
491 *
492 * This is a shorthand for constructing the TensorView from this TensorView. See
493 * [`Tensor::reshape_view`](Tensor::reshape_view).
494 *
495 * # Panics
496 *
497 * - If the number of provided elements in the new shape does not match the product of the
498 * dimension lengths in the existing tensor's shape.
499 * - If a dimension name is not unique
500 */
501 pub fn reshape_owned<const D2: usize>(
502 self,
503 shape: [(Dimension, usize); D2],
504 ) -> TensorView<T, TensorReshape<T, S, D, D2>, D2> {
505 TensorView::from(TensorReshape::from(self.source, shape))
506 }
507
508 /**
509 * Given the dimension name, returns a view of this tensor reshaped to one dimension
510 * with a length equal to the number of elements in this tensor.
511 */
512 pub fn flatten(&self, dimension: Dimension) -> TensorView<T, TensorReshape<T, &S, D, 1>, 1> {
513 self.reshape([(dimension, dimensions::elements(&self.shape()))])
514 }
515
516 /**
517 * Given the dimension name, returns a view of this tensor reshaped to one dimension
518 * with a length equal to the number of elements in this tensor.
519 */
520 pub fn flatten_mut(
521 &mut self,
522 dimension: Dimension,
523 ) -> TensorView<T, TensorReshape<T, &mut S, D, 1>, 1> {
524 self.reshape_mut([(dimension, dimensions::elements(&self.shape()))])
525 }
526
527 /**
528 * Given the dimension name, returns a view of this tensor reshaped to one dimension
529 * with a length equal to the number of elements in this tensor.
530 *
531 * If you intend to query the tensor a lot after creating the view, consider
532 * using [flatten_into_tensor](TensorView::flatten_into_tensor) instead as it will have
533 * less overhead to index after creation.
534 */
535 pub fn flatten_owned(
536 self,
537 dimension: Dimension,
538 ) -> TensorView<T, TensorReshape<T, S, D, 1>, 1> {
539 let length = dimensions::elements(&self.shape());
540 self.reshape_owned([(dimension, length)])
541 }
542
543 /**
544 * Given the dimension name, returns a new tensor reshaped to one dimension
545 * with a length equal to the number of elements in this tensor and all elements
546 * copied into the new tensor.
547 */
548 pub fn flatten_into_tensor(self, dimension: Dimension) -> Tensor<T, 1>
549 where
550 T: Clone,
551 {
552 // TODO: Want a specialisation here to use Tensor::reshape_owned when we
553 // know that our source type is Tensor
554 let length = dimensions::elements(&self.shape());
555 Tensor::from([(dimension, length)], self.iter().collect())
556 }
557
558 /**
559 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
560 * range from view. Error cases are documented on [TensorRange].
561 *
562 * This is a shorthand for constructing the TensorView from this TensorView. See
563 * [`Tensor::range`](Tensor::range).
564 */
565 pub fn range<R, const P: usize>(
566 &self,
567 ranges: [(Dimension, R); P],
568 ) -> Result<TensorView<T, TensorRange<T, &S, D>, D>, IndexRangeValidationError<D, P>>
569 where
570 R: Into<IndexRange>,
571 {
572 TensorRange::from(&self.source, ranges).map(|range| TensorView::from(range))
573 }
574
575 /**
576 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
577 * range from view. Error cases are documented on [TensorRange]. The TensorRange
578 * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
579 *
580 * This is a shorthand for constructing the TensorView from this TensorView. See
581 * [`Tensor::range`](Tensor::range).
582 */
583 pub fn range_mut<R, const P: usize>(
584 &mut self,
585 ranges: [(Dimension, R); P],
586 ) -> Result<TensorView<T, TensorRange<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
587 where
588 R: Into<IndexRange>,
589 {
590 TensorRange::from(&mut self.source, ranges).map(|range| TensorView::from(range))
591 }
592
593 /**
594 * Returns a TensorView with a range taken in P dimensions, hiding the values **outside** the
595 * range from view. Error cases are documented on [TensorRange]. The TensorRange
596 * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
597 *
598 * This is a shorthand for constructing the TensorView from this TensorView. See
599 * [`Tensor::range`](Tensor::range).
600 */
601 pub fn range_owned<R, const P: usize>(
602 self,
603 ranges: [(Dimension, R); P],
604 ) -> Result<TensorView<T, TensorRange<T, S, D>, D>, IndexRangeValidationError<D, P>>
605 where
606 R: Into<IndexRange>,
607 {
608 TensorRange::from(self.source, ranges).map(|range| TensorView::from(range))
609 }
610
611 /**
612 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
613 * range from view. Error cases are documented on [TensorMask].
614 *
615 * This is a shorthand for constructing the TensorView from this TensorView. See
616 * [`Tensor::mask`](Tensor::mask).
617 */
618 pub fn mask<R, const P: usize>(
619 &self,
620 masks: [(Dimension, R); P],
621 ) -> Result<TensorView<T, TensorMask<T, &S, D>, D>, IndexRangeValidationError<D, P>>
622 where
623 R: Into<IndexRange>,
624 {
625 TensorMask::from(&self.source, masks).map(|mask| TensorView::from(mask))
626 }
627
628 /**
629 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
630 * range from view. Error cases are documented on [TensorMask]. The TensorMask
631 * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
632 *
633 * This is a shorthand for constructing the TensorView from this TensorView. See
634 * [`Tensor::mask`](Tensor::mask).
635 */
636 pub fn mask_mut<R, const P: usize>(
637 &mut self,
638 masks: [(Dimension, R); P],
639 ) -> Result<TensorView<T, TensorMask<T, &mut S, D>, D>, IndexRangeValidationError<D, P>>
640 where
641 R: Into<IndexRange>,
642 {
643 TensorMask::from(&mut self.source, masks).map(|mask| TensorView::from(mask))
644 }
645
646 /**
647 * Returns a TensorView with a mask taken in P dimensions, hiding the values **inside** the
648 * range from view. Error cases are documented on [TensorMask]. The TensorMask
649 * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
650 *
651 * This is a shorthand for constructing the TensorView from this TensorView. See
652 * [`Tensor::mask`](Tensor::mask).
653 */
654 pub fn mask_owned<R, const P: usize>(
655 self,
656 masks: [(Dimension, R); P],
657 ) -> Result<TensorView<T, TensorMask<T, S, D>, D>, IndexRangeValidationError<D, P>>
658 where
659 R: Into<IndexRange>,
660 {
661 TensorMask::from(self.source, masks).map(|mask| TensorView::from(mask))
662 }
663
664 /**
665 * Returns a TensorView with a mask taken in the provided dimension, hiding
666 * all but the start_and_end number of values at the start and end of the
667 * dimension from view.
668 *
669 * This is a shorthand for constructing the TensorView from this TensorView. See
670 * [`Tensor::start_and_end_of`](Tensor::start_and_end_of).
671 *
672 * # Panics
673 *
674 * - If the start_and_end value is 0 - this is not a valid mask as it would
675 * hide all elements
676 * - If the dimension is not in the tensor's shape.
677 */
678 #[track_caller]
679 pub fn start_and_end_of(
680 &self,
681 dimension: Dimension,
682 start_and_end: usize,
683 ) -> TensorView<T, TensorMask<T, &S, D>, D> {
684 TensorMask::panicking_start_and_end_of(&self.source, dimension, start_and_end)
685 }
686
687 /**
688 * Returns a TensorView with a mask taken in the provided dimension, hiding
689 * all but the start_and_end number of values at the start and end of the
690 * dimension from view. The TensorMask mutably borrows the source, and can
691 * therefore mutate it if it implements TensorMut.
692 *
693 * This is a shorthand for constructing the TensorView from this TensorView. See
694 * [`Tensor::start_and_end_of_mut`](Tensor::start_and_end_of_mut).
695 *
696 * # Panics
697 *
698 * - If the start_and_end value is 0 - this is not a valid mask as it would
699 * hide all elements
700 * - If the dimension is not in the tensor's shape.
701 */
702 #[track_caller]
703 pub fn start_and_end_of_mut(
704 &mut self,
705 dimension: Dimension,
706 start_and_end: usize,
707 ) -> TensorView<T, TensorMask<T, &mut S, D>, D> {
708 TensorMask::panicking_start_and_end_of(&mut self.source, dimension, start_and_end)
709 }
710
711 /**
712 * Returns a TensorView with a mask taken in the provided dimension, hiding
713 * all but the start_and_end number of values at the start and end of the
714 * dimension from view. The TensorMask takes ownership of the source, and
715 * can therefore mutate it if it implements TensorMut.
716 *
717 * This is a shorthand for constructing the TensorView from this TensorView. See
718 * [`Tensor::start_and_end_of_owned`](Tensor::start_and_end_of_owned).
719 *
720 * # Panics
721 *
722 * - If the start_and_end value is 0 - this is not a valid mask as it would
723 * hide all elements
724 * - If the dimension is not in the tensor's shape.
725 */
726 #[track_caller]
727 pub fn start_and_end_of_owned(
728 self,
729 dimension: Dimension,
730 start_and_end: usize,
731 ) -> TensorView<T, TensorMask<T, S, D>, D> {
732 TensorMask::panicking_start_and_end_of(self.source, dimension, start_and_end)
733 }
734
735 /**
736 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
737 * order. The data of this tensor and the dimension lengths remain unchanged.
738 *
739 * This is a shorthand for constructing the TensorView from this TensorView. See
740 * [`Tensor::reverse`](Tensor::reverse).
741 *
742 * # Panics
743 *
744 * - If a dimension name is not in the tensor's shape or is repeated.
745 */
746 #[track_caller]
747 pub fn reverse(&self, dimensions: &[Dimension]) -> TensorView<T, TensorReverse<T, &S, D>, D> {
748 TensorView::from(TensorReverse::from(&self.source, dimensions))
749 }
750
751 /**
752 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
753 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
754 * mutably borrows the source, and can therefore mutate it if it implements TensorMut.
755 *
756 * This is a shorthand for constructing the TensorView from this TensorView. See
757 * [`Tensor::reverse`](Tensor::reverse).
758 *
759 * # Panics
760 *
761 * - If a dimension name is not in the tensor's shape or is repeated.
762 */
763 #[track_caller]
764 pub fn reverse_mut(
765 &mut self,
766 dimensions: &[Dimension],
767 ) -> TensorView<T, TensorReverse<T, &mut S, D>, D> {
768 TensorView::from(TensorReverse::from(&mut self.source, dimensions))
769 }
770
771 /**
772 * Returns a TensorView with the dimension names provided of the shape reversed in iteration
773 * order. The data of this tensor and the dimension lengths remain unchanged. The TensorReverse
774 * takes ownership of the source, and can therefore mutate it if it implements TensorMut.
775 *
776 * This is a shorthand for constructing the TensorView from this TensorView. See
777 * [`Tensor::reverse`](Tensor::reverse).
778 *
779 * # Panics
780 *
781 * - If a dimension name is not in the tensor's shape or is repeated.
782 */
783 #[track_caller]
784 pub fn reverse_owned(
785 self,
786 dimensions: &[Dimension],
787 ) -> TensorView<T, TensorReverse<T, S, D>, D> {
788 TensorView::from(TensorReverse::from(self.source, dimensions))
789 }
790
791 /**
792 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
793 * mapped by a function. The value pairs are not copied for you, if you're using `Copy` types
794 * or need to clone the values anyway, you can use
795 * [`TensorView::elementwise`](TensorView::elementwise) instead.
796 *
797 * # Generics
798 *
799 * This method can be called with any right hand side that can be converted to a TensorView,
800 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
801 *
802 * # Panics
803 *
804 * If the two tensors have different shapes.
805 */
806 #[track_caller]
807 pub fn elementwise_reference<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
808 where
809 I: Into<TensorView<T, S2, D>>,
810 S2: TensorRef<T, D>,
811 M: Fn(&T, &T) -> T,
812 {
813 self.elementwise_reference_less_generic(rhs.into(), mapping_function)
814 }
815
816 /**
817 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
818 * mapped by a function. The mapping function also receives each index corresponding to the
819 * value pairs. The value pairs are not copied for you, if you're using `Copy` types
820 * or need to clone the values anyway, you can use
821 * [`TensorView::elementwise_with_index`](TensorView::elementwise_with_index) instead.
822 *
823 * # Generics
824 *
825 * This method can be called with any right hand side that can be converted to a TensorView,
826 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
827 *
828 * # Panics
829 *
830 * If the two tensors have different shapes.
831 */
832 #[track_caller]
833 pub fn elementwise_reference_with_index<S2, I, M>(
834 &self,
835 rhs: I,
836 mapping_function: M,
837 ) -> Tensor<T, D>
838 where
839 I: Into<TensorView<T, S2, D>>,
840 S2: TensorRef<T, D>,
841 M: Fn([usize; D], &T, &T) -> T,
842 {
843 self.elementwise_reference_less_generic_with_index(rhs.into(), mapping_function)
844 }
845
846 #[track_caller]
847 fn elementwise_reference_less_generic<S2, M>(
848 &self,
849 rhs: TensorView<T, S2, D>,
850 mapping_function: M,
851 ) -> Tensor<T, D>
852 where
853 S2: TensorRef<T, D>,
854 M: Fn(&T, &T) -> T,
855 {
856 let left_shape = self.shape();
857 let right_shape = rhs.shape();
858 if left_shape != right_shape {
859 panic!(
860 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
861 left_shape, right_shape
862 );
863 }
864 let mapped = self
865 .iter_reference()
866 .zip(rhs.iter_reference())
867 .map(|(x, y)| mapping_function(x, y))
868 .collect();
869 Tensor::from(left_shape, mapped)
870 }
871
872 #[track_caller]
873 fn elementwise_reference_less_generic_with_index<S2, M>(
874 &self,
875 rhs: TensorView<T, S2, D>,
876 mapping_function: M,
877 ) -> Tensor<T, D>
878 where
879 S2: TensorRef<T, D>,
880 M: Fn([usize; D], &T, &T) -> T,
881 {
882 let left_shape = self.shape();
883 let right_shape = rhs.shape();
884 if left_shape != right_shape {
885 panic!(
886 "Dimensions of left and right tensors are not the same: (left: {:?}, right: {:?})",
887 left_shape, right_shape
888 );
889 }
890 // we just checked both shapes were the same, so we don't need to propagate indexes
891 // for both tensors because they'll be identical
892 let mapped = self
893 .iter_reference()
894 .with_index()
895 .zip(rhs.iter_reference())
896 .map(|((i, x), y)| mapping_function(i, x, y))
897 .collect();
898 Tensor::from(left_shape, mapped)
899 }
900
901 /**
902 * Returns a TensorView which makes the order of the data in this tensor appear to be in
903 * a different order. The order of the dimension names is unchanged, although their lengths
904 * may swap.
905 *
906 * This is a shorthand for constructing the TensorView from this TensorView.
907 *
908 * See also: [transpose](TensorView::transpose), [TensorTranspose]
909 *
910 * # Panics
911 *
912 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
913 * order need not match.
914 */
915 pub fn transpose_view(
916 &self,
917 dimensions: [Dimension; D],
918 ) -> TensorView<T, TensorTranspose<T, &S, D>, D> {
919 TensorView::from(TensorTranspose::from(&self.source, dimensions))
920 }
921
922 /// Unverified constructor for interal use when we know the dimensions/data/strides are
923 /// the same as the existing instance and don't need reverification
924 pub(crate) fn new_with_same_shape(&self, data: Vec<T>) -> TensorView<T, Tensor<T, D>, D> {
925 let shape = self.shape();
926 let strides = crate::tensors::compute_strides(&shape);
927 TensorView::from(Tensor {
928 data,
929 shape,
930 strides,
931 })
932 }
933}
934
935impl<T, S, const D: usize> TensorView<T, S, D>
936where
937 S: TensorMut<T, D>,
938{
939 /**
940 * Returns an iterator over mutable references to the data in this TensorView.
941 */
942 pub fn iter_reference_mut(&mut self) -> TensorReferenceMutIterator<'_, T, S, D> {
943 TensorReferenceMutIterator::from(&mut self.source)
944 }
945
946 /**
947 * Creates an iterator over the values in this TensorView.
948 */
949 pub fn iter_owned(self) -> TensorOwnedIterator<T, S, D>
950 where
951 T: Default,
952 {
953 TensorOwnedIterator::from(self.source())
954 }
955}
956
957/**
958 * TensorView methods which require only read access via a [TensorRef] source
959 * and a clonable type.
960 */
961impl<T, S, const D: usize> TensorView<T, S, D>
962where
963 T: Clone,
964 S: TensorRef<T, D>,
965{
966 /**
967 * Gets a copy of the first value in this tensor.
968 * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
969 * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
970 */
971 pub fn first(&self) -> T {
972 self.iter()
973 .next()
974 .expect("Tensors always have at least 1 element")
975 }
976
977 /**
978 * Returns a new Tensor which has the same data as this tensor, but with the order of data
979 * changed. The order of the dimension names is unchanged, although their lengths may swap.
980 *
981 * For example, with a `[("x", x), ("y", y)]` tensor you could call
982 * `transpose(["y", "x"])` which would return a new tensor with a shape of
983 * `[("x", y), ("y", x)]` where every (x,y) of its data corresponds to (y,x) in the original.
984 *
985 * This method need not shift *all* the dimensions though, you could also swap the width
986 * and height of images in a tensor with a shape of
987 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `transpose(["batch", "w", "h", "c"])`
988 * which would return a new tensor where all the images have been swapped over the diagonal.
989 *
990 * See also: [TensorAccess], [reorder](TensorView::reorder)
991 *
992 * # Panics
993 *
994 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
995 * order need not match (and if the order does match, this function is just an expensive
996 * clone).
997 */
998 #[track_caller]
999 pub fn transpose(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1000 let shape = self.shape();
1001 let mut reordered = self.reorder(dimensions);
1002 // Transposition is essentially reordering, but we retain the dimension name ordering
1003 // of the original order, this means we may swap dimension lengths, but the dimensions
1004 // will not change order.
1005 #[allow(clippy::needless_range_loop)]
1006 for d in 0..D {
1007 reordered.shape[d].0 = shape[d].0;
1008 }
1009 reordered
1010 }
1011
1012 /**
1013 * Returns a new Tensor which has the same data as this tensor, but with the order of the
1014 * dimensions and corresponding order of data changed.
1015 *
1016 * For example, with a `[("x", x), ("y", y)]` tensor you could call
1017 * `reorder(["y", "x"])` which would return a new tensor with a shape of
1018 * `[("y", y), ("x", x)]` where every (y,x) of its data corresponds to (x,y) in the original.
1019 *
1020 * This method need not shift *all* the dimensions though, you could also swap the width
1021 * and height of images in a tensor with a shape of
1022 * `[("batch", b), ("h", h), ("w", w), ("c", c)]` via `reorder(["batch", "w", "h", "c"])`
1023 * which would return a new tensor where every (b,w,h,c) of its data corresponds to (b,h,w,c)
1024 * in the original.
1025 *
1026 * See also: [TensorAccess], [transpose](TensorView::transpose)
1027 *
1028 * # Panics
1029 *
1030 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1031 * order need not match (and if the order does match, this function is just an expensive
1032 * clone).
1033 */
1034 #[track_caller]
1035 pub fn reorder(&self, dimensions: [Dimension; D]) -> Tensor<T, D> {
1036 let reorderd = match TensorAccess::try_from(&self.source, dimensions) {
1037 Ok(reordered) => reordered,
1038 Err(_error) => panic!(
1039 "Dimension names provided {:?} must be the same set of dimension names in the tensor: {:?}",
1040 dimensions,
1041 self.shape(),
1042 ),
1043 };
1044 let reorderd_shape = reorderd.shape();
1045 Tensor::from(reorderd_shape, reorderd.iter().collect())
1046 }
1047
1048 /**
1049 * Creates and returns a new tensor with all values from the original with the
1050 * function applied to each. This can be used to change the type of the tensor
1051 * such as creating a mask:
1052 * ```
1053 * use easy_ml::tensors::Tensor;
1054 * use easy_ml::tensors::views::TensorView;
1055 * let x = TensorView::from(Tensor::from([("a", 2), ("b", 2)], vec![
1056 * 0.0, 1.2,
1057 * 5.8, 6.9
1058 * ]));
1059 * let y = x.map(|element| element > 2.0);
1060 * let result = Tensor::from([("a", 2), ("b", 2)], vec![
1061 * false, false,
1062 * true, true
1063 * ]);
1064 * assert_eq!(&y, &result);
1065 * ```
1066 */
1067 pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
1068 let mapped = self.iter().map(mapping_function).collect();
1069 Tensor::from(self.shape(), mapped)
1070 }
1071
1072 /**
1073 * Creates and returns a new tensor with all values from the original and
1074 * the index of each value mapped by a function.
1075 */
1076 pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
1077 let mapped = self
1078 .iter()
1079 .with_index()
1080 .map(|(i, x)| mapping_function(i, x))
1081 .collect();
1082 Tensor::from(self.shape(), mapped)
1083 }
1084
1085 /**
1086 * Returns an iterator over copies of the data in this TensorView.
1087 */
1088 pub fn iter(&self) -> TensorIterator<'_, T, S, D> {
1089 TensorIterator::from(&self.source)
1090 }
1091
1092 /**
1093 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1094 * mapped by a function.
1095 *
1096 * ```
1097 * use easy_ml::tensors::Tensor;
1098 * use easy_ml::tensors::views::TensorView;
1099 * let lhs = TensorView::from(Tensor::from([("a", 4)], vec![1, 2, 3, 4]));
1100 * let rhs = TensorView::from(Tensor::from([("a", 4)], vec![0, 1, 2, 3]));
1101 * let multiplied = lhs.elementwise(&rhs, |l, r| l * r);
1102 * assert_eq!(
1103 * multiplied,
1104 * Tensor::from([("a", 4)], vec![0, 2, 6, 12])
1105 * );
1106 * ```
1107 *
1108 * # Generics
1109 *
1110 * This method can be called with any right hand side that can be converted to a TensorView,
1111 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1112 *
1113 * # Panics
1114 *
1115 * If the two tensors have different shapes.
1116 */
1117 #[track_caller]
1118 pub fn elementwise<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1119 where
1120 I: Into<TensorView<T, S2, D>>,
1121 S2: TensorRef<T, D>,
1122 M: Fn(T, T) -> T,
1123 {
1124 self.elementwise_reference_less_generic(rhs.into(), |lhs, rhs| {
1125 mapping_function(lhs.clone(), rhs.clone())
1126 })
1127 }
1128
1129 /**
1130 * Creates and returns a new tensor with all value pairs of two tensors with the same shape
1131 * mapped by a function. The mapping function also receives each index corresponding to the
1132 * value pairs.
1133 *
1134 * # Generics
1135 *
1136 * This method can be called with any right hand side that can be converted to a TensorView,
1137 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1138 *
1139 * # Panics
1140 *
1141 * If the two tensors have different shapes.
1142 */
1143 #[track_caller]
1144 pub fn elementwise_with_index<S2, I, M>(&self, rhs: I, mapping_function: M) -> Tensor<T, D>
1145 where
1146 I: Into<TensorView<T, S2, D>>,
1147 S2: TensorRef<T, D>,
1148 M: Fn([usize; D], T, T) -> T,
1149 {
1150 self.elementwise_reference_less_generic_with_index(rhs.into(), |i, lhs, rhs| {
1151 mapping_function(i, lhs.clone(), rhs.clone())
1152 })
1153 }
1154}
1155
1156/**
1157 * TensorView methods which require only read access via a scalar [TensorRef] source
1158 * and a clonable type.
1159 */
1160impl<T, S> TensorView<T, S, 0>
1161where
1162 T: Clone,
1163 S: TensorRef<T, 0>,
1164{
1165 /**
1166 * Returns a copy of the sole element in the 0 dimensional tensor.
1167 */
1168 pub fn scalar(&self) -> T {
1169 self.source.get_reference([]).unwrap().clone()
1170 }
1171}
1172
1173/**
1174 * TensorView methods which require mutable access via a [TensorMut] source and a [Default].
1175 */
1176impl<T, S> TensorView<T, S, 0>
1177where
1178 T: Default,
1179 S: TensorMut<T, 0>,
1180{
1181 /**
1182 * Returns the sole element in the 0 dimensional tensor.
1183 */
1184 pub fn into_scalar(self) -> T {
1185 TensorOwnedIterator::from(self.source).next().unwrap()
1186 }
1187}
1188
1189/**
1190 * TensorView methods which require mutable access via a [TensorMut] source.
1191 */
1192impl<T, S, const D: usize> TensorView<T, S, D>
1193where
1194 T: Clone,
1195 S: TensorMut<T, D>,
1196{
1197 /**
1198 * Applies a function to all values in the tensor view, modifying
1199 * the tensor in place.
1200 */
1201 pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1202 self.iter_reference_mut()
1203 .for_each(|x| *x = mapping_function(x.clone()));
1204 }
1205
1206 /**
1207 * Applies a function to all values and each value's index in the tensor view, modifying
1208 * the tensor view in place.
1209 */
1210 pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
1211 self.iter_reference_mut()
1212 .with_index()
1213 .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
1214 }
1215}
1216
1217impl<T, S> TensorView<T, S, 1>
1218where
1219 T: Numeric,
1220 for<'a> &'a T: NumericRef<T>,
1221 S: TensorRef<T, 1>,
1222{
1223 /**
1224 * Computes the scalar product of two equal length vectors. For two vectors `[a,b,c]` and
1225 * `[d,e,f]`, returns `a*d + b*e + c*f`. This is also known as the dot product.
1226 *
1227 * ```
1228 * use easy_ml::tensors::Tensor;
1229 * use easy_ml::tensors::views::TensorView;
1230 * let tensor_view = TensorView::from(Tensor::from([("sequence", 5)], vec![3, 4, 5, 6, 7]));
1231 * assert_eq!(tensor_view.scalar_product(&tensor_view), 3*3 + 4*4 + 5*5 + 6*6 + 7*7);
1232 * ```
1233 *
1234 * # Generics
1235 *
1236 * This method can be called with any right hand side that can be converted to a TensorView,
1237 * which includes `Tensor`, `&Tensor`, `&mut Tensor` as well as references to a `TensorView`.
1238 *
1239 * # Panics
1240 *
1241 * If the two vectors are not of equal length or their dimension names do not match.
1242 */
1243 // Would like this impl block to be in operations.rs too but then it would show first in the
1244 // TensorView docs which isn't ideal
1245 pub fn scalar_product<S2, I>(&self, rhs: I) -> T
1246 where
1247 I: Into<TensorView<T, S2, 1>>,
1248 S2: TensorRef<T, 1>,
1249 {
1250 self.scalar_product_less_generic(rhs.into())
1251 }
1252}
1253
1254impl<T, S> TensorView<T, S, 2>
1255where
1256 T: Numeric,
1257 for<'a> &'a T: NumericRef<T>,
1258 S: TensorRef<T, 2>,
1259{
1260 /**
1261 * Returns the determinant of this square matrix, or None if the matrix
1262 * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant_tensor())
1263 */
1264 pub fn determinant(&self) -> Option<T> {
1265 linear_algebra::determinant_tensor::<T, _, _>(self)
1266 }
1267
1268 /**
1269 * Computes the inverse of a matrix provided that it exists. To have an inverse a
1270 * matrix must be square (same number of rows and columns) and it must also have a
1271 * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse_tensor())
1272 */
1273 pub fn inverse(&self) -> Option<Tensor<T, 2>> {
1274 linear_algebra::inverse_tensor::<T, _, _>(self)
1275 }
1276
1277 /**
1278 * Computes the covariance matrix for this feature matrix along the specified feature
1279 * dimension in this matrix. See [`linear_algebra`](crate::linear_algebra::covariance()).
1280 */
1281 pub fn covariance(&self, feature_dimension: Dimension) -> Tensor<T, 2> {
1282 linear_algebra::covariance::<T, _, _>(self, feature_dimension)
1283 }
1284}
1285
1286macro_rules! tensor_view_select_impl {
1287 (impl TensorView $d:literal 1) => {
1288 impl<T, S> TensorView<T, S, $d>
1289 where
1290 S: TensorRef<T, $d>,
1291 {
1292 /**
1293 * Selects the provided dimension name and index pairs in this TensorView, returning a
1294 * TensorView which has fewer dimensions than this TensorView, with the removed dimensions
1295 * always indexed as the provided values.
1296 *
1297 * This is a shorthand for manually constructing the TensorView and
1298 * [TensorIndex]
1299 *
1300 * Note: due to limitations in Rust's const generics support, this method is only
1301 * implemented for `provided_indexes` of length 1 and `D` from 1 to 6. You can fall
1302 * back to manual construction to create `TensorIndex`es with multiple provided
1303 * indexes if you need to reduce dimensionality by more than 1 dimension at a time.
1304 */
1305 #[track_caller]
1306 pub fn select(
1307 &self,
1308 provided_indexes: [(Dimension, usize); 1],
1309 ) -> TensorView<T, TensorIndex<T, &S, $d, 1>, { $d - 1 }> {
1310 TensorView::from(TensorIndex::from(&self.source, provided_indexes))
1311 }
1312
1313 /**
1314 * Selects the provided dimension name and index pairs in this TensorView, returning a
1315 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1316 * always indexed as the provided values. The TensorIndex mutably borrows this
1317 * Tensor, and can therefore mutate it if it implements TensorMut.
1318 *
1319 * See [select](TensorView::select)
1320 */
1321 #[track_caller]
1322 pub fn select_mut(
1323 &mut self,
1324 provided_indexes: [(Dimension, usize); 1],
1325 ) -> TensorView<T, TensorIndex<T, &mut S, $d, 1>, { $d - 1 }> {
1326 TensorView::from(TensorIndex::from(&mut self.source, provided_indexes))
1327 }
1328
1329 /**
1330 * Selects the provided dimension name and index pairs in this TensorView, returning a
1331 * TensorView which has fewer dimensions than this Tensor, with the removed dimensions
1332 * always indexed as the provided values. The TensorIndex takes ownership of this
1333 * Tensor, and can therefore mutate it if it implements TensorMut.
1334 *
1335 * See [select](TensorView::select)
1336 */
1337 #[track_caller]
1338 pub fn select_owned(
1339 self,
1340 provided_indexes: [(Dimension, usize); 1],
1341 ) -> TensorView<T, TensorIndex<T, S, $d, 1>, { $d - 1 }> {
1342 TensorView::from(TensorIndex::from(self.source, provided_indexes))
1343 }
1344 }
1345 };
1346}
1347
1348tensor_view_select_impl!(impl TensorView 6 1);
1349tensor_view_select_impl!(impl TensorView 5 1);
1350tensor_view_select_impl!(impl TensorView 4 1);
1351tensor_view_select_impl!(impl TensorView 3 1);
1352tensor_view_select_impl!(impl TensorView 2 1);
1353tensor_view_select_impl!(impl TensorView 1 1);
1354
1355macro_rules! tensor_view_expand_impl {
1356 (impl Tensor $d:literal 1) => {
1357 impl<T, S> TensorView<T, S, $d>
1358 where
1359 S: TensorRef<T, $d>,
1360 {
1361 /**
1362 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1363 * a particular position within the shape, returning a TensorView which has more
1364 * dimensions than this TensorView.
1365 *
1366 * This is a shorthand for manually constructing the TensorView and
1367 * [TensorExpansion]
1368 *
1369 * Note: due to limitations in Rust's const generics support, this method is only
1370 * implemented for `extra_dimension_names` of length 1 and `D` from 0 to 5. You can
1371 * fall back to manual construction to create `TensorExpansion`s with multiple provided
1372 * indexes if you need to increase dimensionality by more than 1 dimension at a time.
1373 */
1374 #[track_caller]
1375 pub fn expand(
1376 &self,
1377 extra_dimension_names: [(usize, Dimension); 1],
1378 ) -> TensorView<T, TensorExpansion<T, &S, $d, 1>, { $d + 1 }> {
1379 TensorView::from(TensorExpansion::from(&self.source, extra_dimension_names))
1380 }
1381
1382 /**
1383 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1384 * a particular position within the shape, returning a TensorView which has more
1385 * dimensions than this Tensor. The TensorIndex mutably borrows this
1386 * Tensor, and can therefore mutate it if it implements TensorMut.
1387 *
1388 * See [expand](Tensor::expand)
1389 */
1390 #[track_caller]
1391 pub fn expand_mut(
1392 &mut self,
1393 extra_dimension_names: [(usize, Dimension); 1],
1394 ) -> TensorView<T, TensorExpansion<T, &mut S, $d, 1>, { $d + 1 }> {
1395 TensorView::from(TensorExpansion::from(
1396 &mut self.source,
1397 extra_dimension_names,
1398 ))
1399 }
1400
1401 /**
1402 * Expands the dimensionality of this tensor by adding dimensions of length 1 at
1403 * a particular position within the shape, returning a TensorView which has more
1404 * dimensions than this Tensor. The TensorIndex takes ownership of this
1405 * Tensor, and can therefore mutate it if it implements TensorMut.
1406 *
1407 * See [expand](Tensor::expand)
1408 */
1409 #[track_caller]
1410 pub fn expand_owned(
1411 self,
1412 extra_dimension_names: [(usize, Dimension); 1],
1413 ) -> TensorView<T, TensorExpansion<T, S, $d, 1>, { $d + 1 }> {
1414 TensorView::from(TensorExpansion::from(self.source, extra_dimension_names))
1415 }
1416 }
1417 };
1418}
1419
1420tensor_view_expand_impl!(impl Tensor 0 1);
1421tensor_view_expand_impl!(impl Tensor 1 1);
1422tensor_view_expand_impl!(impl Tensor 2 1);
1423tensor_view_expand_impl!(impl Tensor 3 1);
1424tensor_view_expand_impl!(impl Tensor 4 1);
1425tensor_view_expand_impl!(impl Tensor 5 1);
1426
1427/**
1428 * Debug implementations for TensorView additionally show the visible data and visible dimensions
1429 * reported by the source as fields. This is in addition to recursive debug content of the actual
1430 * source.
1431 */
1432impl<T, S, const D: usize> std::fmt::Debug for TensorView<T, S, D>
1433where
1434 T: std::fmt::Debug,
1435 S: std::fmt::Debug + TensorRef<T, D>,
1436{
1437 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1438 f.debug_struct("TensorView")
1439 .field("visible", &DebugSourceVisible::from(&self.source))
1440 .field("shape", &self.source.view_shape())
1441 .field("source", &self.source)
1442 .finish()
1443 }
1444}
1445
1446struct DebugSourceVisible<T, S, const D: usize> {
1447 source: S,
1448 _type: PhantomData<T>,
1449}
1450
1451impl<T, S, const D: usize> DebugSourceVisible<T, S, D>
1452where
1453 T: std::fmt::Debug,
1454 S: std::fmt::Debug + TensorRef<T, D>,
1455{
1456 fn from(source: S) -> DebugSourceVisible<T, S, D> {
1457 DebugSourceVisible {
1458 source,
1459 _type: PhantomData,
1460 }
1461 }
1462}
1463
1464impl<T, S, const D: usize> std::fmt::Debug for DebugSourceVisible<T, S, D>
1465where
1466 T: std::fmt::Debug,
1467 S: std::fmt::Debug + TensorRef<T, D>,
1468{
1469 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1470 f.debug_list()
1471 .entries(TensorReferenceIterator::from(&self.source))
1472 .finish()
1473 }
1474}
1475
1476#[test]
1477fn test_debug() {
1478 let x = Tensor::from([("rows", 3), ("columns", 4)], (0..12).collect());
1479 let view = TensorView::from(&x);
1480 let debugged = format!("{:?}\n{:?}", x, view);
1481 assert_eq!(
1482 debugged,
1483 r#"Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] }
1484TensorView { visible: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], source: Tensor { data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape: [("rows", 3), ("columns", 4)], strides: [4, 1] } }"#
1485 )
1486}
1487
1488#[test]
1489fn test_debug_clipped() {
1490 let x = Tensor::from([("rows", 2), ("columns", 3)], (0..6).collect());
1491 let view = TensorView::from(&x)
1492 .range_owned([("columns", IndexRange::new(1, 2))])
1493 .unwrap();
1494 let debugged = format!("{:#?}\n{:#?}", x, view);
1495 println!("{:#?}\n{:#?}", x, view);
1496 assert_eq!(
1497 debugged,
1498 r#"Tensor {
1499 data: [
1500 0,
1501 1,
1502 2,
1503 3,
1504 4,
1505 5,
1506 ],
1507 shape: [
1508 (
1509 "rows",
1510 2,
1511 ),
1512 (
1513 "columns",
1514 3,
1515 ),
1516 ],
1517 strides: [
1518 3,
1519 1,
1520 ],
1521}
1522TensorView {
1523 visible: [
1524 1,
1525 2,
1526 4,
1527 5,
1528 ],
1529 shape: [
1530 (
1531 "rows",
1532 2,
1533 ),
1534 (
1535 "columns",
1536 2,
1537 ),
1538 ],
1539 source: TensorRange {
1540 source: Tensor {
1541 data: [
1542 0,
1543 1,
1544 2,
1545 3,
1546 4,
1547 5,
1548 ],
1549 shape: [
1550 (
1551 "rows",
1552 2,
1553 ),
1554 (
1555 "columns",
1556 3,
1557 ),
1558 ],
1559 strides: [
1560 3,
1561 1,
1562 ],
1563 },
1564 range: [
1565 IndexRange {
1566 start: 0,
1567 length: 2,
1568 },
1569 IndexRange {
1570 start: 1,
1571 length: 2,
1572 },
1573 ],
1574 _type: PhantomData<i32>,
1575 },
1576}"#
1577 )
1578}
1579
1580/**
1581 * Any tensor view of a Displayable type implements Display
1582 *
1583 * You can control the precision of the formatting using format arguments, i.e.
1584 * `format!("{:.3}", tensor)`
1585 */
1586impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorView<T, S, D>
1587where
1588 T: std::fmt::Display,
1589 S: TensorRef<T, D>,
1590{
1591 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1592 crate::tensors::display::format_view(&self.source, f)
1593 }
1594}
1595
1596/**
1597 * A Tensor can be converted to a TensorView of that Tensor.
1598 */
1599impl<T, const D: usize> From<Tensor<T, D>> for TensorView<T, Tensor<T, D>, D> {
1600 fn from(tensor: Tensor<T, D>) -> TensorView<T, Tensor<T, D>, D> {
1601 TensorView::from(tensor)
1602 }
1603}
1604
1605/**
1606 * A reference to a Tensor can be converted to a TensorView of that referenced Tensor.
1607 */
1608impl<'a, T, const D: usize> From<&'a Tensor<T, D>> for TensorView<T, &'a Tensor<T, D>, D> {
1609 fn from(tensor: &Tensor<T, D>) -> TensorView<T, &Tensor<T, D>, D> {
1610 TensorView::from(tensor)
1611 }
1612}
1613
1614/**
1615 * A mutable reference to a Tensor can be converted to a TensorView of that mutably referenced
1616 * Tensor.
1617 */
1618impl<'a, T, const D: usize> From<&'a mut Tensor<T, D>> for TensorView<T, &'a mut Tensor<T, D>, D> {
1619 fn from(tensor: &mut Tensor<T, D>) -> TensorView<T, &mut Tensor<T, D>, D> {
1620 TensorView::from(tensor)
1621 }
1622}
1623
1624/**
1625 * A reference to a TensorView can be converted to an owned TensorView with a reference to the
1626 * source type of that first TensorView.
1627 */
1628impl<'a, T, S, const D: usize> From<&'a TensorView<T, S, D>> for TensorView<T, &'a S, D>
1629where
1630 S: TensorRef<T, D>,
1631{
1632 fn from(tensor_view: &TensorView<T, S, D>) -> TensorView<T, &S, D> {
1633 TensorView::from(tensor_view.source_ref())
1634 }
1635}
1636
1637/**
1638 * A mutable reference to a TensorView can be converted to an owned TensorView with a mutable
1639 * reference to the source type of that first TensorView.
1640 */
1641impl<'a, T, S, const D: usize> From<&'a mut TensorView<T, S, D>> for TensorView<T, &'a mut S, D>
1642where
1643 S: TensorRef<T, D>,
1644{
1645 fn from(tensor_view: &mut TensorView<T, S, D>) -> TensorView<T, &mut S, D> {
1646 TensorView::from(tensor_view.source_ref_mut())
1647 }
1648}