easy_ml/tensors/views/reverse.rs
1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6/**
7 * A view over a tensor where some or all of the dimensions are iterated in reverse order.
8 *
9 * ```
10 * use easy_ml::tensors::Tensor;
11 * use easy_ml::tensors::views::{TensorView, TensorReverse};
12 * let ab = Tensor::from([("a", 2), ("b", 3)], (0..6).collect());
13 * let reversed = ab.reverse(&["a"]);
14 * let also_reversed = TensorView::from(TensorReverse::from(&ab, &["a"]));
15 * assert_eq!(reversed, also_reversed);
16 * assert_eq!(
17 * reversed,
18 * Tensor::from(
19 * [("a", 2), ("b", 3)],
20 * vec![
21 * 3, 4, 5,
22 * 0, 1, 2,
23 * ]
24 * )
25 * );
26 * ```
27 */
28#[derive(Clone, Debug)]
29pub struct TensorReverse<T, S, const D: usize> {
30 source: S,
31 reversed: [bool; D],
32 _type: PhantomData<T>,
33}
34
35impl<T, S, const D: usize> TensorReverse<T, S, D>
36where
37 S: TensorRef<T, D>,
38{
39 /**
40 * Creates a TensorReverse from a source and a list of dimension names to reverse the
41 * order of iteration for. The list cannot be more than the number of dimensions in the source
42 * but it does not need to contain every dimension in the source. Any dimensions in the source
43 * but not in the list of dimension names to reverse will continue to iterate in their normal
44 * order.
45 *
46 * # Panics
47 *
48 * - If a dimension name is not in the source's shape or is repeated.
49 */
50 #[track_caller]
51 pub fn from(source: S, dimensions: &[Dimension]) -> TensorReverse<T, S, D> {
52 if crate::tensors::dimensions::has_duplicates_names(dimensions) {
53 panic!("Dimension names must all be unique: {:?}", dimensions);
54 }
55 let shape = source.view_shape();
56 if let Some(dimension) = dimensions.iter().find(|d| !dimensions::contains(&shape, d)) {
57 panic!(
58 "Dimension names to reverse must be in the source: {:?} is not in {:?}",
59 dimension, shape
60 );
61 }
62 let reversed = std::array::from_fn(|i| dimensions.contains(&shape[i].0));
63 TensorReverse {
64 source,
65 reversed,
66 _type: PhantomData,
67 }
68 }
69
70 /**
71 * Consumes the TensorReverse, yielding the source it was created from.
72 */
73 pub fn source(self) -> S {
74 self.source
75 }
76
77 /**
78 * Gives a reference to the TensorReverse's source (in which the iteration order may be
79 * different).
80 */
81 pub fn source_ref(&self) -> &S {
82 &self.source
83 }
84
85 /**
86 * Gives a mutable reference to the TensorReverse's source (in which the iteration order may
87 * be different).
88 */
89 // # Safety
90 //
91 // Although we're giving out a mutable reference here and thus the Tensor could be modified
92 // by the caller, it's impossible to change the dimensionality of the source due to this
93 // being determind at compile time by const generics, so as we only need our reversed array to
94 // match the dimensionality of the source after any modifications we don't have any edge
95 // cases that could make it invalid.
96 pub fn source_ref_mut(&mut self) -> &mut S {
97 &mut self.source
98 }
99}
100
101pub(crate) fn reverse_indexes<const D: usize>(
102 indexes: &[usize; D],
103 shape: &[(Dimension, usize); D],
104 reversed: &[bool; D],
105) -> [usize; D] {
106 std::array::from_fn(|d| {
107 if reversed[d] {
108 let length = shape[d].1;
109 // TensorRef requires dimensions are not of 0 length, so this never underflows
110 let last_index = length - 1;
111 let index = indexes[d];
112 // swap dimension indexing, so 0 becomes length-1, and length-1 becomes 0
113 last_index - index
114 } else {
115 indexes[d]
116 }
117 })
118}
119
120// # Safety
121//
122// The type implementing TensorRef must implement it correctly, so by delegating to it
123// by only reversing some indexes and not introducing interior mutability, we implement
124// TensorRef correctly as well.
125/**
126 * A TensorReverse implements TensorRef, with the dimension names the TensorReverse was created
127 * with iterating in reverse order compared to the dimension names in the original source.
128 */
129unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorReverse<T, S, D>
130where
131 S: TensorRef<T, D>,
132{
133 fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
134 self.source.get_reference(reverse_indexes(
135 &indexes,
136 &self.view_shape(),
137 &self.reversed,
138 ))
139 }
140
141 fn view_shape(&self) -> [(Dimension, usize); D] {
142 self.source.view_shape()
143 }
144
145 unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
146 unsafe {
147 self.source.get_reference_unchecked(reverse_indexes(
148 &indexes,
149 &self.view_shape(),
150 &self.reversed,
151 ))
152 }
153 }
154
155 fn data_layout(&self) -> DataLayout<D> {
156 // There might be some specific cases where reversing maintains a linear order but
157 // in general I think reversing only some indexes is going to mean any attempt at being
158 // able to take a slice that matches up with our view_shape is gone.
159 DataLayout::Other
160 }
161}
162
163// # Safety
164//
165// The type implementing TensorMut must implement it correctly, so by delegating to it
166// by only reversing some indexes and not introducing interior mutability, we implement
167// TensorMut correctly as well.
168/**
169 * A TensorReverse implements TensorMut, with the dimension names the TensorReverse was created
170 * with iterating in reverse order compared to the dimension names in the original source.
171 */
172unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorReverse<T, S, D>
173where
174 S: TensorMut<T, D>,
175{
176 fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
177 self.source.get_reference_mut(reverse_indexes(
178 &indexes,
179 &self.view_shape(),
180 &self.reversed,
181 ))
182 }
183
184 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
185 unsafe {
186 self.source.get_reference_unchecked_mut(reverse_indexes(
187 &indexes,
188 &self.view_shape(),
189 &self.reversed,
190 ))
191 }
192 }
193}
194
195#[test]
196fn test_reversed_tensors() {
197 use crate::tensors::Tensor;
198 let tensor = Tensor::from([("a", 2), ("b", 3), ("c", 2)], (0..12).collect());
199 assert_eq!(
200 vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
201 tensor.iter().collect::<Vec<_>>()
202 );
203 let reversed = tensor.reverse_owned(&["a", "c"]);
204 assert_eq!(
205 vec![7, 6, 9, 8, 11, 10, 1, 0, 3, 2, 5, 4],
206 reversed.iter().collect::<Vec<_>>()
207 );
208}