1use crate::tensors::Dimension;
2use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
3use std::marker::PhantomData;
4
5#[derive(Clone, Debug)]
25pub struct TensorIndex<T, S, const D: usize, const I: usize> {
26 source: S,
27 provided: [Option<usize>; D],
28 _type: PhantomData<T>,
29}
30
31impl<T, S, const D: usize, const I: usize> TensorIndex<T, S, D, I>
32where
33 S: TensorRef<T, D>,
34{
35 #[track_caller]
51 pub fn from(source: S, provided_indexes: [(Dimension, usize); I]) -> TensorIndex<T, S, D, I> {
52 let shape = source.view_shape();
53 if I > D {
54 panic!("D - I must be >= 0, D: {:?}, I: {:?}", D, I);
55 }
56 if crate::tensors::dimensions::has_duplicates(&provided_indexes) {
57 panic!(
58 "Multiple indexes cannot be provided for the same dimension name, provided: {:?}",
59 provided_indexes,
60 );
61 }
62 let mut provided = [None; D];
63 for (name, index) in &provided_indexes {
64 match shape
67 .iter()
68 .enumerate()
69 .find(|(_i, (n, length))| n == name && index < length)
70 {
71 None => panic!(
72 "Provided indexes must all correspond to valid indexes into the source shape, source shape: {:?}, provided: {:?}",
73 shape, provided_indexes,
74 ),
75 Some((i, (_n, _length))) => provided[i] = Some(*index),
77 }
78 }
79 TensorIndex {
80 source,
81 provided,
82 _type: PhantomData,
83 }
84 }
85
86 pub fn source(self) -> S {
90 self.source
91 }
92
93 pub fn source_ref(&self) -> &S {
104 &self.source
105 }
106}
107
108macro_rules! tensor_index_ref_impl {
109 (unsafe impl TensorRef for TensorIndex $d:literal $i:literal $helper_name:ident) => {
110 impl<T, S> TensorIndex<T, S, $d, $i>
111 where
112 S: TensorRef<T, $d>,
113 {
114 fn $helper_name(&self, indexes: [usize; $d - $i]) -> Option<[usize; $d]> {
115 let mut supplied = indexes.iter();
116 let mut combined = [0; $d];
120 let mut d = 0;
121 for provided in self.provided.iter() {
122 combined[d] = match provided {
123 None => *supplied.next()?,
127 Some(i) => *i,
128 };
129 d += 1;
130 }
131 Some(combined)
132 }
133 }
134
135 unsafe impl<T, S> TensorRef<T, { $d - $i }> for TensorIndex<T, S, $d, $i>
144 where
145 S: TensorRef<T, $d>,
146 {
147 fn get_reference(&self, indexes: [usize; $d - $i]) -> Option<&T> {
148 self.source
151 .get_reference(self.$helper_name(indexes).unwrap())
152 }
153
154 fn view_shape(&self) -> [(Dimension, usize); $d - $i] {
155 let shape = self.source.view_shape();
156 let mut unprovided = shape
157 .iter()
158 .enumerate()
159 .filter(|(i, _)| self.provided[*i].is_none())
160 .map(|(_, (name, length))| (*name, *length));
161 std::array::from_fn(|_| unprovided.next().unwrap())
162 }
163
164 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d - $i]) -> &T {
165 unsafe {
166 self.source
169 .get_reference_unchecked(self.$helper_name(indexes).unwrap_unchecked())
170 }
171 }
172
173 fn data_layout(&self) -> DataLayout<{ $d - $i }> {
174 DataLayout::NonLinear
177 }
178 }
179
180 unsafe impl<T, S> TensorMut<T, { $d - $i }> for TensorIndex<T, S, $d, $i>
181 where
182 S: TensorMut<T, $d>,
183 {
184 fn get_reference_mut(&mut self, indexes: [usize; $d - $i]) -> Option<&mut T> {
185 self.source
186 .get_reference_mut(self.$helper_name(indexes).unwrap())
187 }
188
189 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d - $i]) -> &mut T {
190 unsafe {
191 self.source
194 .get_reference_unchecked_mut(self.$helper_name(indexes).unwrap_unchecked())
195 }
196 }
197 }
198 };
199}
200
201tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 1 compute_select_indexes_6_1);
202tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 2 compute_select_indexes_6_2);
203tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 3 compute_select_indexes_6_3);
204tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 4 compute_select_indexes_6_4);
205tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 5 compute_select_indexes_6_5);
206tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 6 compute_select_indexes_6_6);
207tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 1 compute_select_indexes_5_1);
208tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 2 compute_select_indexes_5_2);
209tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 3 compute_select_indexes_5_3);
210tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 4 compute_select_indexes_5_4);
211tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 5 compute_select_indexes_5_5);
212tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 4 1 compute_select_indexes_4_1);
213tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 4 2 compute_select_indexes_4_2);
214tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 4 3 compute_select_indexes_4_3);
215tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 4 4 compute_select_indexes_4_4);
216tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 3 1 compute_select_indexes_3_1);
217tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 3 2 compute_select_indexes_3_2);
218tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 3 3 compute_select_indexes_3_3);
219tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 2 1 compute_select_indexes_2_1);
220tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 2 2 compute_select_indexes_2_2);
221tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 1 1 compute_select_indexes_1_1);
222
223#[test]
224fn dimensionality_reduction() {
225 use crate::tensors::Tensor;
226 use crate::tensors::views::TensorView;
227 #[rustfmt::skip]
228 let tensor = Tensor::from([("batch", 2), ("row", 2), ("column", 2)], vec![
229 0, 1,
230 2, 3,
231
232 4, 5,
233 6, 7
234 ]);
235 let matrix = TensorView::from(TensorIndex::from(&tensor, [("batch", 1)]));
237 assert_eq!(matrix.shape(), [("row", 2), ("column", 2)]);
238 assert_eq!(
239 matrix,
240 Tensor::from([("row", 2), ("column", 2)], vec![4, 5, 6, 7])
241 );
242 let vector = TensorView::from(TensorIndex::from(matrix.source(), [("column", 0)]));
244 assert_eq!(vector.shape(), [("row", 2)]);
245 assert_eq!(vector, Tensor::from([("row", 2)], vec![4, 6]));
246 let vector = TensorView::from(TensorIndex::from(&tensor, [("batch", 1), ("column", 0)]));
248 assert_eq!(vector.shape(), [("row", 2)]);
249 assert_eq!(vector, Tensor::from([("row", 2)], vec![4, 6]));
250
251 let matrix = TensorView::from(TensorIndex::from(&tensor, [("row", 1)]));
253 assert_eq!(matrix.shape(), [("batch", 2), ("column", 2)]);
254 assert_eq!(
255 matrix,
256 Tensor::from([("batch", 2), ("column", 2)], vec![2, 3, 6, 7])
257 );
258
259 let matrix = TensorView::from(TensorIndex::from(&tensor, [("column", 1)]));
261 assert_eq!(matrix.shape(), [("batch", 2), ("row", 2)]);
262 assert_eq!(
263 matrix,
264 Tensor::from([("batch", 2), ("row", 2)], vec![1, 3, 5, 7])
265 );
266 let vector = TensorView::from(TensorIndex::from(matrix.source(), [("batch", 0)]));
268 assert_eq!(vector.shape(), [("row", 2)]);
269 assert_eq!(vector, Tensor::from([("row", 2)], vec![1, 3,]));
270 let vector = TensorView::from(TensorIndex::from(&tensor, [("batch", 0), ("column", 1)]));
272 assert_eq!(vector.shape(), [("row", 2)]);
273 assert_eq!(vector, Tensor::from([("row", 2)], vec![1, 3,]));
274}
275
276#[derive(Clone, Debug)]
295pub struct TensorExpansion<T, S, const D: usize, const I: usize> {
296 source: S,
297 extra: [(usize, Dimension); I],
298 _type: PhantomData<T>,
299}
300
301impl<T, S, const D: usize, const I: usize> TensorExpansion<T, S, D, I>
302where
303 S: TensorRef<T, D>,
304{
305 #[track_caller]
326 pub fn from(
327 source: S,
328 extra_dimension_names: [(usize, Dimension); I],
329 ) -> TensorExpansion<T, S, D, I> {
330 let mut dimensions = extra_dimension_names;
331 if crate::tensors::dimensions::has_duplicates_extra_names(&extra_dimension_names) {
332 panic!("All extra dimension names {:?} must be unique", dimensions,);
333 }
334 let shape = source.view_shape();
335 for &(d, name) in &dimensions {
336 if d > D {
337 panic!(
338 "All extra dimensions {:?} must be inserted in the range 0 <= d <= D of the source shape {:?}",
339 dimensions, shape
340 );
341 }
342 for &(n, _) in &shape {
343 if name == n {
344 panic!(
345 "All extra dimension names {:?} must not be already present in the source shape {:?}",
346 dimensions, shape
347 );
348 }
349 }
350 }
351
352 dimensions.sort_by(|a, b| a.0.cmp(&b.0));
354
355 TensorExpansion {
356 source,
357 extra: dimensions,
358 _type: PhantomData,
359 }
360 }
361
362 pub fn source(self) -> S {
366 self.source
367 }
368
369 pub fn source_ref(&self) -> &S {
380 &self.source
381 }
382}
383
384macro_rules! tensor_expansion_ref_impl {
385 (unsafe impl TensorRef for TensorExpansion $d:literal $i:literal $helper_name:ident) => {
386 impl<T, S> TensorExpansion<T, S, $d, $i>
387 where
388 S: TensorRef<T, $d>,
389 {
390 fn $helper_name(&self, indexes: [usize; $d + $i]) -> Option<[usize; $d]> {
391 let mut used = [0; $d];
392 let mut i = 0; let mut extra = 0; for &index in indexes.iter() {
395 match self.extra.get(extra) {
396 None => {
400 used[i] = index;
401 i += 1;
402 }
403 Some((j, _name)) => {
404 if *j == i {
405 if index != 0 {
408 return None;
410 }
411 extra += 1;
412 } else {
415 used[i] = index;
416 i += 1;
417 }
418 }
419 }
420 }
421 Some(used)
424 }
425 }
426
427 unsafe impl<T, S> TensorRef<T, { $d + $i }> for TensorExpansion<T, S, $d, $i>
428 where
429 S: TensorRef<T, $d>,
430 {
431 fn get_reference(&self, indexes: [usize; $d + $i]) -> Option<&T> {
432 self.source.get_reference(self.$helper_name(indexes)?)
433 }
434
435 fn view_shape(&self) -> [(Dimension, usize); $d + $i] {
436 let shape = self.source.view_shape();
437 let mut extra_shape = [("", 0); $d + $i];
438 let mut i = 0; let mut extra = 0; for dimension in extra_shape.iter_mut() {
441 match self.extra.get(extra) {
442 None => {
443 *dimension = shape[i];
444 i += 1;
445 }
446 Some((j, extra_name)) => {
447 if *j == i {
448 *dimension = (extra_name, 1);
449 extra += 1;
450 } else {
452 *dimension = shape[i];
453 i += 1;
454 }
455 }
456 }
457 }
458 extra_shape
459 }
460
461 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + $i]) -> &T {
462 unsafe {
463 self.source
466 .get_reference_unchecked(self.$helper_name(indexes).unwrap_unchecked())
467 }
468 }
469
470 fn data_layout(&self) -> DataLayout<{ $d + $i }> {
471 DataLayout::NonLinear
474 }
475 }
476
477 unsafe impl<T, S> TensorMut<T, { $d + $i }> for TensorExpansion<T, S, $d, $i>
478 where
479 S: TensorMut<T, $d>,
480 {
481 fn get_reference_mut(&mut self, indexes: [usize; $d + $i]) -> Option<&mut T> {
482 self.source.get_reference_mut(self.$helper_name(indexes)?)
483 }
484
485 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + $i]) -> &mut T {
486 unsafe {
487 self.source
490 .get_reference_unchecked_mut(self.$helper_name(indexes).unwrap_unchecked())
491 }
492 }
493 }
494 };
495}
496
497tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 1 compute_expansion_indexes_0_1);
498tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 2 compute_expansion_indexes_0_2);
499tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 3 compute_expansion_indexes_0_3);
500tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 4 compute_expansion_indexes_0_4);
501tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 5 compute_expansion_indexes_0_5);
502tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 6 compute_expansion_indexes_0_6);
503tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 1 compute_expansion_indexes_1_1);
504tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 2 compute_expansion_indexes_1_2);
505tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 3 compute_expansion_indexes_1_3);
506tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 4 compute_expansion_indexes_1_4);
507tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 5 compute_expansion_indexes_1_5);
508tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 2 1 compute_expansion_indexes_2_1);
509tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 2 2 compute_expansion_indexes_2_2);
510tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 2 3 compute_expansion_indexes_2_3);
511tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 2 4 compute_expansion_indexes_2_4);
512tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 3 1 compute_expansion_indexes_3_1);
513tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 3 2 compute_expansion_indexes_3_2);
514tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 3 3 compute_expansion_indexes_3_3);
515tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 4 1 compute_expansion_indexes_4_1);
516tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 4 2 compute_expansion_indexes_4_2);
517tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 5 1 compute_expansion_indexes_5_1);
518
519#[test]
520fn dimensionality_expansion() {
521 use crate::tensors::Tensor;
522 use crate::tensors::views::TensorView;
523 let tensor = Tensor::from([("row", 2), ("column", 2)], (0..4).collect());
524 let tensor_3 = TensorView::from(TensorExpansion::from(&tensor, [(0, "batch")]));
525 assert_eq!(tensor_3.shape(), [("batch", 1), ("row", 2), ("column", 2)]);
526 assert_eq!(
527 tensor_3,
528 Tensor::from([("batch", 1), ("row", 2), ("column", 2)], vec![0, 1, 2, 3,])
529 );
530 let vector = Tensor::from([("a", 5)], (0..5).collect());
531 let tensor = TensorView::from(TensorExpansion::from(&vector, [(1, "b"), (1, "c")]));
532 assert_eq!(
533 tensor,
534 Tensor::from([("a", 5), ("b", 1), ("c", 1)], (0..5).collect())
535 );
536 let matrix = Tensor::from([("row", 2), ("column", 2)], (0..4).collect());
537 let dataset = TensorView::from(TensorExpansion::from(&matrix, [(2, "color"), (0, "batch")]));
538 assert_eq!(
539 dataset,
540 Tensor::from(
541 [("batch", 1), ("row", 2), ("column", 2), ("color", 1)],
542 (0..4).collect()
543 )
544 );
545}
546
547#[test]
548#[should_panic(
549 expected = "Unable to index with [2, 2, 2, 2], Tensor dimensions are [(\"a\", 2), (\"b\", 2), (\"c\", 1), (\"d\", 2)]."
550)]
551fn dimensionality_reduction_invalid_extra_index() {
552 use crate::tensors::Tensor;
553 use crate::tensors::views::TensorView;
554 let tensor = Tensor::from([("a", 2), ("b", 2), ("d", 2)], (0..8).collect());
555 let tensor = TensorView::from(TensorExpansion::from(&tensor, [(2, "c")]));
556 assert_eq!(tensor.shape(), [("a", 2), ("b", 2), ("c", 1), ("d", 2)]);
557 tensor.index().get_ref([2, 2, 2, 2]);
558}