1use crate::tensors;
2use crate::tensors::Dimension;
3use crate::tensors::dimensions;
4use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
5use std::marker::PhantomData;
6
7#[derive(Clone, Debug)]
32pub struct TensorReshape<T, S, const D: usize, const D2: usize> {
33 source: S,
34 shape: [(Dimension, usize); D2],
35 strides: [usize; D2],
36 source_strides: [usize; D],
37 _type: PhantomData<T>,
38}
39
40impl<T, S, const D: usize, const D2: usize> TensorReshape<T, S, D, D2>
41where
42 S: TensorRef<T, D>,
43{
44 #[track_caller]
61 pub fn from(source: S, shape: [(Dimension, usize); D2]) -> TensorReshape<T, S, D, D2> {
62 if dimensions::has_duplicates(&shape) {
63 panic!("Dimension names must all be unique: {:?}", &shape);
64 }
65 let existing_one_dimensional_length = dimensions::elements(&source.view_shape());
66 let given_one_dimensional_length = dimensions::elements(&shape);
67 if given_one_dimensional_length != existing_one_dimensional_length {
68 panic!(
69 "Number of elements required by provided shape {:?} are {:?} but number of elements in source are: {:?} due to shape of {:?}",
70 &shape,
71 &given_one_dimensional_length,
72 &existing_one_dimensional_length,
73 &source.view_shape()
74 );
75 }
76 let source_strides = tensors::compute_strides(&source.view_shape());
77 TensorReshape {
78 source,
79 shape,
80 strides: tensors::compute_strides(&shape),
81 source_strides,
82 _type: PhantomData,
83 }
84 }
85
86 #[allow(dead_code)]
90 pub fn source(self) -> S {
91 self.source
92 }
93
94 #[allow(dead_code)]
104 pub fn source_ref(&self) -> &S {
105 &self.source
106 }
107}
108
109impl<T, S, const D: usize> TensorReshape<T, S, D, D>
110where
111 S: TensorRef<T, D>,
112{
113 #[track_caller]
128 pub fn from_existing_dimensions(source: S, lengths: [usize; D]) -> TensorReshape<T, S, D, D> {
129 let previous_shape = source.view_shape();
130 let shape = std::array::from_fn(|n| (previous_shape[n].0, lengths[0]));
131 let existing_one_dimensional_length = dimensions::elements(&source.view_shape());
132 let given_one_dimensional_length = dimensions::elements(&shape);
133 if given_one_dimensional_length != existing_one_dimensional_length {
134 panic!(
135 "Number of elements required by provided shape {:?} are {:?} but number of elements in source are: {:?} due to shape of {:?}",
136 &shape,
137 &given_one_dimensional_length,
138 &existing_one_dimensional_length,
139 &source.view_shape()
140 );
141 }
142 let source_strides = tensors::compute_strides(&source.view_shape());
143 TensorReshape {
144 source,
145 shape,
146 strides: tensors::compute_strides(&shape),
147 source_strides,
148 _type: PhantomData,
149 }
150 }
151}
152
153fn unflatten<const D: usize>(nth: usize, strides: &[usize; D]) -> [usize; D] {
154 let mut steps_remaining = nth;
155 let mut index = [0; D];
156 for d in 0..D {
157 let stride = strides[d];
158 index[d] = steps_remaining / stride;
161 steps_remaining %= stride;
165 }
166 index
167}
168
169#[test]
170fn unflatten_produces_indices_in_n_dimensions() {
171 let strides = tensors::compute_strides(&[("x", 2), ("y", 2)]);
172 assert_eq!([0, 0], unflatten(0, &strides));
173 assert_eq!([0, 1], unflatten(1, &strides));
174 assert_eq!([1, 0], unflatten(2, &strides));
175 assert_eq!([1, 1], unflatten(3, &strides));
176
177 let strides = tensors::compute_strides(&[("x", 3), ("y", 2)]);
178 assert_eq!([0, 0], unflatten(0, &strides));
179 assert_eq!([0, 1], unflatten(1, &strides));
180 assert_eq!([1, 0], unflatten(2, &strides));
181 assert_eq!([1, 1], unflatten(3, &strides));
182 assert_eq!([2, 0], unflatten(4, &strides));
183 assert_eq!([2, 1], unflatten(5, &strides));
184
185 let strides = tensors::compute_strides(&[("x", 2), ("y", 3)]);
186 assert_eq!([0, 0], unflatten(0, &strides));
187 assert_eq!([0, 1], unflatten(1, &strides));
188 assert_eq!([0, 2], unflatten(2, &strides));
189 assert_eq!([1, 0], unflatten(3, &strides));
190 assert_eq!([1, 1], unflatten(4, &strides));
191 assert_eq!([1, 2], unflatten(5, &strides));
192
193 let strides = tensors::compute_strides(&[("x", 2), ("y", 3), ("z", 1)]);
194 assert_eq!([0, 0, 0], unflatten(0, &strides));
195 assert_eq!([0, 1, 0], unflatten(1, &strides));
196 assert_eq!([0, 2, 0], unflatten(2, &strides));
197 assert_eq!([1, 0, 0], unflatten(3, &strides));
198 assert_eq!([1, 1, 0], unflatten(4, &strides));
199 assert_eq!([1, 2, 0], unflatten(5, &strides));
200
201 let strides = tensors::compute_strides(&[("batch", 1), ("x", 2), ("y", 3)]);
202 assert_eq!([0, 0, 0], unflatten(0, &strides));
203 assert_eq!([0, 0, 1], unflatten(1, &strides));
204 assert_eq!([0, 0, 2], unflatten(2, &strides));
205 assert_eq!([0, 1, 0], unflatten(3, &strides));
206 assert_eq!([0, 1, 1], unflatten(4, &strides));
207 assert_eq!([0, 1, 2], unflatten(5, &strides));
208
209 let strides = tensors::compute_strides(&[("x", 2), ("y", 3), ("z", 2)]);
210 assert_eq!([0, 0, 0], unflatten(0, &strides));
211 assert_eq!([0, 0, 1], unflatten(1, &strides));
212 assert_eq!([0, 1, 0], unflatten(2, &strides));
213 assert_eq!([0, 1, 1], unflatten(3, &strides));
214 assert_eq!([0, 2, 0], unflatten(4, &strides));
215 assert_eq!([0, 2, 1], unflatten(5, &strides));
216 assert_eq!([1, 0, 0], unflatten(6, &strides));
217 assert_eq!([1, 0, 1], unflatten(7, &strides));
218 assert_eq!([1, 1, 0], unflatten(8, &strides));
219 assert_eq!([1, 1, 1], unflatten(9, &strides));
220 assert_eq!([1, 2, 0], unflatten(10, &strides));
221 assert_eq!([1, 2, 1], unflatten(11, &strides));
222}
223
224unsafe impl<T, S, const D: usize, const D2: usize> TensorRef<T, D2> for TensorReshape<T, S, D, D2>
234where
235 S: TensorRef<T, D>,
236{
237 fn get_reference(&self, indexes: [usize; D2]) -> Option<&T> {
238 let one_dimensional_index =
239 tensors::get_index_direct(&indexes, &self.strides, &self.shape)?;
240 self.source
241 .get_reference(unflatten(one_dimensional_index, &self.source_strides))
242 }
243
244 fn view_shape(&self) -> [(Dimension, usize); D2] {
245 self.shape
246 }
247
248 unsafe fn get_reference_unchecked(&self, indexes: [usize; D2]) -> &T {
249 unsafe {
250 let one_dimensional_index =
254 tensors::get_index_direct_unchecked(&indexes, &self.strides);
255 self.source
256 .get_reference_unchecked(unflatten(one_dimensional_index, &self.source_strides))
257 }
258 }
259
260 fn data_layout(&self) -> DataLayout<D2> {
261 DataLayout::Other
265 }
266}
267
268unsafe impl<T, S, const D: usize, const D2: usize> TensorMut<T, D2> for TensorReshape<T, S, D, D2>
272where
273 S: TensorMut<T, D>,
274{
275 fn get_reference_mut(&mut self, indexes: [usize; D2]) -> Option<&mut T> {
276 let one_dimensional_index =
277 tensors::get_index_direct(&indexes, &self.strides, &self.shape)?;
278 self.source
279 .get_reference_mut(unflatten(one_dimensional_index, &self.source_strides))
280 }
281
282 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D2]) -> &mut T {
283 unsafe {
284 let one_dimensional_index =
288 tensors::get_index_direct_unchecked(&indexes, &self.strides);
289 self.source
290 .get_reference_unchecked_mut(unflatten(one_dimensional_index, &self.source_strides))
291 }
292 }
293}