easy_ml/tensors/indexing.rs
1/*!
2 * # Indexing
3 *
4 * Many libraries represent tensors as N dimensional arrays, however there is often some semantic
5 * meaning to each dimension. You may have a batch of 2000 images, each 100 pixels wide and high,
6 * with each pixel representing 3 numbers for rgb values. This can be represented as a
7 * 2000 x 100 x 100 x 3 tensor, but a 4 dimensional array does not track the semantic meaning
8 * of each dimension and associated index.
9 *
10 * 6 months later you could come back to the code and forget which order the dimensions were
11 * created in, at best getting the indexes out of bounds and causing a crash in your application,
12 * and at worst silently reading the wrong data without realising. *Was it width then height or
13 * height then width?*...
14 *
15 * Easy ML moves the N dimensional array to an implementation detail, and most of its APIs work
16 * on the names of each dimension in a tensor instead of just the order. Instead of a
17 * 2000 x 100 x 100 x 3 tensor in which the last element is at [1999, 99, 99, 2], Easy ML tracks
18 * the names of the dimensions, so you have a
19 * `[("batch", 2000), ("width", 100), ("height", 100), ("rgb", 3)]` shaped tensor.
20 *
21 * This can't stop you from getting the math wrong, but confusion over which dimension
22 * means what is reduced. Tensors carry around their pairs of dimension name and length
23 * so adding a `[("batch", 2000), ("width", 100), ("height", 100), ("rgb", 3)]` shaped tensor
24 * to a `[("batch", 2000), ("height", 100), ("width", 100), ("rgb", 3)]` will fail unless you
25 * reorder one first, and you could access an element as
26 * `tensor.index_by(["batch", "width", "height", "rgb"]).get([1999, 0, 99, 3])` or
27 * `tensor.index_by(["batch", "height", "width", "rgb"]).get([1999, 99, 0, 3])` and read the same data,
28 * because you index into dimensions based on their name, not just the order they are stored in
29 * memory.
30 *
31 * Even with a name for each dimension, at some point you still need to say what order you want
32 * to index each dimension with, and this is where [`TensorAccess`] comes in. It
33 * creates a mapping from the dimension name order you want to access elements with to the order
34 * the dimensions are stored as.
35 */
36
37use crate::differentiation::{Index, Primitive, Record, RecordTensor};
38use crate::numeric::Numeric;
39use crate::tensors::dimensions;
40use crate::tensors::dimensions::DimensionMappings;
41use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
42use crate::tensors::{Dimension, Tensor};
43
44use std::error::Error;
45use std::fmt;
46use std::iter::{ExactSizeIterator, FusedIterator};
47use std::marker::PhantomData;
48
49pub use crate::matrices::iterators::WithIndex;
50
51/**
52 * Access to the data in a Tensor with a particular order of dimension indexing. The order
53 * affects the shape of the TensorAccess as well as the order of indexes you supply to read
54 * or write values to the tensor.
55 *
56 * See the [module level documentation](crate::tensors::indexing) for more information.
57 */
58#[derive(Clone, Debug)]
59pub struct TensorAccess<T, S, const D: usize> {
60 source: S,
61 dimension_mapping: DimensionMappings<D>,
62 _type: PhantomData<T>,
63}
64
65impl<T, S, const D: usize> TensorAccess<T, S, D>
66where
67 S: TensorRef<T, D>,
68{
69 /**
70 * Creates a TensorAccess which can be indexed in the order of the supplied dimensions
71 * to read or write values from this tensor.
72 *
73 * # Panics
74 *
75 * If the set of dimensions supplied do not match the set of dimensions in this tensor's shape.
76 */
77 #[track_caller]
78 pub fn from(source: S, dimensions: [Dimension; D]) -> TensorAccess<T, S, D> {
79 match TensorAccess::try_from(source, dimensions) {
80 Err(error) => panic!("{}", error),
81 Ok(success) => success,
82 }
83 }
84
85 /**
86 * Creates a TensorAccess which can be indexed in the order of the supplied dimensions
87 * to read or write values from this tensor.
88 *
89 * Returns Err if the set of dimensions supplied do not match the set of dimensions in this
90 * tensor's shape.
91 */
92 pub fn try_from(
93 source: S,
94 dimensions: [Dimension; D],
95 ) -> Result<TensorAccess<T, S, D>, InvalidDimensionsError<D>> {
96 Ok(TensorAccess {
97 dimension_mapping: DimensionMappings::new(&source.view_shape(), &dimensions)
98 .ok_or_else(|| InvalidDimensionsError {
99 actual: source.view_shape(),
100 requested: dimensions,
101 })?,
102 source,
103 _type: PhantomData,
104 })
105 }
106
107 /**
108 * Creates a TensorAccess which is indexed in the same order as the dimensions in the view
109 * shape of the tensor it is created from.
110 *
111 * Hence if you create a TensorAccess directly from a Tensor by `from_source_order`
112 * this uses the order the dimensions were laid out in memory with.
113 *
114 * ```
115 * use easy_ml::tensors::Tensor;
116 * use easy_ml::tensors::indexing::TensorAccess;
117 * let tensor = Tensor::from([("x", 2), ("y", 2), ("z", 2)], vec![
118 * 1, 2,
119 * 3, 4,
120 *
121 * 5, 6,
122 * 7, 8
123 * ]);
124 * let xyz = tensor.index_by(["x", "y", "z"]);
125 * let also_xyz = TensorAccess::from_source_order(&tensor);
126 * let also_xyz = tensor.index();
127 * ```
128 */
129 pub fn from_source_order(source: S) -> TensorAccess<T, S, D> {
130 TensorAccess {
131 dimension_mapping: DimensionMappings::no_op_mapping(),
132 source,
133 _type: PhantomData,
134 }
135 }
136
137 /**
138 * Creates a TensorAccess which is indexed in the same order as the linear data layout
139 * dimensions in the tensor it is created from, or None if the source data layout
140 * is not linear.
141 *
142 * Hence if you use `from_memory_order` on a source that was originally big endian like
143 * [Tensor] this uses the order for efficient iteration through each step in memory
144 * when [iterating](TensorIterator).
145 */
146 pub fn from_memory_order(source: S) -> Option<TensorAccess<T, S, D>> {
147 let data_layout = match source.data_layout() {
148 DataLayout::Linear(order) => order,
149 _ => return None,
150 };
151 let shape = source.view_shape();
152 Some(TensorAccess::try_from(source, data_layout).unwrap_or_else(|_| panic!(
153 "Source implementation contained dimensions {:?} in data_layout that were not the same set as in the view_shape {:?} which breaks the contract of TensorRef",
154 data_layout, shape
155 )))
156 }
157
158 /**
159 * The shape this TensorAccess has with the dimensions mapped to the order the TensorAccess
160 * was created with, not necessarily the same order as in the underlying tensor.
161 */
162 pub fn shape(&self) -> [(Dimension, usize); D] {
163 self.dimension_mapping
164 .map_shape_to_requested(&self.source.view_shape())
165 }
166
167 pub fn source(self) -> S {
168 self.source
169 }
170
171 // # Safety
172 //
173 // Giving out a mutable reference to our source could allow it to be changed out from under us
174 // and make our dimmension mapping invalid. However, since the source implements TensorRef
175 // interior mutability is not allowed, so we can give out shared references without breaking
176 // our own integrity.
177 pub fn source_ref(&self) -> &S {
178 &self.source
179 }
180}
181
182/**
183 * An error indicating failure to create a TensorAccess because the requested dimension order
184 * does not match the shape in the source data.
185 */
186#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
187pub struct InvalidDimensionsError<const D: usize> {
188 pub actual: [(Dimension, usize); D],
189 pub requested: [Dimension; D],
190}
191
192impl<const D: usize> Error for InvalidDimensionsError<D> {}
193
194impl<const D: usize> fmt::Display for InvalidDimensionsError<D> {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 write!(
197 f,
198 "Requested dimension order: {:?} does not match the shape in the source: {:?}",
199 &self.actual, &self.requested
200 )
201 }
202}
203
204#[test]
205fn test_sync() {
206 fn assert_sync<T: Sync>() {}
207 assert_sync::<InvalidDimensionsError<3>>();
208}
209
210#[test]
211fn test_send() {
212 fn assert_send<T: Send>() {}
213 assert_send::<InvalidDimensionsError<3>>();
214}
215
216impl<T, S, const D: usize> TensorAccess<T, S, D>
217where
218 S: TensorRef<T, D>,
219{
220 /**
221 * Using the dimension ordering of the TensorAccess, gets a reference to the value at the
222 * index if the index is in range. Otherwise returns None.
223 */
224 pub fn try_get_reference(&self, indexes: [usize; D]) -> Option<&T> {
225 self.source
226 .get_reference(self.dimension_mapping.map_dimensions_to_source(&indexes))
227 }
228
229 /**
230 * Using the dimension ordering of the TensorAccess, gets a reference to the value at the
231 * index if the index is in range, panicking if the index is out of range.
232 */
233 // NOTE: Ideally `get_reference` would be used here for consistency, but that opens the
234 // minefield of TensorRef::get_reference and TensorAccess::get_ref being different signatures
235 // but the same name.
236 #[track_caller]
237 pub fn get_ref(&self, indexes: [usize; D]) -> &T {
238 match self.try_get_reference(indexes) {
239 Some(reference) => reference,
240 None => panic!(
241 "Unable to index with {:?}, Tensor dimensions are {:?}.",
242 indexes,
243 self.shape()
244 ),
245 }
246 }
247
248 /**
249 * Using the dimension ordering of the TensorAccess, gets a reference to the value at the
250 * index wihout any bounds checking.
251 *
252 * # Safety
253 *
254 * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
255 * resulting reference is not used. Valid indexes are defined as in [TensorRef]. Note that
256 * the order of the indexes needed here must match with
257 * [`TensorAccess::shape`](TensorAccess::shape) which may not neccessarily be the same
258 * as the `view_shape` of the `TensorRef` implementation this TensorAccess was created from).
259 *
260 * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
261 * [TensorRef]: TensorRef
262 */
263 // NOTE: This aliases with TensorRef::get_reference_unchecked but the TensorRef impl
264 // just calls this and the signatures match anyway, so there are no potential issues.
265 #[allow(clippy::missing_safety_doc)] // it's not missing
266 pub unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
267 unsafe {
268 self.source
269 .get_reference_unchecked(self.dimension_mapping.map_dimensions_to_source(&indexes))
270 }
271 }
272
273 /**
274 * Returns an iterator over references to the data in this TensorAccess, in the order of
275 * the TensorAccess shape.
276 */
277 pub fn iter_reference(&self) -> TensorReferenceIterator<'_, T, TensorAccess<T, S, D>, D> {
278 TensorReferenceIterator::from(self)
279 }
280}
281
282impl<T, S, const D: usize> TensorAccess<T, S, D>
283where
284 S: TensorRef<T, D>,
285 T: Clone,
286{
287 /**
288 * Using the dimension ordering of the TensorAccess, gets a copy of the value at the
289 * index if the index is in range, panicking if the index is out of range.
290 *
291 * For a non panicking API see [`try_get_reference`](TensorAccess::try_get_reference)
292 */
293 #[track_caller]
294 pub fn get(&self, indexes: [usize; D]) -> T {
295 match self.try_get_reference(indexes) {
296 Some(reference) => reference.clone(),
297 None => panic!(
298 "Unable to index with {:?}, Tensor dimensions are {:?}.",
299 indexes,
300 self.shape()
301 ),
302 }
303 }
304
305 /**
306 * Gets a copy of the first value in this tensor.
307 * For 0 dimensional tensors this is the only index `[]`, for 1 dimensional tensors this
308 * is `[0]`, for 2 dimensional tensors `[0,0]`, etcetera.
309 */
310 pub fn first(&self) -> T {
311 self.iter()
312 .next()
313 .expect("Tensors always have at least 1 element")
314 }
315
316 /**
317 * Creates and returns a new tensor with all values from the original with the
318 * function applied to each.
319 *
320 * Note: mapping methods are defined on [Tensor] and
321 * [TensorView](crate::tensors::views::TensorView) directly so you don't need to create a
322 * TensorAccess unless you want to do the mapping with a different dimension order.
323 */
324 pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Tensor<U, D> {
325 let mapped = self.iter().map(mapping_function).collect();
326 Tensor::from(self.shape(), mapped)
327 }
328
329 /**
330 * Creates and returns a new tensor with all values from the original and
331 * the index of each value mapped by a function. The indexes passed to the mapping
332 * function always increment the rightmost index, starting at all 0s, using the dimension
333 * order that the TensorAccess is indexed by, not neccessarily the index order the
334 * original source uses.
335 *
336 * Note: mapping methods are defined on [Tensor] and
337 * [TensorView](crate::tensors::views::TensorView) directly so you don't need to create a
338 * TensorAccess unless you want to do the mapping with a different dimension order.
339 */
340 pub fn map_with_index<U>(&self, mapping_function: impl Fn([usize; D], T) -> U) -> Tensor<U, D> {
341 let mapped = self
342 .iter()
343 .with_index()
344 .map(|(i, x)| mapping_function(i, x))
345 .collect();
346 Tensor::from(self.shape(), mapped)
347 }
348
349 /**
350 * Returns an iterator over copies of the data in this TensorAccess, in the order of
351 * the TensorAccess shape.
352 */
353 pub fn iter(&self) -> TensorIterator<'_, T, TensorAccess<T, S, D>, D> {
354 TensorIterator::from(self)
355 }
356}
357
358impl<T, S, const D: usize> TensorAccess<T, S, D>
359where
360 S: TensorMut<T, D>,
361{
362 /**
363 * Using the dimension ordering of the TensorAccess, gets a mutable reference to the value at
364 * the index if the index is in range. Otherwise returns None.
365 */
366 pub fn try_get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
367 self.source
368 .get_reference_mut(self.dimension_mapping.map_dimensions_to_source(&indexes))
369 }
370
371 /**
372 * Using the dimension ordering of the TensorAccess, gets a mutable reference to the value at
373 * the index if the index is in range, panicking if the index is out of range.
374 */
375 // NOTE: Ideally `get_reference_mut` would be used here for consistency, but that opens the
376 // minefield of TensorMut::get_reference_mut and TensorAccess::get_ref_mut being different
377 // signatures but the same name.
378 #[track_caller]
379 pub fn get_ref_mut(&mut self, indexes: [usize; D]) -> &mut T {
380 match self.try_get_reference_mut(indexes) {
381 Some(reference) => reference,
382 // can't provide a better error because the borrow checker insists that returning
383 // a reference in the Some branch means our mutable borrow prevents us calling
384 // self.shape() and a bad error is better than cloning self.shape() on every call
385 None => panic!("Unable to index with {:?}", indexes),
386 }
387 }
388
389 /**
390 * Using the dimension ordering of the TensorAccess, gets a mutable reference to the value at
391 * the index wihout any bounds checking.
392 *
393 * # Safety
394 *
395 * Calling this method with an out-of-bounds index is *[undefined behavior]* even if the
396 * resulting reference is not used. Valid indexes are defined as in [TensorRef]. Note that
397 * the order of the indexes needed here must match with
398 * [`TensorAccess::shape`](TensorAccess::shape) which may not neccessarily be the same
399 * as the `view_shape` of the `TensorRef` implementation this TensorAccess was created from).
400 *
401 * [undefined behavior]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
402 * [TensorRef]: TensorRef
403 */
404 // NOTE: This aliases with TensorRef::get_reference_unchecked_mut but the TensorMut impl
405 // just calls this and the signatures match anyway, so there are no potential issues.
406 #[allow(clippy::missing_safety_doc)] // it's not missing
407 pub unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
408 unsafe {
409 self.source.get_reference_unchecked_mut(
410 self.dimension_mapping.map_dimensions_to_source(&indexes),
411 )
412 }
413 }
414
415 /**
416 * Returns an iterator over mutable references to the data in this TensorAccess, in the order
417 * of the TensorAccess shape.
418 */
419 pub fn iter_reference_mut(
420 &mut self,
421 ) -> TensorReferenceMutIterator<'_, T, TensorAccess<T, S, D>, D> {
422 TensorReferenceMutIterator::from(self)
423 }
424}
425
426impl<T, S, const D: usize> TensorAccess<T, S, D>
427where
428 S: TensorMut<T, D>,
429 T: Clone,
430{
431 /**
432 * Applies a function to all values in the tensor, modifying
433 * the tensor in place.
434 */
435 pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
436 self.iter_reference_mut()
437 .for_each(|x| *x = mapping_function(x.clone()));
438 }
439
440 /**
441 * Applies a function to all values and each value's index in the tensor, modifying
442 * the tensor in place. The indexes passed to the mapping function always increment
443 * the rightmost index, starting at all 0s, using the dimension order that the
444 * TensorAccess is indexed by, not neccessarily the index order the original source uses.
445 */
446 pub fn map_mut_with_index(&mut self, mapping_function: impl Fn([usize; D], T) -> T) {
447 self.iter_reference_mut()
448 .with_index()
449 .for_each(|(i, x)| *x = mapping_function(i, x.clone()));
450 }
451}
452
453impl<'a, T, S, const D: usize> TensorAccess<(T, Index), &RecordTensor<'a, T, S, D>, D>
454where
455 T: Numeric + Primitive,
456 S: TensorRef<(T, Index), D>,
457{
458 /**
459 * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
460 * as a Record if the index is in range, panicking if the index is out of range.
461 *
462 * If you need to access all the data as records instead of just a specific index you should
463 * probably use one of the iterator APIs instead.
464 *
465 * See also: [iter_as_records](RecordTensor::iter_as_records)
466 *
467 * # Panics
468 *
469 * If the index is out of range.
470 *
471 * For a non panicking API see [try_get_as_record](TensorAccess::try_get_as_record)
472 *
473 * ```
474 * use easy_ml::differentiation::RecordTensor;
475 * use easy_ml::differentiation::WengertList;
476 * use easy_ml::tensors::Tensor;
477 *
478 * let list = WengertList::new();
479 * let X = RecordTensor::variables(
480 * &list,
481 * Tensor::from(
482 * [("r", 2), ("c", 3)],
483 * vec![
484 * 3.0, 4.0, 5.0,
485 * 1.0, 4.0, 9.0,
486 * ]
487 * )
488 * );
489 * let x = X.index_by(["c", "r"]).get_as_record([2, 0]);
490 * assert_eq!(x.number, 5.0);
491 * ```
492 */
493 #[track_caller]
494 pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
495 Record::from_existing(self.get(indexes), self.source.history())
496 }
497
498 /**
499 * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
500 * as a Record if the index is in range. Otherwise returns None.
501 *
502 * If you need to access all the data as records instead of just a specific index you should
503 * probably use one of the iterator APIs instead.
504 *
505 * See also: [iter_as_records](RecordTensor::iter_as_records)
506 */
507 pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
508 self.try_get_reference(indexes)
509 .map(|r| Record::from_existing(r.clone(), self.source.history()))
510 }
511}
512
513impl<'a, T, S, const D: usize> TensorAccess<(T, Index), RecordTensor<'a, T, S, D>, D>
514where
515 T: Numeric + Primitive,
516 S: TensorRef<(T, Index), D>,
517{
518 /**
519 * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
520 * as a Record if the index is in range, panicking if the index is out of range.
521 *
522 * If you need to access all the data as records instead of just a specific index you should
523 * probably use one of the iterator APIs instead.
524 *
525 * See also: [iter_as_records](RecordTensor::iter_as_records)
526 *
527 * # Panics
528 *
529 * If the index is out of range.
530 *
531 * For a non panicking API see [try_get_as_record](TensorAccess::try_get_as_record)
532 */
533 #[track_caller]
534 pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
535 Record::from_existing(self.get(indexes), self.source.history())
536 }
537
538 /**
539 * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
540 * as a Record if the index is in range. Otherwise returns None.
541 *
542 * If you need to access all the data as records instead of just a specific index you should
543 * probably use one of the iterator APIs instead.
544 *
545 * See also: [iter_as_records](RecordTensor::iter_as_records)
546 */
547 pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
548 self.try_get_reference(indexes)
549 .map(|r| Record::from_existing(r.clone(), self.source.history()))
550 }
551}
552
553impl<'a, T, S, const D: usize> TensorAccess<(T, Index), &mut RecordTensor<'a, T, S, D>, D>
554where
555 T: Numeric + Primitive,
556 S: TensorRef<(T, Index), D>,
557{
558 /**
559 * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
560 * as a Record if the index is in range, panicking if the index is out of range.
561 *
562 * If you need to access all the data as records instead of just a specific index you should
563 * probably use one of the iterator APIs instead.
564 *
565 * See also: [iter_as_records](RecordTensor::iter_as_records)
566 *
567 * # Panics
568 *
569 * If the index is out of range.
570 *
571 * For a non panicking API see [try_get_as_record](TensorAccess::try_get_as_record)
572 */
573 #[track_caller]
574 pub fn get_as_record(&self, indexes: [usize; D]) -> Record<'a, T> {
575 Record::from_existing(self.get(indexes), self.source.history())
576 }
577
578 /**
579 * Using the dimension ordering of the TensorAccess, returns a copy of the data at the index
580 * as a Record if the index is in range. Otherwise returns None.
581 *
582 * If you need to access all the data as records instead of just a specific index you should
583 * probably use one of the iterator APIs instead.
584 *
585 * See also: [iter_as_records](RecordTensor::iter_as_records)
586 */
587 pub fn try_get_as_record(&self, indexes: [usize; D]) -> Option<Record<'a, T>> {
588 self.try_get_reference(indexes)
589 .map(|r| Record::from_existing(r.clone(), self.source.history()))
590 }
591}
592
593// # Safety
594//
595// The type implementing TensorRef inside the TensorAccess must implement it correctly, so by
596// delegating to it without changing anything other than the order we index it, we implement
597// TensorRef correctly as well.
598/**
599 * A TensorAccess implements TensorRef, with the dimension order and indexing matching that of the
600 * TensorAccess shape.
601 */
602unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorAccess<T, S, D>
603where
604 S: TensorRef<T, D>,
605{
606 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
607 self.try_get_reference(indexes)
608 }
609
610 fn view_shape(&self) -> [(Dimension, usize); D] {
611 self.shape()
612 }
613
614 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
615 unsafe { self.get_reference_unchecked(indexes) }
616 }
617
618 fn data_layout(&self) -> DataLayout<D> {
619 match self.source.data_layout() {
620 // We might have reordered the view_shape but we didn't rearrange the memory or change
621 // what each dimension name refers to in memory, so the data layout remains as is.
622 DataLayout::Linear(order) => DataLayout::Linear(order),
623 DataLayout::NonLinear => DataLayout::NonLinear,
624 DataLayout::Other => DataLayout::Other,
625 }
626 }
627}
628
629// # Safety
630//
631// The type implementing TensorMut inside the TensorAccess must implement it correctly, so by
632// delegating to it without changing anything other than the order we index it, we implement
633// TensorMut correctly as well.
634/**
635 * A TensorAccess implements TensorMut, with the dimension order and indexing matching that of the
636 * TensorAccess shape.
637 */
638unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorAccess<T, S, D>
639where
640 S: TensorMut<T, D>,
641{
642 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
643 self.try_get_reference_mut(indexes)
644 }
645
646 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
647 unsafe { self.get_reference_unchecked_mut(indexes) }
648 }
649}
650
651/**
652 * Any tensor access of a Displayable type implements Display
653 *
654 * You can control the precision of the formatting using format arguments, i.e.
655 * `format!("{:.3}", tensor)`
656 */
657impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorAccess<T, S, D>
658where
659 T: std::fmt::Display,
660 S: TensorRef<T, D>,
661{
662 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
663 crate::tensors::display::format_view(&self, f)?;
664 writeln!(f)?;
665 write!(f, "Data Layout = {:?}", self.data_layout())
666 }
667}
668
669/**
670 * An iterator over all indexes in a shape.
671 *
672 * First the all 0 index is iterated, then each iteration increments the rightmost index.
673 * For a shape of `[("a", 2), ("b", 2), ("c", 2)]` this will yield indexes in order of: `[0,0,0]`,
674 * `[0,0,1]`, `[0,1,0]`, `[0,1,1]`, `[1,0,0]`, `[1,0,1]`, `[1,1,0]`, `[1,1,1]`,
675 *
676 * You don't typically need to use this directly, as tensors have iterators that iterate over
677 * them and return values to you (using this under the hood), but `ShapeIterator` can be useful
678 * if you need to hold a mutable reference to a tensor while iterating as `ShapeIterator` does
679 * not borrow the tensor. NB: if you do index into a tensor you're mutably borrowing using
680 * `ShapeIterator` directly, take care to ensure you don't accidentally reshape the tensor and
681 * continue to use indexes from `ShapeIterator` as they would then be invalid.
682 */
683#[derive(Clone, Debug)]
684pub struct ShapeIterator<const D: usize> {
685 shape: [(Dimension, usize); D],
686 indexes: [usize; D],
687 finished: bool,
688}
689
690impl<const D: usize> ShapeIterator<D> {
691 /**
692 * Constructs a ShapeIterator for a shape.
693 *
694 * If the shape has any dimensions with a length of zero, the iterator will immediately
695 * return None on [`next()`](Iterator::next).
696 */
697 pub fn from(shape: [(Dimension, usize); D]) -> ShapeIterator<D> {
698 // If we're given an invalid shape (shape input is not neccessarily going to meet the no
699 // 0 lengths contract of TensorRef because that's not actually required here), return
700 // a finished iterator
701 // Since this is an iterator over an owned shape, it's not going to become invalid later
702 // when we start iterating so this is the only check we need.
703 let starting_index_valid = shape.iter().all(|(_, l)| *l > 0);
704 ShapeIterator {
705 shape,
706 indexes: [0; D],
707 finished: !starting_index_valid,
708 }
709 }
710}
711
712impl<const D: usize> Iterator for ShapeIterator<D> {
713 type Item = [usize; D];
714
715 fn next(&mut self) -> Option<Self::Item> {
716 iter(&mut self.finished, &mut self.indexes, &self.shape)
717 }
718
719 fn size_hint(&self) -> (usize, Option<usize>) {
720 size_hint(self.finished, &self.indexes, &self.shape)
721 }
722}
723
724// Once we hit the end we mark ourselves as finished so we're always Fused.
725impl<const D: usize> FusedIterator for ShapeIterator<D> {}
726// We can always calculate the exact number of steps remaining because the shape and indexes are
727// private fields that are only mutated by `next` to count up.
728impl<const D: usize> ExactSizeIterator for ShapeIterator<D> {}
729
730/// Common index order iterator logic
731fn iter<const D: usize>(
732 finished: &mut bool,
733 indexes: &mut [usize; D],
734 shape: &[(Dimension, usize); D],
735) -> Option<[usize; D]> {
736 if *finished {
737 return None;
738 }
739
740 let value = Some(*indexes);
741
742 if D > 0 {
743 // Increment index of final dimension. In the 2D case, we iterate through a row by
744 // incrementing through every column index.
745 indexes[D - 1] += 1;
746 for d in (1..D).rev() {
747 if indexes[d] == shape[d].1 {
748 // ran to end of this dimension with our index
749 // In the 2D case, we finished indexing through every column in the row,
750 // and it's now time to move onto the next row.
751 indexes[d] = 0;
752 indexes[d - 1] += 1;
753 }
754 }
755 // Check if we ran past the final index
756 if indexes[0] == shape[0].1 {
757 *finished = true;
758 }
759 } else {
760 *finished = true;
761 }
762
763 value
764}
765
766/// Common index order iterator logic
767fn iter_back<const D: usize>(
768 finished: &mut bool,
769 indexes: &mut [usize; D],
770 shape: &[(Dimension, usize); D],
771) -> Option<[usize; D]> {
772 if *finished {
773 return None;
774 }
775
776 let value = Some(*indexes);
777
778 if D > 0 {
779 let mut bounds = [false; D];
780
781 // Decrement index of final dimension. In the 2D case, we iterate through a row by
782 // decrementing through every column index.
783 if indexes[D - 1] == 0 {
784 bounds[D - 1] = true;
785 } else {
786 indexes[D - 1] -= 1;
787 }
788 for d in (1..D).rev() {
789 if bounds[d] {
790 // ran to start of this dimension with our index
791 // In the 2D case, we finished indexing through every column in the row,
792 // and it's now time to move onto the next row.
793 indexes[d] = shape[d].1 - 1;
794 if indexes[d - 1] == 0 {
795 bounds[d - 1] = true;
796 } else {
797 indexes[d - 1] -= 1;
798 }
799 }
800 }
801 // Check if we reached the first index
802 if bounds[0] {
803 *finished = true;
804 }
805 } else {
806 *finished = true;
807 }
808
809 value
810}
811
812/// Common size hint logic
813fn size_hint<const D: usize>(
814 finished: bool,
815 indexes: &[usize; D],
816 shape: &[(Dimension, usize); D],
817) -> (usize, Option<usize>) {
818 if finished {
819 return (0, Some(0));
820 }
821
822 let remaining = if D > 0 {
823 let total = dimensions::elements(shape);
824 let strides = crate::tensors::compute_strides(shape);
825 let seen = crate::tensors::get_index_direct_unchecked(indexes, &strides);
826 total - seen
827 } else {
828 1
829 // If D == 0 and we're not finished we've not returned the sole index yet so there's
830 // exactly 1 left
831 };
832
833 (remaining, Some(remaining))
834}
835
836/// Common size hint logic
837fn double_ended_size_hint<const D: usize>(
838 finished: bool,
839 forward_indexes: &[usize; D],
840 back_indexes: &[usize; D],
841 shape: &[(Dimension, usize); D],
842) -> (usize, Option<usize>) {
843 if finished {
844 return (0, Some(0));
845 }
846
847 let remaining = if D > 0 {
848 //let total = dimensions::elements(shape);
849 let strides = crate::tensors::compute_strides(shape);
850 let progress_forward =
851 crate::tensors::get_index_direct_unchecked(forward_indexes, &strides);
852 let progress_backward = crate::tensors::get_index_direct_unchecked(back_indexes, &strides);
853 // progress_forward will range from 0 if we've not iterated forward at all yet
854 // through to the total-1 if we are on the final index at the end.
855 // likewise progress_backward starts at total-1 and finishes at 0 when on the first
856 // index.
857 // To calculate total left going forward (as in forward only case) and then
858 // subtract the total already seen backward we'd have:
859 // (total - progress_forward) - ((total - 1) - progress_backward)
860 // This cancels to
861 1 + progress_backward - progress_forward
862 } else {
863 1
864 // If D == 0 and we're not finished we've not returned the sole index yet so there's
865 // exactly 1 left
866 };
867
868 (remaining, Some(remaining))
869}
870
871/**
872 * An iterator over all indexes in a shape which can iterate in both directions.
873 *
874 * Going forwards, first the all 0 index is iterated, then each iteration increments the rightmost
875 * index.
876 * For a shape of `[("a", 2), ("b", 2), ("c", 2)]` this will yield indexes in order of: `[0,0,0]`,
877 * `[0,0,1]`, `[0,1,0]`, `[0,1,1]`, `[1,0,0]`, `[1,0,1]`, `[1,1,0]`, `[1,1,1]`,
878 * When iterating backwards, the indexes are yielded in reverse. Indexes do not cross,
879 * iteration is over when they indexes meet in the middle.
880 */
881#[derive(Clone, Debug)]
882pub struct DoubleEndedShapeIterator<const D: usize> {
883 shape: [(Dimension, usize); D],
884 forward_indexes: [usize; D],
885 back_indexes: [usize; D],
886 finished: bool,
887}
888
889impl<const D: usize> DoubleEndedShapeIterator<D> {
890 /**
891 * Constructs a DoubleEndedShapeIterator for a shape.
892 *
893 * If the shape has any dimensions with a length of zero, the iterator will immediately
894 * return None on [`next()`](Iterator::next) or
895 * [`next_back()`](DoubleEndedIterator::next_back()).
896 */
897 pub fn from(shape: [(Dimension, usize); D]) -> DoubleEndedShapeIterator<D> {
898 // If we're given an invalid shape (shape input is not neccessarily going to meet the no
899 // 0 lengths contract of TensorRef because that's not actually required here), return
900 // a finished iterator
901 // Since this is an iterator over an owned shape, it's not going to become invalid later
902 // when we start iterating so this is the only check we need.
903 let starting_index_valid = shape.iter().all(|(_, l)| *l > 0);
904 DoubleEndedShapeIterator {
905 shape,
906 forward_indexes: [0; D],
907 back_indexes: shape.map(|(_, l)| l - 1),
908 finished: !starting_index_valid,
909 }
910 }
911}
912
913fn overlapping_iterators<const D: usize>(
914 forward_indexes: &[usize; D],
915 back_indexes: &[usize; D],
916) -> bool {
917 forward_indexes == back_indexes
918}
919
920impl<const D: usize> Iterator for DoubleEndedShapeIterator<D> {
921 type Item = [usize; D];
922
923 fn next(&mut self) -> Option<Self::Item> {
924 let will_finish = overlapping_iterators(&self.forward_indexes, &self.back_indexes);
925 let item = iter(&mut self.finished, &mut self.forward_indexes, &self.shape);
926 if will_finish {
927 self.finished = true;
928 }
929 item
930 }
931
932 fn size_hint(&self) -> (usize, Option<usize>) {
933 double_ended_size_hint(
934 self.finished,
935 &self.forward_indexes,
936 &self.back_indexes,
937 &self.shape,
938 )
939 }
940}
941
942impl<const D: usize> DoubleEndedIterator for DoubleEndedShapeIterator<D> {
943 fn next_back(&mut self) -> Option<Self::Item> {
944 let will_finish = overlapping_iterators(&self.forward_indexes, &self.back_indexes);
945 let item = iter_back(&mut self.finished, &mut self.back_indexes, &self.shape);
946 if will_finish {
947 self.finished = true;
948 }
949 item
950 }
951}
952
953// Once we hit the end we mark ourselves as finished so we're always Fused.
954impl<const D: usize> FusedIterator for DoubleEndedShapeIterator<D> {}
955// We can always calculate the exact number of steps remaining because the shape and indexes are
956// private fields that are only mutated by `next` to count up.
957impl<const D: usize> ExactSizeIterator for DoubleEndedShapeIterator<D> {}
958
959/**
960 * An iterator over copies of all values in a tensor.
961 *
962 * First the all 0 index is iterated, then each iteration increments the rightmost index.
963 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
964 * this will take a single step in memory on each iteration, akin to iterating through the
965 * flattened data of the tensor.
966 *
967 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
968 * will still iterate the rightmost index allowing iteration through dimensions in a different
969 * order to how they are stored, but no longer taking a single step in memory on each
970 * iteration (which may be less cache friendly for the CPU).
971 *
972 * ```
973 * use easy_ml::tensors::Tensor;
974 * let tensor_0 = Tensor::from_scalar(1);
975 * let tensor_1 = Tensor::from([("a", 7)], vec![ 1, 2, 3, 4, 5, 6, 7 ]);
976 * let tensor_2 = Tensor::from([("a", 2), ("b", 3)], vec![
977 * // two rows, three columns
978 * 1, 2, 3,
979 * 4, 5, 6
980 * ]);
981 * let tensor_3 = Tensor::from([("a", 2), ("b", 1), ("c", 2)], vec![
982 * // two rows each a single column, stacked on top of each other
983 * 1,
984 * 2,
985 *
986 * 3,
987 * 4
988 * ]);
989 * let tensor_access_0 = tensor_0.index_by([]);
990 * let tensor_access_1 = tensor_1.index_by(["a"]);
991 * let tensor_access_2 = tensor_2.index_by(["a", "b"]);
992 * let tensor_access_2_rev = tensor_2.index_by(["b", "a"]);
993 * let tensor_access_3 = tensor_3.index_by(["a", "b", "c"]);
994 * let tensor_access_3_rev = tensor_3.index_by(["c", "b", "a"]);
995 * assert_eq!(
996 * tensor_0.iter().collect::<Vec<i32>>(),
997 * vec![1]
998 * );
999 * assert_eq!(
1000 * tensor_access_0.iter().collect::<Vec<i32>>(),
1001 * vec![1]
1002 * );
1003 * assert_eq!(
1004 * tensor_1.iter().collect::<Vec<i32>>(),
1005 * vec![1, 2, 3, 4, 5, 6, 7]
1006 * );
1007 * assert_eq!(
1008 * tensor_access_1.iter().collect::<Vec<i32>>(),
1009 * vec![1, 2, 3, 4, 5, 6, 7]
1010 * );
1011 * assert_eq!(
1012 * tensor_2.iter().collect::<Vec<i32>>(),
1013 * vec![1, 2, 3, 4, 5, 6]
1014 * );
1015 * assert_eq!(
1016 * tensor_access_2.iter().collect::<Vec<i32>>(),
1017 * vec![1, 2, 3, 4, 5, 6]
1018 * );
1019 * assert_eq!(
1020 * tensor_access_2.iter().rev().collect::<Vec<i32>>(),
1021 * vec![6, 5, 4, 3, 2, 1]
1022 * );
1023 * assert_eq!(
1024 * tensor_access_2_rev.iter().collect::<Vec<i32>>(),
1025 * vec![1, 4, 2, 5, 3, 6]
1026 * );
1027 * assert_eq!(
1028 * tensor_3.iter().collect::<Vec<i32>>(),
1029 * vec![1, 2, 3, 4]
1030 * );
1031 * assert_eq!(
1032 * tensor_3.iter().rev().collect::<Vec<i32>>(),
1033 * vec![4, 3, 2, 1]
1034 * );
1035 * assert_eq!(
1036 * tensor_access_3.iter().collect::<Vec<i32>>(),
1037 * vec![1, 2, 3, 4]
1038 * );
1039 * assert_eq!(
1040 * tensor_access_3_rev.iter().collect::<Vec<i32>>(),
1041 * vec![1, 3, 2, 4]
1042 * );
1043 * ```
1044 */
1045#[derive(Debug)]
1046pub struct TensorIterator<'a, T, S, const D: usize> {
1047 shape_iterator: DoubleEndedShapeIterator<D>,
1048 source: &'a S,
1049 _type: PhantomData<T>,
1050}
1051
1052impl<'a, T, S, const D: usize> TensorIterator<'a, T, S, D>
1053where
1054 T: Clone,
1055 S: TensorRef<T, D>,
1056{
1057 pub fn from(source: &S) -> TensorIterator<'_, T, S, D> {
1058 TensorIterator {
1059 shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1060 source,
1061 _type: PhantomData,
1062 }
1063 }
1064
1065 /**
1066 * Constructs an iterator which also yields the indexes of each element in
1067 * this iterator.
1068 */
1069 pub fn with_index(self) -> WithIndex<Self> {
1070 WithIndex { iterator: self }
1071 }
1072}
1073
1074impl<'a, T, S, const D: usize> From<TensorIterator<'a, T, S, D>>
1075 for WithIndex<TensorIterator<'a, T, S, D>>
1076where
1077 T: Clone,
1078 S: TensorRef<T, D>,
1079{
1080 fn from(iterator: TensorIterator<'a, T, S, D>) -> Self {
1081 iterator.with_index()
1082 }
1083}
1084
1085impl<'a, T, S, const D: usize> Iterator for TensorIterator<'a, T, S, D>
1086where
1087 T: Clone,
1088 S: TensorRef<T, D>,
1089{
1090 type Item = T;
1091
1092 fn next(&mut self) -> Option<Self::Item> {
1093 // Safety: Our iterator only iterates over the correct indexes into our tensor's shape as
1094 // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
1095 // immutable reference to our tensor source, it can't be resized which ensures
1096 // DoubleEndedShapeIterator can always yield valid indexes for our iteration.
1097 self.shape_iterator
1098 .next()
1099 .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) }.clone())
1100 }
1101
1102 fn size_hint(&self) -> (usize, Option<usize>) {
1103 self.shape_iterator.size_hint()
1104 }
1105}
1106
1107impl<'a, T, S, const D: usize> DoubleEndedIterator for TensorIterator<'a, T, S, D>
1108where
1109 T: Clone,
1110 S: TensorRef<T, D>,
1111{
1112 fn next_back(&mut self) -> Option<Self::Item> {
1113 // Safety: Our iterator only iterates over the correct indexes into our tensor's shape as
1114 // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
1115 // immutable reference to our tensor source, it can't be resized which ensures
1116 // DoubleEndedShapeIterator can always yield valid indexes for our iteration.
1117 self.shape_iterator
1118 .next_back()
1119 .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) }.clone())
1120 }
1121}
1122
1123impl<'a, T, S, const D: usize> FusedIterator for TensorIterator<'a, T, S, D>
1124where
1125 T: Clone,
1126 S: TensorRef<T, D>,
1127{
1128}
1129
1130impl<'a, T, S, const D: usize> ExactSizeIterator for TensorIterator<'a, T, S, D>
1131where
1132 T: Clone,
1133 S: TensorRef<T, D>,
1134{
1135}
1136
1137impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorIterator<'a, T, S, D>>
1138where
1139 T: Clone,
1140 S: TensorRef<T, D>,
1141{
1142 type Item = ([usize; D], T);
1143
1144 fn next(&mut self) -> Option<Self::Item> {
1145 let index = self.iterator.shape_iterator.forward_indexes;
1146 self.iterator.next().map(|x| (index, x))
1147 }
1148
1149 fn size_hint(&self) -> (usize, Option<usize>) {
1150 self.iterator.size_hint()
1151 }
1152}
1153
1154impl<'a, T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorIterator<'a, T, S, D>>
1155where
1156 T: Clone,
1157 S: TensorRef<T, D>,
1158{
1159 fn next_back(&mut self) -> Option<Self::Item> {
1160 let index = self.iterator.shape_iterator.back_indexes;
1161 self.iterator.next_back().map(|x| (index, x))
1162 }
1163}
1164
1165impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorIterator<'a, T, S, D>>
1166where
1167 T: Clone,
1168 S: TensorRef<T, D>,
1169{
1170}
1171
1172impl<'a, T, S, const D: usize> ExactSizeIterator for WithIndex<TensorIterator<'a, T, S, D>>
1173where
1174 T: Clone,
1175 S: TensorRef<T, D>,
1176{
1177}
1178
1179/**
1180 * An iterator over references to all values in a tensor.
1181 *
1182 * First the all 0 index is iterated, then each iteration increments the rightmost index.
1183 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
1184 * this will take a single step in memory on each iteration, akin to iterating through the
1185 * flattened data of the tensor.
1186 *
1187 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
1188 * will still iterate the rightmost index allowing iteration through dimensions in a different
1189 * order to how they are stored, but no longer taking a single step in memory on each
1190 * iteration (which may be less cache friendly for the CPU).
1191 *
1192 * ```
1193 * use easy_ml::tensors::Tensor;
1194 * let tensor_0 = Tensor::from_scalar(1);
1195 * let tensor_1 = Tensor::from([("a", 7)], vec![ 1, 2, 3, 4, 5, 6, 7 ]);
1196 * let tensor_2 = Tensor::from([("a", 2), ("b", 3)], vec![
1197 * // two rows, three columns
1198 * 1, 2, 3,
1199 * 4, 5, 6
1200 * ]);
1201 * let tensor_3 = Tensor::from([("a", 2), ("b", 1), ("c", 2)], vec![
1202 * // two rows each a single column, stacked on top of each other
1203 * 1,
1204 * 2,
1205 *
1206 * 3,
1207 * 4
1208 * ]);
1209 * let tensor_access_0 = tensor_0.index_by([]);
1210 * let tensor_access_1 = tensor_1.index_by(["a"]);
1211 * let tensor_access_2 = tensor_2.index_by(["a", "b"]);
1212 * let tensor_access_2_rev = tensor_2.index_by(["b", "a"]);
1213 * let tensor_access_3 = tensor_3.index_by(["a", "b", "c"]);
1214 * let tensor_access_3_rev = tensor_3.index_by(["c", "b", "a"]);
1215 * assert_eq!(
1216 * tensor_0.iter_reference().cloned().collect::<Vec<i32>>(),
1217 * vec![1]
1218 * );
1219 * assert_eq!(
1220 * tensor_access_0.iter_reference().cloned().collect::<Vec<i32>>(),
1221 * vec![1]
1222 * );
1223 * assert_eq!(
1224 * tensor_1.iter_reference().cloned().collect::<Vec<i32>>(),
1225 * vec![1, 2, 3, 4, 5, 6, 7]
1226 * );
1227 * assert_eq!(
1228 * tensor_access_1.iter_reference().cloned().collect::<Vec<i32>>(),
1229 * vec![1, 2, 3, 4, 5, 6, 7]
1230 * );
1231 * assert_eq!(
1232 * tensor_2.iter_reference().cloned().collect::<Vec<i32>>(),
1233 * vec![1, 2, 3, 4, 5, 6]
1234 * );
1235 * assert_eq!(
1236 * tensor_2.iter_reference().rev().cloned().collect::<Vec<i32>>(),
1237 * vec![6, 5, 4, 3, 2, 1]
1238 * );
1239 * assert_eq!(
1240 * tensor_access_2.iter_reference().cloned().collect::<Vec<i32>>(),
1241 * vec![1, 2, 3, 4, 5, 6]
1242 * );
1243 * assert_eq!(
1244 * tensor_access_2_rev.iter_reference().cloned().collect::<Vec<i32>>(),
1245 * vec![1, 4, 2, 5, 3, 6]
1246 * );
1247 * assert_eq!(
1248 * tensor_3.iter_reference().cloned().collect::<Vec<i32>>(),
1249 * vec![1, 2, 3, 4]
1250 * );
1251 * assert_eq!(
1252 * tensor_3.iter_reference().rev().cloned().collect::<Vec<i32>>(),
1253 * vec![4, 3, 2, 1]
1254 * );
1255 * assert_eq!(
1256 * tensor_access_3.iter_reference().cloned().collect::<Vec<i32>>(),
1257 * vec![1, 2, 3, 4]
1258 * );
1259 * assert_eq!(
1260 * tensor_access_3_rev.iter_reference().cloned().collect::<Vec<i32>>(),
1261 * vec![1, 3, 2, 4]
1262 * );
1263 * ```
1264 */
1265#[derive(Debug)]
1266pub struct TensorReferenceIterator<'a, T, S, const D: usize> {
1267 shape_iterator: DoubleEndedShapeIterator<D>,
1268 source: &'a S,
1269 _type: PhantomData<&'a T>,
1270}
1271
1272impl<'a, T, S, const D: usize> TensorReferenceIterator<'a, T, S, D>
1273where
1274 S: TensorRef<T, D>,
1275{
1276 pub fn from(source: &S) -> TensorReferenceIterator<'_, T, S, D> {
1277 TensorReferenceIterator {
1278 shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1279 source,
1280 _type: PhantomData,
1281 }
1282 }
1283
1284 /**
1285 * Constructs an iterator which also yields the indexes of each element in
1286 * this iterator.
1287 */
1288 pub fn with_index(self) -> WithIndex<Self> {
1289 WithIndex { iterator: self }
1290 }
1291}
1292
1293impl<'a, T, S, const D: usize> From<TensorReferenceIterator<'a, T, S, D>>
1294 for WithIndex<TensorReferenceIterator<'a, T, S, D>>
1295where
1296 S: TensorRef<T, D>,
1297{
1298 fn from(iterator: TensorReferenceIterator<'a, T, S, D>) -> Self {
1299 iterator.with_index()
1300 }
1301}
1302
1303impl<'a, T, S, const D: usize> Iterator for TensorReferenceIterator<'a, T, S, D>
1304where
1305 S: TensorRef<T, D>,
1306{
1307 type Item = &'a T;
1308
1309 fn next(&mut self) -> Option<Self::Item> {
1310 // Safety: Our iterator only iterates over the correct indexes into our tensor's shape as
1311 // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
1312 // immutable reference to our tensor source, it can't be resized which ensures
1313 // DoubleEndedIterator can always yield valid indexes for our iteration.
1314 self.shape_iterator
1315 .next()
1316 .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) })
1317 }
1318
1319 fn size_hint(&self) -> (usize, Option<usize>) {
1320 self.shape_iterator.size_hint()
1321 }
1322}
1323
1324impl<'a, T, S, const D: usize> DoubleEndedIterator for TensorReferenceIterator<'a, T, S, D>
1325where
1326 S: TensorRef<T, D>,
1327{
1328 fn next_back(&mut self) -> Option<Self::Item> {
1329 // Safety: Our iterator only iterates over the correct indexes into our tensor's shape as
1330 // defined by TensorRef. Since TensorRef promises no interior mutability and we hold an
1331 // immutable reference to our tensor source, it can't be resized which ensures
1332 // DoubleEndedIterator can always yield valid indexes for our iteration.
1333 self.shape_iterator
1334 .next_back()
1335 .map(|indexes| unsafe { self.source.get_reference_unchecked(indexes) })
1336 }
1337}
1338
1339impl<'a, T, S, const D: usize> FusedIterator for TensorReferenceIterator<'a, T, S, D> where
1340 S: TensorRef<T, D>
1341{
1342}
1343
1344impl<'a, T, S, const D: usize> ExactSizeIterator for TensorReferenceIterator<'a, T, S, D> where
1345 S: TensorRef<T, D>
1346{
1347}
1348
1349impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorReferenceIterator<'a, T, S, D>>
1350where
1351 S: TensorRef<T, D>,
1352{
1353 type Item = ([usize; D], &'a T);
1354
1355 fn next(&mut self) -> Option<Self::Item> {
1356 let index = self.iterator.shape_iterator.forward_indexes;
1357 self.iterator.next().map(|x| (index, x))
1358 }
1359
1360 fn size_hint(&self) -> (usize, Option<usize>) {
1361 self.iterator.size_hint()
1362 }
1363}
1364
1365impl<'a, T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>>
1366where
1367 S: TensorRef<T, D>,
1368{
1369 fn next_back(&mut self) -> Option<Self::Item> {
1370 let index = self.iterator.shape_iterator.back_indexes;
1371 self.iterator.next_back().map(|x| (index, x))
1372 }
1373}
1374
1375impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>> where
1376 S: TensorRef<T, D>
1377{
1378}
1379
1380impl<'a, T, S, const D: usize> ExactSizeIterator for WithIndex<TensorReferenceIterator<'a, T, S, D>> where
1381 S: TensorRef<T, D>
1382{
1383}
1384
1385/**
1386 * An iterator over mutable references to all values in a tensor.
1387 *
1388 * First the all 0 index is iterated, then each iteration increments the rightmost index.
1389 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
1390 * this will take a single step in memory on each iteration, akin to iterating through the
1391 * flattened data of the tensor.
1392 *
1393 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
1394 * will still iterate the rightmost index allowing iteration through dimensions in a different
1395 * order to how they are stored, but no longer taking a single step in memory on each
1396 * iteration (which may be less cache friendly for the CPU).
1397 *
1398 * ```
1399 * use easy_ml::tensors::Tensor;
1400 * let mut tensor = Tensor::from([("a", 7)], vec![ 1, 2, 3, 4, 5, 6, 7 ]);
1401 * let doubled = tensor.map(|x| 2 * x);
1402 * // mutating a tensor in place can also be done with Tensor::map_mut and
1403 * // Tensor::map_mut_with_index
1404 * for elem in tensor.iter_reference_mut() {
1405 * *elem = 2 * *elem;
1406 * }
1407 * assert_eq!(
1408 * tensor,
1409 * doubled,
1410 * );
1411 * ```
1412 */
1413#[derive(Debug)]
1414pub struct TensorReferenceMutIterator<'a, T, S, const D: usize> {
1415 shape_iterator: DoubleEndedShapeIterator<D>,
1416 source: &'a mut S,
1417 _type: PhantomData<&'a mut T>,
1418}
1419
1420impl<'a, T, S, const D: usize> TensorReferenceMutIterator<'a, T, S, D>
1421where
1422 S: TensorMut<T, D>,
1423{
1424 pub fn from(source: &mut S) -> TensorReferenceMutIterator<'_, T, S, D> {
1425 TensorReferenceMutIterator {
1426 shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1427 source,
1428 _type: PhantomData,
1429 }
1430 }
1431
1432 /**
1433 * Constructs an iterator which also yields the indexes of each element in
1434 * this iterator.
1435 */
1436 pub fn with_index(self) -> WithIndex<Self> {
1437 WithIndex { iterator: self }
1438 }
1439}
1440
1441impl<'a, T, S, const D: usize> From<TensorReferenceMutIterator<'a, T, S, D>>
1442 for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1443where
1444 S: TensorMut<T, D>,
1445{
1446 fn from(iterator: TensorReferenceMutIterator<'a, T, S, D>) -> Self {
1447 iterator.with_index()
1448 }
1449}
1450
1451impl<'a, T, S, const D: usize> Iterator for TensorReferenceMutIterator<'a, T, S, D>
1452where
1453 S: TensorMut<T, D>,
1454{
1455 type Item = &'a mut T;
1456
1457 fn next(&mut self) -> Option<Self::Item> {
1458 self.shape_iterator.next().map(|indexes| {
1459 unsafe {
1460 // Safety: We are not allowed to give out overlapping mutable references,
1461 // but since we will always increment the counter on every call to next()
1462 // and stop when we reach the end no references will overlap.
1463 // The compiler doesn't know this, so transmute the lifetime for it.
1464 // Safety: DoubleEndedShapeIterator only iterates over the correct indexes into our
1465 // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1466 // mutability and we hold an exclusive reference to our tensor source, it can't
1467 // be resized (except by us - and we don't) which ensures DoubleEndedShapeIterator
1468 // can always yield valid indexes for our iteration.
1469 std::mem::transmute::<&mut T, &mut T>(
1470 self.source.get_reference_unchecked_mut(indexes)
1471 )
1472 }
1473 })
1474 }
1475
1476 fn size_hint(&self) -> (usize, Option<usize>) {
1477 self.shape_iterator.size_hint()
1478 }
1479}
1480
1481impl<'a, T, S, const D: usize> DoubleEndedIterator for TensorReferenceMutIterator<'a, T, S, D>
1482where
1483 S: TensorMut<T, D>,
1484{
1485 fn next_back(&mut self) -> Option<Self::Item> {
1486 self.shape_iterator.next_back().map(|indexes| {
1487 unsafe {
1488 // Safety: We are not allowed to give out overlapping mutable references,
1489 // but since we will always increment the counter on every call to next()
1490 // and stop when we reach the end no references will overlap.
1491 // The compiler doesn't know this, so transmute the lifetime for it.
1492 // Safety: Our iterator only iterates over the correct indexes into our
1493 // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1494 // mutability and we hold an exclusive reference to our tensor source, it can't
1495 // be resized (except by us - and we don't) which ensures DoubleEndedShapeIterator
1496 // can always yield valid indexes for our iteration.
1497 std::mem::transmute::<&mut T, &mut T>(
1498 self.source.get_reference_unchecked_mut(indexes)
1499 )
1500 }
1501 })
1502 }
1503}
1504
1505impl<'a, T, S, const D: usize> FusedIterator for TensorReferenceMutIterator<'a, T, S, D> where
1506 S: TensorMut<T, D>
1507{
1508}
1509
1510impl<'a, T, S, const D: usize> ExactSizeIterator for TensorReferenceMutIterator<'a, T, S, D> where
1511 S: TensorMut<T, D>
1512{
1513}
1514
1515impl<'a, T, S, const D: usize> Iterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1516where
1517 S: TensorMut<T, D>,
1518{
1519 type Item = ([usize; D], &'a mut T);
1520
1521 fn next(&mut self) -> Option<Self::Item> {
1522 let index = self.iterator.shape_iterator.forward_indexes;
1523 self.iterator.next().map(|x| (index, x))
1524 }
1525
1526 fn size_hint(&self) -> (usize, Option<usize>) {
1527 self.iterator.size_hint()
1528 }
1529}
1530
1531impl<'a, T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1532where
1533 S: TensorMut<T, D>,
1534{
1535 fn next_back(&mut self) -> Option<Self::Item> {
1536 let index = self.iterator.shape_iterator.back_indexes;
1537 self.iterator.next_back().map(|x| (index, x))
1538 }
1539}
1540
1541
1542impl<'a, T, S, const D: usize> FusedIterator for WithIndex<TensorReferenceMutIterator<'a, T, S, D>> where
1543 S: TensorMut<T, D>
1544{
1545}
1546
1547impl<'a, T, S, const D: usize> ExactSizeIterator
1548 for WithIndex<TensorReferenceMutIterator<'a, T, S, D>>
1549where
1550 S: TensorMut<T, D>,
1551{
1552}
1553
1554/**
1555 * An iterator over all values in an owned tensor.
1556 *
1557 * This iterator does not clone the values, it returns the actual values stored in the tensor.
1558 * There is no such method to return `T` by value from a [TensorRef]/[TensorMut], to do
1559 * this it [replaces](std::mem::replace) the values with dummy values. Hence it can only be
1560 * created for types that implement [Default] or [ZeroOne](crate::numeric::ZeroOne)
1561 * from [Numeric](crate::numeric) which provide a means to create dummy values.
1562 *
1563 * First the all 0 index is iterated, then each iteration increments the rightmost index.
1564 * For [Tensor] or [TensorRef]s which do not reorder the underlying Tensor
1565 * this will take a single step in memory on each iteration, akin to iterating through the
1566 * flattened data of the tensor.
1567 *
1568 * If the TensorRef reorders the tensor data (e.g. [TensorAccess]) this iterator
1569 * will still iterate the rightmost index allowing iteration through dimensions in a different
1570 * order to how they are stored, but no longer taking a single step in memory on each
1571 * iteration (which may be less cache friendly for the CPU).
1572 *
1573 * ```
1574 * use easy_ml::tensors::Tensor;
1575 *
1576 * #[derive(Debug, Default, Eq, PartialEq)]
1577 * struct NoClone(i32);
1578 *
1579 * let tensor = Tensor::from([("a", 3)], vec![ NoClone(1), NoClone(2), NoClone(3) ]);
1580 * let values = tensor.iter_owned(); // will use T::default() for dummy values
1581 * assert_eq!(vec![ NoClone(1), NoClone(2), NoClone(3) ], values.collect::<Vec<NoClone>>());
1582 * ```
1583 */
1584#[derive(Debug)]
1585pub struct TensorOwnedIterator<T, S, const D: usize> {
1586 shape_iterator: DoubleEndedShapeIterator<D>,
1587 source: S,
1588 producer: fn() -> T,
1589}
1590
1591impl<T, S, const D: usize> TensorOwnedIterator<T, S, D>
1592where
1593 S: TensorMut<T, D>,
1594{
1595 /**
1596 * Creates the TensorOwnedIterator from a source where the default values will be provided
1597 * by [Default::default]. This constructor is also used by the convenience
1598 * methods on [Tensor::iter_owned](Tensor::iter_owned) and
1599 * [TensorView::iter_owned](crate::tensors::views::TensorView::iter_owned).
1600 */
1601 pub fn from(source: S) -> TensorOwnedIterator<T, S, D>
1602 where
1603 T: Default,
1604 {
1605 TensorOwnedIterator {
1606 shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1607 source,
1608 producer: || T::default(),
1609 }
1610 }
1611
1612 /**
1613 * Creates the TensorOwnedIterator from a source where the default values will be provided
1614 * by [ZeroOne::zero](crate::numeric::ZeroOne::zero).
1615 */
1616 pub fn from_numeric(source: S) -> TensorOwnedIterator<T, S, D>
1617 where
1618 T: crate::numeric::ZeroOne,
1619 {
1620 TensorOwnedIterator {
1621 shape_iterator: DoubleEndedShapeIterator::from(source.view_shape()),
1622 source,
1623 producer: || T::zero(),
1624 }
1625 }
1626
1627 /**
1628 * Constructs an iterator which also yields the indexes of each element in
1629 * this iterator.
1630 */
1631 pub fn with_index(self) -> WithIndex<Self> {
1632 WithIndex { iterator: self }
1633 }
1634}
1635
1636impl<T, S, const D: usize> From<TensorOwnedIterator<T, S, D>>
1637 for WithIndex<TensorOwnedIterator<T, S, D>>
1638where
1639 S: TensorMut<T, D>,
1640{
1641 fn from(iterator: TensorOwnedIterator<T, S, D>) -> Self {
1642 iterator.with_index()
1643 }
1644}
1645
1646impl<T, S, const D: usize> Iterator for TensorOwnedIterator<T, S, D>
1647where
1648 S: TensorMut<T, D>,
1649{
1650 type Item = T;
1651
1652 fn next(&mut self) -> Option<Self::Item> {
1653 self.shape_iterator.next().map(|indexes| {
1654 let producer = self.producer;
1655 let dummy = producer();
1656 // Safety: DoubleEndedShapeIterator only iterates over the correct indexes into our
1657 // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1658 // mutability and we hold our tensor source by value, it can't be resized (except by
1659 // us - and we don't) which ensures it can always yield valid indexes for
1660 // our iteration.
1661 std::mem::replace(
1662 unsafe { self.source.get_reference_unchecked_mut(indexes) },
1663 dummy,
1664 )
1665 })
1666 }
1667
1668 fn size_hint(&self) -> (usize, Option<usize>) {
1669 self.shape_iterator.size_hint()
1670 }
1671}
1672
1673impl<T, S, const D: usize> DoubleEndedIterator for TensorOwnedIterator<T, S, D>
1674where
1675 S: TensorMut<T, D>,
1676{
1677 fn next_back(&mut self) -> Option<Self::Item> {
1678 self.shape_iterator.next_back().map(|indexes| {
1679 let producer = self.producer;
1680 let dummy = producer();
1681 // Safety: DoubleEndedShapeIterator only iterates over the correct indexes into our
1682 // tensor's shape as defined by TensorRef. Since TensorRef promises no interior
1683 // mutability and we hold our tensor source by value, it can't be resized (except by
1684 // us - and we don't) which ensures it can always yield valid indexes for
1685 // our iteration.
1686 std::mem::replace(
1687 unsafe { self.source.get_reference_unchecked_mut(indexes) },
1688 dummy,
1689 )
1690 })
1691 }
1692}
1693
1694impl<T, S, const D: usize> FusedIterator for TensorOwnedIterator<T, S, D> where S: TensorMut<T, D> {}
1695
1696impl<T, S, const D: usize> ExactSizeIterator for TensorOwnedIterator<T, S, D> where
1697 S: TensorMut<T, D>
1698{
1699}
1700
1701impl<T, S, const D: usize> Iterator for WithIndex<TensorOwnedIterator<T, S, D>>
1702where
1703 S: TensorMut<T, D>,
1704{
1705 type Item = ([usize; D], T);
1706
1707 fn next(&mut self) -> Option<Self::Item> {
1708 let index = self.iterator.shape_iterator.forward_indexes;
1709 self.iterator.next().map(|x| (index, x))
1710 }
1711
1712 fn size_hint(&self) -> (usize, Option<usize>) {
1713 self.iterator.size_hint()
1714 }
1715}
1716
1717impl<T, S, const D: usize> DoubleEndedIterator for WithIndex<TensorOwnedIterator<T, S, D>>
1718where
1719 S: TensorMut<T, D>,
1720{
1721 fn next_back(&mut self) -> Option<Self::Item> {
1722 let index = self.iterator.shape_iterator.back_indexes;
1723 self.iterator.next_back().map(|x| (index, x))
1724 }
1725}
1726
1727impl<T, S, const D: usize> FusedIterator for WithIndex<TensorOwnedIterator<T, S, D>> where
1728 S: TensorMut<T, D>
1729{
1730}
1731
1732impl<T, S, const D: usize> ExactSizeIterator for WithIndex<TensorOwnedIterator<T, S, D>> where
1733 S: TensorMut<T, D>
1734{
1735}
1736
1737/**
1738 * A TensorTranspose makes the data in the tensor it is created from appear to be in a different
1739 * order, swapping the lengths of each named dimension to match the new order but leaving the
1740 * dimension name order unchanged.
1741 *
1742 * ```
1743 * use easy_ml::tensors::Tensor;
1744 * use easy_ml::tensors::indexing::TensorTranspose;
1745 * use easy_ml::tensors::views::TensorView;
1746 * let tensor = Tensor::from([("batch", 2), ("rows", 3), ("columns", 2)], vec![
1747 * 1, 2,
1748 * 3, 4,
1749 * 5, 6,
1750 *
1751 * 7, 8,
1752 * 9, 0,
1753 * 1, 2
1754 * ]);
1755 * let transposed = TensorView::from(TensorTranspose::from(&tensor, ["batch", "columns", "rows"]));
1756 * assert_eq!(
1757 * transposed,
1758 * Tensor::from([("batch", 2), ("rows", 2), ("columns", 3)], vec![
1759 * 1, 3, 5,
1760 * 2, 4, 6,
1761 *
1762 * 7, 9, 1,
1763 * 8, 0, 2
1764 * ])
1765 * );
1766 * let also_transposed = tensor.transpose_view(["batch", "columns", "rows"]);
1767 * ```
1768 */
1769#[derive(Clone)]
1770pub struct TensorTranspose<T, S, const D: usize> {
1771 access: TensorAccess<T, S, D>,
1772}
1773
1774impl<T: fmt::Debug, S: fmt::Debug, const D: usize> fmt::Debug for TensorTranspose<T, S, D> {
1775 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1776 f.debug_struct("TensorTranspose")
1777 .field("source", &self.access.source)
1778 .field("dimension_mapping", &self.access.dimension_mapping)
1779 .field("_type", &self.access._type)
1780 .finish()
1781 }
1782}
1783
1784impl<T, S, const D: usize> TensorTranspose<T, S, D>
1785where
1786 S: TensorRef<T, D>,
1787{
1788 /**
1789 * Creates a TensorTranspose which makes the data appear in the order of the
1790 * supplied dimensions. The order of the dimension names is unchanged, although their lengths
1791 * may swap.
1792 *
1793 * # Panics
1794 *
1795 * If the set of dimensions in the tensor does not match the set of dimensions provided. The
1796 * order need not match.
1797 */
1798 #[track_caller]
1799 pub fn from(source: S, dimensions: [Dimension; D]) -> TensorTranspose<T, S, D> {
1800 TensorTranspose {
1801 access: match TensorAccess::try_from(source, dimensions) {
1802 Err(error) => panic!("{}", error),
1803 Ok(success) => success,
1804 },
1805 }
1806 }
1807
1808 /**
1809 * Creates a TensorTranspose which makes the data to appear in the order of the
1810 * supplied dimensions. The order of the dimension names is unchanged, although their lengths
1811 * may swap.
1812 *
1813 * Returns Err if the set of dimensions supplied do not match the set of dimensions in this
1814 * tensor's shape.
1815 */
1816 pub fn try_from(
1817 source: S,
1818 dimensions: [Dimension; D],
1819 ) -> Result<TensorTranspose<T, S, D>, InvalidDimensionsError<D>> {
1820 TensorAccess::try_from(source, dimensions).map(|access| TensorTranspose { access })
1821 }
1822
1823 /**
1824 * The shape of this TensorTranspose appears to rearrange the data to the order of supplied
1825 * dimensions. The actual data in the underlying tensor and the order of the dimension names
1826 * on this TensorTranspose remains unchanged, although the lengths of the dimensions in this
1827 * shape of may swap compared to the source's shape.
1828 */
1829 pub fn shape(&self) -> [(Dimension, usize); D] {
1830 let names = self.access.source.view_shape();
1831 let order = self.access.shape();
1832 std::array::from_fn(|d| (names[d].0, order[d].1))
1833 }
1834
1835 pub fn source(self) -> S {
1836 self.access.source
1837 }
1838
1839 // # Safety
1840 //
1841 // Giving out a mutable reference to our source could allow it to be changed out from under us
1842 // and make our dimmension mapping invalid. However, since the source implements TensorRef
1843 // interior mutability is not allowed, so we can give out shared references without breaking
1844 // our own integrity.
1845 pub fn source_ref(&self) -> &S {
1846 &self.access.source
1847 }
1848}
1849
1850// # Safety
1851//
1852// The TensorAccess must implement TensorRef correctly, so by delegating to it without changing
1853// anything other than the order of the dimension names we expose, we implement
1854// TensoTensorRefrMut correctly as well.
1855/**
1856 * A TensorTranspose implements TensorRef, with the dimension order and indexing matching that
1857 * of the TensorTranspose shape.
1858 */
1859unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorTranspose<T, S, D>
1860where
1861 S: TensorRef<T, D>,
1862{
1863 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
1864 // we didn't change the lengths of any dimension in our shape from the TensorAccess so we
1865 // can delegate to the tensor access for non named indexing here
1866 self.access.try_get_reference(indexes)
1867 }
1868
1869 fn view_shape(&self) -> [(Dimension, usize); D] {
1870 self.shape()
1871 }
1872
1873 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
1874 unsafe { self.access.get_reference_unchecked(indexes) }
1875 }
1876
1877 fn data_layout(&self) -> DataLayout<D> {
1878 let data_layout = self.access.data_layout();
1879 match data_layout {
1880 DataLayout::Linear(order) => DataLayout::Linear(
1881 self.access
1882 .dimension_mapping
1883 .map_linear_data_layout_to_transposed(&order),
1884 ),
1885 _ => data_layout,
1886 }
1887 }
1888}
1889
1890// # Safety
1891//
1892// The TensorAccess must implement TensorMut correctly, so so by delegating to it without changing
1893// anything other than the order of the dimension names we expose, we implement, we implement
1894// TensorMut correctly as well.
1895/**
1896 * A TensorTranspose implements TensorMut, with the dimension order and indexing matching that of
1897 * the TensorTranspose shape.
1898 */
1899unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorTranspose<T, S, D>
1900where
1901 S: TensorMut<T, D>,
1902{
1903 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
1904 self.access.try_get_reference_mut(indexes)
1905 }
1906
1907 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
1908 unsafe { self.access.get_reference_unchecked_mut(indexes) }
1909 }
1910}
1911
1912/**
1913 * Any tensor transpose of a Displayable type implements Display
1914 *
1915 * You can control the precision of the formatting using format arguments, i.e.
1916 * `format!("{:.3}", tensor)`
1917 */
1918impl<T: std::fmt::Display, S, const D: usize> std::fmt::Display for TensorTranspose<T, S, D>
1919where
1920 T: std::fmt::Display,
1921 S: TensorRef<T, D>,
1922{
1923 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1924 crate::tensors::display::format_view(&self, f)?;
1925 writeln!(f)?;
1926 write!(f, "Data Layout = {:?}", self.data_layout())
1927 }
1928}