easy_ml/tensors/views/
reverse.rs

1use crate::tensors::Dimension;
2use crate::tensors::dimensions;
3use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
4use std::marker::PhantomData;
5
6/**
7 * A view over a tensor where some or all of the dimensions are iterated in reverse order.
8 *
9 * ```
10 * use easy_ml::tensors::Tensor;
11 * use easy_ml::tensors::views::{TensorView, TensorReverse};
12 * let ab = Tensor::from([("a", 2), ("b", 3)], (0..6).collect());
13 * let reversed = ab.reverse(&["a"]);
14 * let also_reversed = TensorView::from(TensorReverse::from(&ab, &["a"]));
15 * assert_eq!(reversed, also_reversed);
16 * assert_eq!(
17 *     reversed,
18 *     Tensor::from(
19 *         [("a", 2), ("b", 3)],
20 *         vec![
21 *             3, 4, 5,
22 *             0, 1, 2,
23 *         ]
24 *     )
25 * );
26 * ```
27 */
28#[derive(Clone, Debug)]
29pub struct TensorReverse<T, S, const D: usize> {
30    source: S,
31    reversed: [bool; D],
32    _type: PhantomData<T>,
33}
34
35impl<T, S, const D: usize> TensorReverse<T, S, D>
36where
37    S: TensorRef<T, D>,
38{
39    /**
40     * Creates a TensorReverse from a source and a list of dimension names to reverse the
41     * order of iteration for. The list cannot be more than the number of dimensions in the source
42     * but it does not need to contain every dimension in the source. Any dimensions in the source
43     * but not in the list of dimension names to reverse will continue to iterate in their normal
44     * order.
45     *
46     * # Panics
47     *
48     * - If a dimension name is not in the source's shape or is repeated.
49     */
50    #[track_caller]
51    pub fn from(source: S, dimensions: &[Dimension]) -> TensorReverse<T, S, D> {
52        if crate::tensors::dimensions::has_duplicates_names(dimensions) {
53            panic!("Dimension names must all be unique: {:?}", dimensions);
54        }
55        let shape = source.view_shape();
56        if let Some(dimension) = dimensions.iter().find(|d| !dimensions::contains(&shape, d)) {
57            panic!(
58                "Dimension names to reverse must be in the source: {:?} is not in {:?}",
59                dimension, shape
60            );
61        }
62        let reversed = std::array::from_fn(|i| dimensions.contains(&shape[i].0));
63        TensorReverse {
64            source,
65            reversed,
66            _type: PhantomData,
67        }
68    }
69
70    /**
71     * Consumes the TensorReverse, yielding the source it was created from.
72     */
73    pub fn source(self) -> S {
74        self.source
75    }
76
77    /**
78     * Gives a reference to the TensorReverse's source (in which the iteration order may be
79     * different).
80     */
81    pub fn source_ref(&self) -> &S {
82        &self.source
83    }
84
85    /**
86     * Gives a mutable reference to the TensorReverse's source (in which the iteration order may
87     * be different).
88     */
89    // # Safety
90    //
91    // Although we're giving out a mutable reference here and thus the Tensor could be modified
92    // by the caller, it's impossible to change the dimensionality of the source due to this
93    // being determind at compile time by const generics, so as we only need our reversed array to
94    // match the dimensionality of the source after any modifications we don't have any edge
95    // cases that could make it invalid.
96    pub fn source_ref_mut(&mut self) -> &mut S {
97        &mut self.source
98    }
99}
100
101pub(crate) fn reverse_indexes<const D: usize>(
102    indexes: &[usize; D],
103    shape: &[(Dimension, usize); D],
104    reversed: &[bool; D],
105) -> [usize; D] {
106    std::array::from_fn(|d| {
107        if reversed[d] {
108            let length = shape[d].1;
109            // TensorRef requires dimensions are not of 0 length, so this never underflows
110            let last_index = length - 1;
111            let index = indexes[d];
112            // swap dimension indexing, so 0 becomes length-1, and length-1 becomes 0
113            last_index - index
114        } else {
115            indexes[d]
116        }
117    })
118}
119
120// # Safety
121//
122// The type implementing TensorRef must implement it correctly, so by delegating to it
123// by only reversing some indexes and not introducing interior mutability, we implement
124// TensorRef correctly as well.
125/**
126 * A TensorReverse implements TensorRef, with the dimension names the TensorReverse was created
127 * with iterating in reverse order compared to the dimension names in the original source.
128 */
129unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorReverse<T, S, D>
130where
131    S: TensorRef<T, D>,
132{
133    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
134        self.source.get_reference(reverse_indexes(
135            &indexes,
136            &self.view_shape(),
137            &self.reversed,
138        ))
139    }
140
141    fn view_shape(&self) -> [(Dimension, usize); D] {
142        self.source.view_shape()
143    }
144
145    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
146        unsafe {
147            self.source.get_reference_unchecked(reverse_indexes(
148                &indexes,
149                &self.view_shape(),
150                &self.reversed,
151            ))
152        }
153    }
154
155    fn data_layout(&self) -> DataLayout<D> {
156        // There might be some specific cases where reversing maintains a linear order but
157        // in general I think reversing only some indexes is going to mean any attempt at being
158        // able to take a slice that matches up with our view_shape is gone.
159        DataLayout::Other
160    }
161}
162
163// # Safety
164//
165// The type implementing TensorMut must implement it correctly, so by delegating to it
166// by only reversing some indexes and not introducing interior mutability, we implement
167// TensorMut correctly as well.
168/**
169 * A TensorReverse implements TensorMut, with the dimension names the TensorReverse was created
170 * with iterating in reverse order compared to the dimension names in the original source.
171 */
172unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorReverse<T, S, D>
173where
174    S: TensorMut<T, D>,
175{
176    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
177        self.source.get_reference_mut(reverse_indexes(
178            &indexes,
179            &self.view_shape(),
180            &self.reversed,
181        ))
182    }
183
184    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
185        unsafe {
186            self.source.get_reference_unchecked_mut(reverse_indexes(
187                &indexes,
188                &self.view_shape(),
189                &self.reversed,
190            ))
191        }
192    }
193}
194
195#[test]
196fn test_reversed_tensors() {
197    use crate::tensors::Tensor;
198    let tensor = Tensor::from([("a", 2), ("b", 3), ("c", 2)], (0..12).collect());
199    assert_eq!(
200        vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
201        tensor.iter().collect::<Vec<_>>()
202    );
203    let reversed = tensor.reverse_owned(&["a", "c"]);
204    assert_eq!(
205        vec![7, 6, 9, 8, 11, 10, 1, 0, 3, 2, 5, 4],
206        reversed.iter().collect::<Vec<_>>()
207    );
208}