1use crate::tensors::Dimension;
2use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
3use std::marker::PhantomData;
4
5#[derive(Clone, Debug)]
25pub struct TensorIndex<T, S, const D: usize, const I: usize> {
26 source: S,
27 provided: [Option<usize>; D],
28 _type: PhantomData<T>,
29}
30
31impl<T, S, const D: usize, const I: usize> TensorIndex<T, S, D, I>
32where
33 S: TensorRef<T, D>,
34{
35 #[track_caller]
51 pub fn from(source: S, provided_indexes: [(Dimension, usize); I]) -> TensorIndex<T, S, D, I> {
52 let shape = source.view_shape();
53 if I > D {
54 panic!("D - I must be >= 0, D: {:?}, I: {:?}", D, I);
55 }
56 if crate::tensors::dimensions::has_duplicates(&provided_indexes) {
57 panic!(
58 "Multiple indexes cannot be provided for the same dimension name, provided: {:?}",
59 provided_indexes,
60 );
61 }
62 let mut provided = [None; D];
63 for (name, index) in &provided_indexes {
64 match shape
67 .iter()
68 .enumerate()
69 .find(|(_i, (n, length))| n == name && index < length)
70 {
71 None => panic!(
72 "Provided indexes must all correspond to valid indexes into the source shape, source shape: {:?}, provided: {:?}",
73 shape, provided_indexes,
74 ),
75 Some((i, (_n, _length))) => provided[i] = Some(*index),
77 }
78 }
79 TensorIndex {
80 source,
81 provided,
82 _type: PhantomData,
83 }
84 }
85
86 pub fn source(self) -> S {
90 self.source
91 }
92
93 pub fn source_ref(&self) -> &S {
104 &self.source
105 }
106}
107
108macro_rules! tensor_index_ref_impl {
109 (unsafe impl TensorRef for TensorIndex $d:literal $i:literal $helper_name:ident) => {
110 impl<T, S> TensorIndex<T, S, $d, $i>
111 where
112 S: TensorRef<T, $d>,
113 {
114 fn $helper_name(&self, indexes: [usize; $d - $i]) -> Option<[usize; $d]> {
115 let mut supplied = indexes.iter();
116 let mut combined = [0; $d];
120 let mut d = 0;
121 for provided in self.provided.iter() {
122 combined[d] = match provided {
123 None => *supplied.next()?,
127 Some(i) => *i,
128 };
129 d += 1;
130 }
131 Some(combined)
132 }
133 }
134
135 unsafe impl<T, S> TensorRef<T, { $d - $i }> for TensorIndex<T, S, $d, $i>
144 where
145 S: TensorRef<T, $d>,
146 {
147 fn get_reference(&self, indexes: [usize; $d - $i]) -> Option<&T> {
148 self.source
151 .get_reference(self.$helper_name(indexes).unwrap())
152 }
153
154 fn view_shape(&self) -> [(Dimension, usize); $d - $i] {
155 let shape = self.source.view_shape();
156 let mut unprovided = shape
157 .iter()
158 .enumerate()
159 .filter(|(i, _)| self.provided[*i].is_none())
160 .map(|(_, (name, length))| (*name, *length));
161 std::array::from_fn(|_| unprovided.next().unwrap())
162 }
163
164 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d - $i]) -> &T {
165 unsafe {
166 self.source
168 .get_reference_unchecked(self.$helper_name(indexes).unwrap())
169 }
170 }
171
172 fn data_layout(&self) -> DataLayout<{ $d - $i }> {
173 DataLayout::NonLinear
176 }
177 }
178
179 unsafe impl<T, S> TensorMut<T, { $d - $i }> for TensorIndex<T, S, $d, $i>
180 where
181 S: TensorMut<T, $d>,
182 {
183 fn get_reference_mut(&mut self, indexes: [usize; $d - $i]) -> Option<&mut T> {
184 self.source
185 .get_reference_mut(self.$helper_name(indexes).unwrap())
186 }
187
188 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d - $i]) -> &mut T {
189 unsafe {
190 self.source
192 .get_reference_unchecked_mut(self.$helper_name(indexes).unwrap())
193 }
194 }
195 }
196 };
197}
198
199tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 1 compute_select_indexes_6_1);
200tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 2 compute_select_indexes_6_2);
201tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 3 compute_select_indexes_6_3);
202tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 4 compute_select_indexes_6_4);
203tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 5 compute_select_indexes_6_5);
204tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 6 6 compute_select_indexes_6_6);
205tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 1 compute_select_indexes_5_1);
206tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 2 compute_select_indexes_5_2);
207tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 3 compute_select_indexes_5_3);
208tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 4 compute_select_indexes_5_4);
209tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 5 5 compute_select_indexes_5_5);
210tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 4 1 compute_select_indexes_4_1);
211tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 4 2 compute_select_indexes_4_2);
212tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 4 3 compute_select_indexes_4_3);
213tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 4 4 compute_select_indexes_4_4);
214tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 3 1 compute_select_indexes_3_1);
215tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 3 2 compute_select_indexes_3_2);
216tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 3 3 compute_select_indexes_3_3);
217tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 2 1 compute_select_indexes_2_1);
218tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 2 2 compute_select_indexes_2_2);
219tensor_index_ref_impl!(unsafe impl TensorRef for TensorIndex 1 1 compute_select_indexes_1_1);
220
221#[test]
222fn dimensionality_reduction() {
223 use crate::tensors::Tensor;
224 use crate::tensors::views::TensorView;
225 #[rustfmt::skip]
226 let tensor = Tensor::from([("batch", 2), ("row", 2), ("column", 2)], vec![
227 0, 1,
228 2, 3,
229
230 4, 5,
231 6, 7
232 ]);
233 let matrix = TensorView::from(TensorIndex::from(&tensor, [("batch", 1)]));
235 assert_eq!(matrix.shape(), [("row", 2), ("column", 2)]);
236 assert_eq!(
237 matrix,
238 Tensor::from([("row", 2), ("column", 2)], vec![4, 5, 6, 7])
239 );
240 let vector = TensorView::from(TensorIndex::from(matrix.source(), [("column", 0)]));
242 assert_eq!(vector.shape(), [("row", 2)]);
243 assert_eq!(vector, Tensor::from([("row", 2)], vec![4, 6]));
244 let vector = TensorView::from(TensorIndex::from(&tensor, [("batch", 1), ("column", 0)]));
246 assert_eq!(vector.shape(), [("row", 2)]);
247 assert_eq!(vector, Tensor::from([("row", 2)], vec![4, 6]));
248
249 let matrix = TensorView::from(TensorIndex::from(&tensor, [("row", 1)]));
251 assert_eq!(matrix.shape(), [("batch", 2), ("column", 2)]);
252 assert_eq!(
253 matrix,
254 Tensor::from([("batch", 2), ("column", 2)], vec![2, 3, 6, 7])
255 );
256
257 let matrix = TensorView::from(TensorIndex::from(&tensor, [("column", 1)]));
259 assert_eq!(matrix.shape(), [("batch", 2), ("row", 2)]);
260 assert_eq!(
261 matrix,
262 Tensor::from([("batch", 2), ("row", 2)], vec![1, 3, 5, 7])
263 );
264 let vector = TensorView::from(TensorIndex::from(matrix.source(), [("batch", 0)]));
266 assert_eq!(vector.shape(), [("row", 2)]);
267 assert_eq!(vector, Tensor::from([("row", 2)], vec![1, 3,]));
268 let vector = TensorView::from(TensorIndex::from(&tensor, [("batch", 0), ("column", 1)]));
270 assert_eq!(vector.shape(), [("row", 2)]);
271 assert_eq!(vector, Tensor::from([("row", 2)], vec![1, 3,]));
272}
273
274#[derive(Clone, Debug)]
293pub struct TensorExpansion<T, S, const D: usize, const I: usize> {
294 source: S,
295 extra: [(usize, Dimension); I],
296 _type: PhantomData<T>,
297}
298
299impl<T, S, const D: usize, const I: usize> TensorExpansion<T, S, D, I>
300where
301 S: TensorRef<T, D>,
302{
303 #[track_caller]
324 pub fn from(
325 source: S,
326 extra_dimension_names: [(usize, Dimension); I],
327 ) -> TensorExpansion<T, S, D, I> {
328 let mut dimensions = extra_dimension_names;
329 if crate::tensors::dimensions::has_duplicates_extra_names(&extra_dimension_names) {
330 panic!("All extra dimension names {:?} must be unique", dimensions,);
331 }
332 let shape = source.view_shape();
333 for &(d, name) in &dimensions {
334 if d > D {
335 panic!(
336 "All extra dimensions {:?} must be inserted in the range 0 <= d <= D of the source shape {:?}",
337 dimensions, shape
338 );
339 }
340 for &(n, _) in &shape {
341 if name == n {
342 panic!(
343 "All extra dimension names {:?} must not be already present in the source shape {:?}",
344 dimensions, shape
345 );
346 }
347 }
348 }
349
350 dimensions.sort_by(|a, b| a.0.cmp(&b.0));
352
353 TensorExpansion {
354 source,
355 extra: dimensions,
356 _type: PhantomData,
357 }
358 }
359
360 pub fn source(self) -> S {
364 self.source
365 }
366
367 pub fn source_ref(&self) -> &S {
378 &self.source
379 }
380}
381
382macro_rules! tensor_expansion_ref_impl {
383 (unsafe impl TensorRef for TensorExpansion $d:literal $i:literal $helper_name:ident) => {
384 impl<T, S> TensorExpansion<T, S, $d, $i>
385 where
386 S: TensorRef<T, $d>,
387 {
388 fn $helper_name(&self, indexes: [usize; $d + $i]) -> Option<[usize; $d]> {
389 let mut used = [0; $d];
390 let mut i = 0; let mut extra = 0; for &index in indexes.iter() {
393 match self.extra.get(extra) {
394 None => {
398 used[i] = index;
399 i += 1;
400 }
401 Some((j, _name)) => {
402 if *j == i {
403 if index != 0 {
406 return None;
408 }
409 extra += 1;
410 } else {
413 used[i] = index;
414 i += 1;
415 }
416 }
417 }
418 }
419 Some(used)
422 }
423 }
424
425 unsafe impl<T, S> TensorRef<T, { $d + $i }> for TensorExpansion<T, S, $d, $i>
426 where
427 S: TensorRef<T, $d>,
428 {
429 fn get_reference(&self, indexes: [usize; $d + $i]) -> Option<&T> {
430 self.source.get_reference(self.$helper_name(indexes)?)
431 }
432
433 fn view_shape(&self) -> [(Dimension, usize); $d + $i] {
434 let shape = self.source.view_shape();
435 let mut extra_shape = [("", 0); $d + $i];
436 let mut i = 0; let mut extra = 0; for dimension in extra_shape.iter_mut() {
439 match self.extra.get(extra) {
440 None => {
441 *dimension = shape[i];
442 i += 1;
443 }
444 Some((j, extra_name)) => {
445 if *j == i {
446 *dimension = (extra_name, 1);
447 extra += 1;
448 } else {
450 *dimension = shape[i];
451 i += 1;
452 }
453 }
454 }
455 }
456 extra_shape
457 }
458
459 unsafe fn get_reference_unchecked(&self, indexes: [usize; $d + $i]) -> &T {
460 unsafe {
461 self.source
463 .get_reference_unchecked(self.$helper_name(indexes).unwrap())
464 }
465 }
466
467 fn data_layout(&self) -> DataLayout<{ $d + $i }> {
468 DataLayout::NonLinear
471 }
472 }
473
474 unsafe impl<T, S> TensorMut<T, { $d + $i }> for TensorExpansion<T, S, $d, $i>
475 where
476 S: TensorMut<T, $d>,
477 {
478 fn get_reference_mut(&mut self, indexes: [usize; $d + $i]) -> Option<&mut T> {
479 self.source.get_reference_mut(self.$helper_name(indexes)?)
480 }
481
482 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; $d + $i]) -> &mut T {
483 unsafe {
484 self.source
486 .get_reference_unchecked_mut(self.$helper_name(indexes).unwrap())
487 }
488 }
489 }
490 };
491}
492
493tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 1 compute_expansion_indexes_0_1);
494tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 2 compute_expansion_indexes_0_2);
495tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 3 compute_expansion_indexes_0_3);
496tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 4 compute_expansion_indexes_0_4);
497tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 5 compute_expansion_indexes_0_5);
498tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 0 6 compute_expansion_indexes_0_6);
499tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 1 compute_expansion_indexes_1_1);
500tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 2 compute_expansion_indexes_1_2);
501tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 3 compute_expansion_indexes_1_3);
502tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 4 compute_expansion_indexes_1_4);
503tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 1 5 compute_expansion_indexes_1_5);
504tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 2 1 compute_expansion_indexes_2_1);
505tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 2 2 compute_expansion_indexes_2_2);
506tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 2 3 compute_expansion_indexes_2_3);
507tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 2 4 compute_expansion_indexes_2_4);
508tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 3 1 compute_expansion_indexes_3_1);
509tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 3 2 compute_expansion_indexes_3_2);
510tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 3 3 compute_expansion_indexes_3_3);
511tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 4 1 compute_expansion_indexes_4_1);
512tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 4 2 compute_expansion_indexes_4_2);
513tensor_expansion_ref_impl!(unsafe impl TensorRef for TensorExpansion 5 1 compute_expansion_indexes_5_1);
514
515#[test]
516fn dimensionality_expansion() {
517 use crate::tensors::Tensor;
518 use crate::tensors::views::TensorView;
519 let tensor = Tensor::from([("row", 2), ("column", 2)], (0..4).collect());
520 let tensor_3 = TensorView::from(TensorExpansion::from(&tensor, [(0, "batch")]));
521 assert_eq!(tensor_3.shape(), [("batch", 1), ("row", 2), ("column", 2)]);
522 assert_eq!(
523 tensor_3,
524 Tensor::from([("batch", 1), ("row", 2), ("column", 2)], vec![0, 1, 2, 3,])
525 );
526 let vector = Tensor::from([("a", 5)], (0..5).collect());
527 let tensor = TensorView::from(TensorExpansion::from(&vector, [(1, "b"), (1, "c")]));
528 assert_eq!(
529 tensor,
530 Tensor::from([("a", 5), ("b", 1), ("c", 1)], (0..5).collect())
531 );
532 let matrix = Tensor::from([("row", 2), ("column", 2)], (0..4).collect());
533 let dataset = TensorView::from(TensorExpansion::from(&matrix, [(2, "color"), (0, "batch")]));
534 assert_eq!(
535 dataset,
536 Tensor::from(
537 [("batch", 1), ("row", 2), ("column", 2), ("color", 1)],
538 (0..4).collect()
539 )
540 );
541}
542
543#[test]
544#[should_panic(
545 expected = "Unable to index with [2, 2, 2, 2], Tensor dimensions are [(\"a\", 2), (\"b\", 2), (\"c\", 1), (\"d\", 2)]."
546)]
547fn dimensionality_reduction_invalid_extra_index() {
548 use crate::tensors::Tensor;
549 use crate::tensors::views::TensorView;
550 let tensor = Tensor::from([("a", 2), ("b", 2), ("d", 2)], (0..8).collect());
551 let tensor = TensorView::from(TensorExpansion::from(&tensor, [(2, "c")]));
552 assert_eq!(tensor.shape(), [("a", 2), ("b", 2), ("c", 1), ("d", 2)]);
553 tensor.index().get_ref([2, 2, 2, 2]);
554}