easy_ml/tensors/views/
reshape.rs

1use crate::tensors;
2use crate::tensors::Dimension;
3use crate::tensors::dimensions;
4use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
5use std::marker::PhantomData;
6
7/**
8 * A new shape to override indexing an existing tensor. The dimensionality and individual
9 * dimension lengths can be changed, but the total number of elements in the new shape must
10 * match the existing tensor's shape.
11 *
12 * Elements are still in the same order (in memory) as
13 * the source tensor, none of the data is moved around - though iteration might be in
14 * a different order with the new shape.
15 *
16 * If you just need to rename dimensions without changing them, see
17 * [TensorRename](tensors::views::TensorRename)
18 *
19 * This types' generics can be read as a TensorReshape is generic over some element of type T
20 * from an existing source of type S of dimensionality D, and this tensor has dimensionality D2,
21 * which might be different to D, but will have the same total number of elements.
22 *
23 * ```
24 * use easy_ml::tensors::Tensor;
25 * use easy_ml::tensors::views::{TensorReshape, TensorView};
26 * let tensor = Tensor::from([("a", 2), ("b", 2)], (0..4).collect());
27 * let flat = TensorView::from(TensorReshape::from(tensor, [("i", 4)])); // or use tensor.reshape_view_owned
28 * assert_eq!(flat, Tensor::from([("i", 4)], (0..4).collect()));
29 * ```
30 */
31#[derive(Clone, Debug)]
32pub struct TensorReshape<T, S, const D: usize, const D2: usize> {
33    source: S,
34    shape: [(Dimension, usize); D2],
35    strides: [usize; D2],
36    source_strides: [usize; D],
37    _type: PhantomData<T>,
38}
39
40impl<T, S, const D: usize, const D2: usize> TensorReshape<T, S, D, D2>
41where
42    S: TensorRef<T, D>,
43{
44    /**
45     * Creates a TensorReshape from a source and a new shape to override the
46     * view_shape with. The new shape must correspond to the same number of total
47     * elements, but it need not match on dimensionality or individual dimension lengths.
48     *
49     * If you just need to rename dimensions without changing them, see
50     * [TensorRename](tensors::views::TensorRename)
51     * If you don't need to change the dimensionality, see
52     * [from](TensorReshape::from_existing_dimensions)
53     *
54     * # Panics
55     *
56     * - If the new shape has a different number of elements to the existing
57     * shape in the source
58     * - If the new shape has duplicate dimension names
59     */
60    #[track_caller]
61    pub fn from(source: S, shape: [(Dimension, usize); D2]) -> TensorReshape<T, S, D, D2> {
62        if dimensions::has_duplicates(&shape) {
63            panic!("Dimension names must all be unique: {:?}", &shape);
64        }
65        let existing_one_dimensional_length = dimensions::elements(&source.view_shape());
66        let given_one_dimensional_length = dimensions::elements(&shape);
67        if given_one_dimensional_length != existing_one_dimensional_length {
68            panic!(
69                "Number of elements required by provided shape {:?} are {:?} but number of elements in source are: {:?} due to shape of {:?}",
70                &shape,
71                &given_one_dimensional_length,
72                &existing_one_dimensional_length,
73                &source.view_shape()
74            );
75        }
76        let source_strides = tensors::compute_strides(&source.view_shape());
77        TensorReshape {
78            source,
79            shape,
80            strides: tensors::compute_strides(&shape),
81            source_strides,
82            _type: PhantomData,
83        }
84    }
85
86    /**
87     * Consumes the TensorReshape, yielding the source it was created from.
88     */
89    #[allow(dead_code)]
90    pub fn source(self) -> S {
91        self.source
92    }
93
94    /**
95     * Gives a reference to the TensorReshape's source (in which the data has its original shape).
96     */
97    // # Safety
98    //
99    // Giving out a mutable reference to our source could allow it to be changed out from under us
100    // and make the number of elements in our shape invalid. However, since the source implements
101    // TensorRef interior mutability is not allowed, so we can give out shared references without
102    // breaking our own integrity.
103    #[allow(dead_code)]
104    pub fn source_ref(&self) -> &S {
105        &self.source
106    }
107}
108
109impl<T, S, const D: usize> TensorReshape<T, S, D, D>
110where
111    S: TensorRef<T, D>,
112{
113    /**
114     * Creates a TensorReshape from a source and new dimension lengths with the same dimensionality
115     * as the source to override the view_shape with. The new shape must correspond to the same
116     * number of total elements, but it need not match on individual dimension lengths.
117     *
118     * If you just need to rename dimensions without changing them, see
119     * [TensorRename](tensors::views::TensorRename)
120     * If you need to change the dimensionality, see [from](TensorReshape::from)
121     *
122     * # Panics
123     *
124     * - If the new shape has a different number of elements to the existing
125     * shape in the source
126     */
127    #[track_caller]
128    pub fn from_existing_dimensions(source: S, lengths: [usize; D]) -> TensorReshape<T, S, D, D> {
129        let previous_shape = source.view_shape();
130        let shape = std::array::from_fn(|n| (previous_shape[n].0, lengths[0]));
131        let existing_one_dimensional_length = dimensions::elements(&source.view_shape());
132        let given_one_dimensional_length = dimensions::elements(&shape);
133        if given_one_dimensional_length != existing_one_dimensional_length {
134            panic!(
135                "Number of elements required by provided shape {:?} are {:?} but number of elements in source are: {:?} due to shape of {:?}",
136                &shape,
137                &given_one_dimensional_length,
138                &existing_one_dimensional_length,
139                &source.view_shape()
140            );
141        }
142        let source_strides = tensors::compute_strides(&source.view_shape());
143        TensorReshape {
144            source,
145            shape,
146            strides: tensors::compute_strides(&shape),
147            source_strides,
148            _type: PhantomData,
149        }
150    }
151}
152
153fn unflatten<const D: usize>(nth: usize, strides: &[usize; D]) -> [usize; D] {
154    let mut steps_remaining = nth;
155    let mut index = [0; D];
156    for d in 0..D {
157        let stride = strides[d];
158        // If the stride was 20, then 0-19 for indexes would be 0, 20-39 would be 1
159        // and so on
160        index[d] = steps_remaining / stride;
161        // Given such a stride of 20, we then need to look at what was rounded off
162        // An index of 0 or 20 into such a stride would mean we're done, 1 or 21 would
163        // mean we have 1 step left and so on
164        steps_remaining %= stride;
165    }
166    index
167}
168
169#[test]
170fn unflatten_produces_indices_in_n_dimensions() {
171    let strides = tensors::compute_strides(&[("x", 2), ("y", 2)]);
172    assert_eq!([0, 0], unflatten(0, &strides));
173    assert_eq!([0, 1], unflatten(1, &strides));
174    assert_eq!([1, 0], unflatten(2, &strides));
175    assert_eq!([1, 1], unflatten(3, &strides));
176
177    let strides = tensors::compute_strides(&[("x", 3), ("y", 2)]);
178    assert_eq!([0, 0], unflatten(0, &strides));
179    assert_eq!([0, 1], unflatten(1, &strides));
180    assert_eq!([1, 0], unflatten(2, &strides));
181    assert_eq!([1, 1], unflatten(3, &strides));
182    assert_eq!([2, 0], unflatten(4, &strides));
183    assert_eq!([2, 1], unflatten(5, &strides));
184
185    let strides = tensors::compute_strides(&[("x", 2), ("y", 3)]);
186    assert_eq!([0, 0], unflatten(0, &strides));
187    assert_eq!([0, 1], unflatten(1, &strides));
188    assert_eq!([0, 2], unflatten(2, &strides));
189    assert_eq!([1, 0], unflatten(3, &strides));
190    assert_eq!([1, 1], unflatten(4, &strides));
191    assert_eq!([1, 2], unflatten(5, &strides));
192
193    let strides = tensors::compute_strides(&[("x", 2), ("y", 3), ("z", 1)]);
194    assert_eq!([0, 0, 0], unflatten(0, &strides));
195    assert_eq!([0, 1, 0], unflatten(1, &strides));
196    assert_eq!([0, 2, 0], unflatten(2, &strides));
197    assert_eq!([1, 0, 0], unflatten(3, &strides));
198    assert_eq!([1, 1, 0], unflatten(4, &strides));
199    assert_eq!([1, 2, 0], unflatten(5, &strides));
200
201    let strides = tensors::compute_strides(&[("batch", 1), ("x", 2), ("y", 3)]);
202    assert_eq!([0, 0, 0], unflatten(0, &strides));
203    assert_eq!([0, 0, 1], unflatten(1, &strides));
204    assert_eq!([0, 0, 2], unflatten(2, &strides));
205    assert_eq!([0, 1, 0], unflatten(3, &strides));
206    assert_eq!([0, 1, 1], unflatten(4, &strides));
207    assert_eq!([0, 1, 2], unflatten(5, &strides));
208
209    let strides = tensors::compute_strides(&[("x", 2), ("y", 3), ("z", 2)]);
210    assert_eq!([0, 0, 0], unflatten(0, &strides));
211    assert_eq!([0, 0, 1], unflatten(1, &strides));
212    assert_eq!([0, 1, 0], unflatten(2, &strides));
213    assert_eq!([0, 1, 1], unflatten(3, &strides));
214    assert_eq!([0, 2, 0], unflatten(4, &strides));
215    assert_eq!([0, 2, 1], unflatten(5, &strides));
216    assert_eq!([1, 0, 0], unflatten(6, &strides));
217    assert_eq!([1, 0, 1], unflatten(7, &strides));
218    assert_eq!([1, 1, 0], unflatten(8, &strides));
219    assert_eq!([1, 1, 1], unflatten(9, &strides));
220    assert_eq!([1, 2, 0], unflatten(10, &strides));
221    assert_eq!([1, 2, 1], unflatten(11, &strides));
222}
223
224// # Safety
225//
226// The type implementing TensorRef must implement it correctly, so by delegating to it
227// by only flattening the caller's index into a one dimensional one, and then unflattening
228// it back into the source dimensionality and not introducing interior mutability, we implement
229// TensorRef correctly as well.
230/**
231 * A TensorReshape implements TensorRef, with the data in the same order as the original source.
232 */
233unsafe impl<T, S, const D: usize, const D2: usize> TensorRef<T, D2> for TensorReshape<T, S, D, D2>
234where
235    S: TensorRef<T, D>,
236{
237    fn get_reference(&self, indexes: [usize; D2]) -> Option<&T> {
238        let one_dimensional_index =
239            tensors::get_index_direct(&indexes, &self.strides, &self.shape)?;
240        self.source
241            .get_reference(unflatten(one_dimensional_index, &self.source_strides))
242    }
243
244    fn view_shape(&self) -> [(Dimension, usize); D2] {
245        self.shape
246    }
247
248    unsafe fn get_reference_unchecked(&self, indexes: [usize; D2]) -> &T {
249        unsafe {
250            // It is the caller's responsibility to always call with indexes in range,
251            // therefore out of bounds lookups created by get_index_direct_unchecked should never
252            // happen.
253            let one_dimensional_index =
254                tensors::get_index_direct_unchecked(&indexes, &self.strides);
255            self.source
256                .get_reference_unchecked(unflatten(one_dimensional_index, &self.source_strides))
257        }
258    }
259
260    fn data_layout(&self) -> DataLayout<D2> {
261        // There might be some cases where assigning a new shape maintains a linear order
262        // but it seems like a lot of effort to maintain a correct mapping from the original
263        // linear order to the new one, given we can change even dimensionality in this mapping.
264        DataLayout::Other
265    }
266}
267
268/**
269 * A TensorReshape implements TensorMut, with the data in the same order as the original source.
270 */
271unsafe impl<T, S, const D: usize, const D2: usize> TensorMut<T, D2> for TensorReshape<T, S, D, D2>
272where
273    S: TensorMut<T, D>,
274{
275    fn get_reference_mut(&mut self, indexes: [usize; D2]) -> Option<&mut T> {
276        let one_dimensional_index =
277            tensors::get_index_direct(&indexes, &self.strides, &self.shape)?;
278        self.source
279            .get_reference_mut(unflatten(one_dimensional_index, &self.source_strides))
280    }
281
282    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D2]) -> &mut T {
283        unsafe {
284            // It is the caller's responsibility to always call with indexes in range,
285            // therefore out of bounds lookups created by get_index_direct_unchecked should never
286            // happen.
287            let one_dimensional_index =
288                tensors::get_index_direct_unchecked(&indexes, &self.strides);
289            self.source
290                .get_reference_unchecked_mut(unflatten(one_dimensional_index, &self.source_strides))
291        }
292    }
293}