1use std::{fmt::Display, vec};
2
3use crate::{Error, Result};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct Shape(pub(crate) Vec<usize>);
7
8impl Shape {
9 pub fn scalar() -> Self {
10 Self(vec![])
11 }
12
13 pub fn is_scalar(&self) -> bool {
14 self.0.is_empty() || (self.0.len() == 1 && self.0[0] == 1)
15 }
16
17 pub fn rank(&self) -> usize {
18 self.0.len()
19 }
20
21 pub fn dims(&self) -> &[usize] {
22 &self.0
23 }
24
25 pub fn into_dims(self) -> Vec<usize> {
26 self.0
27 }
28
29 pub fn dim(&self, dim: impl Dim) -> Result<usize> {
30 let index = dim.to_index(self, "get dim")?;
31 Ok(self.dims()[index])
32 }
33
34 pub fn element_count(&self) -> usize {
35 self.dims().iter().product()
36 }
37
38 pub fn is_contiguous(&self, stride: &[usize]) -> bool {
39 if self.rank() != stride.len() {
40 return false;
41 }
42 let mut acc = 1;
43 for (&stride, &dim) in stride.iter().zip(self.dims().iter()).rev() {
44 if dim > 1 && stride != acc {
45 return false;
46 }
47 acc *= dim;
48 }
49 true
50 }
51
52 pub fn extend(mut self, additional_dims: &[usize]) -> Self {
53 self.0.extend(additional_dims);
54 self
55 }
56
57
58 pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
62 let lhs = self;
63 let lhs_dims = lhs.dims();
64 let rhs_dims = rhs.dims();
65 let lhs_ndims = lhs_dims.len();
66 let rhs_ndims = rhs_dims.len();
67 let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
68 let mut bcast_dims = vec![0; bcast_ndims];
69 for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
70 let rev_idx = bcast_ndims - idx;
71 let l_value = if lhs_ndims < rev_idx {
72 1
73 } else {
74 lhs_dims[lhs_ndims - rev_idx]
75 };
76 let r_value = if rhs_ndims < rev_idx {
77 1
78 } else {
79 rhs_dims[rhs_ndims - rev_idx]
80 };
81 *bcast_value = if l_value == r_value {
82 l_value
83 } else if l_value == 1 {
84 r_value
85 } else if r_value == 1 {
86 l_value
87 } else {
88 Err(Error::ShapeMismatchBinaryOp {
89 lhs: lhs.clone(),
90 rhs: rhs.clone(),
91 op,
92 })?
93 }
94 }
95 Ok(Shape::from(bcast_dims))
96 }
97
98
99 pub fn dim_coordinates(&self) -> DimCoordinates {
108 DimCoordinates::from_shape(self)
109 }
110
111 pub fn dims_coordinates<const N: usize>(&self) -> Result<DimNCoordinates<N>> {
112 DimNCoordinates::<N>::from_shape(self)
113 }
114
115 pub fn dim2_coordinates(&self) -> Result<DimNCoordinates<2>> {
116 DimNCoordinates::<2>::from_shape(self)
117 }
118
119 pub fn dim3_coordinates(&self) -> Result<DimNCoordinates<3>> {
120 DimNCoordinates::<3>::from_shape(self)
121 }
122
123 pub fn dim4_coordinates(&self) -> Result<DimNCoordinates<4>> {
124 DimNCoordinates::<4>::from_shape(self)
125 }
126
127 pub fn dim5_coordinates(&self) -> Result<DimNCoordinates<5>> {
128 DimNCoordinates::<5>::from_shape(self)
129 }
130
131 pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
132 let mut stride = self.dims()
133 .iter()
134 .rev()
135 .scan(1, |prod, u| {
136 let prod_pre_mult = *prod;
137 *prod *= u;
138 Some(prod_pre_mult)
139 })
140 .collect::<Vec<_>>();
141 stride.reverse();
142 stride
143 }
144}
145
146
147pub struct DimCoordinates {
152 shape: Vec<usize>,
153 current: Vec<usize>,
154 done: bool,
155}
156
157impl DimCoordinates {
158 pub fn from_shape(shape: &Shape) -> Self {
159 let rank = shape.rank();
160 Self {
161 shape: shape.dims().to_vec(),
162 current: vec![0; rank],
163 done: shape.is_scalar(),
164 }
165 }
166}
167
168impl Iterator for DimCoordinates {
169 type Item = Vec<usize>;
170
171 fn next(&mut self) -> Option<Self::Item> {
172 if self.done {
173 return None;
174 }
175
176 let result = self.current.clone();
177
178 for i in (0..self.current.len()).rev() {
179 self.current[i] += 1;
180 if self.current[i] < self.shape[i] {
181 break;
182 } else {
183 self.current[i] = 0;
184 if i == 0 {
185 self.done = true;
186 }
187 }
188 }
189
190 Some(result)
191 }
192}
193
194pub struct DimNCoordinates<const N: usize> {
195 shape: [usize; N],
196 current: [usize; N],
197 done: bool,
198}
199
200impl<const N: usize> DimNCoordinates<N> {
201 pub fn from_shape(from_shape: &Shape) -> Result<Self> {
202 if from_shape.rank() == N {
203 let mut shape = [0usize; N];
204 for i in 0..N {
205 shape[i] = from_shape.dims()[i];
206 }
207
208 let current = [0usize; N];
209
210 Ok(Self {
211 shape,
212 current,
213 done: N == 0
214 })
215 } else {
216 Err(Error::UnexpectedNumberOfDims {
217 expected: N,
218 got: from_shape.rank(),
219 shape: Shape::from(from_shape.dims()),
220 })?
221 }
222 }
223}
224
225impl<const N: usize> Iterator for DimNCoordinates<N> {
226 type Item = [usize; N];
227 fn next(&mut self) -> Option<Self::Item> {
228 if self.done {
229 return None;
230 }
231
232 let result = self.current;
233
234 for i in (0..N).rev() {
235 self.current[i] += 1;
236 if self.current[i] < self.shape[i] {
237 break;
238 } else {
239 self.current[i] = 0;
240 if i == 0 {
241 self.done = true;
242 }
243 }
244 }
245
246 Some(result)
247 }
248}
249
250impl<const C: usize> From<&[usize; C]> for Shape {
251 fn from(dims: &[usize; C]) -> Self {
252 Self(dims.to_vec())
253 }
254}
255
256impl From<Vec<usize>> for Shape {
257 fn from(dims: Vec<usize>) -> Self {
258 Self(dims)
259 }
260}
261
262impl From<&Vec<usize>> for Shape {
263 fn from(dims: &Vec<usize>) -> Self {
264 Self(dims.clone())
265 }
266}
267
268impl From<&[usize]> for Shape {
269 fn from(dims: &[usize]) -> Self {
270 Self(dims.to_vec())
271 }
272}
273
274impl From<&Shape> for Shape {
275 fn from(shape: &Shape) -> Self {
276 Self(shape.0.to_vec())
277 }
278}
279
280impl From<usize> for Shape {
281 fn from(d1: usize) -> Self {
282 Self([d1].to_vec())
283 }
284}
285
286impl From<()> for Shape {
287 fn from(_: ()) -> Self {
288 Self(vec![])
289 }
290}
291
292impl std::fmt::Display for Shape {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 write!(f, "(")?;
295 for (i, dim) in self.0.iter().enumerate() {
296 if i > 0 {
297 write!(f, ", ")?;
298 }
299 write!(f, "{}", dim)?;
300 }
301 if self.0.len() == 1 {
302 write!(f, ",")?;
303 }
304 write!(f, ")")
305 }
306}
307
308macro_rules! impl_from_tuple {
309 ($tuple:ty, $($index:tt),+) => {
310 impl From<$tuple> for Shape {
311 fn from(d: $tuple) -> Self {
312 Self([$(d.$index,)+].to_vec())
313 }
314 }
315 };
316}
317
318impl_from_tuple!((usize,), 0);
319impl_from_tuple!((usize, usize), 0, 1);
320impl_from_tuple!((usize, usize, usize), 0, 1, 2);
321impl_from_tuple!((usize, usize, usize, usize), 0, 1, 2, 3);
322impl_from_tuple!((usize, usize, usize, usize, usize), 0, 1, 2, 3, 4);
323impl_from_tuple!((usize, usize, usize, usize, usize, usize), 0, 1, 2, 3, 4, 5);
324
325#[derive(Debug, Clone, Copy)]
326pub enum D {
327 Minus1,
328 Minus2,
329 Minus(usize),
330 Index(usize),
331}
332
333impl Display for D {
334 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
335 match self {
336 Self::Minus(n) => writeln!(f, "-{}", n),
337 Self::Minus1 => writeln!(f, "-1"),
338 Self::Minus2 => writeln!(f, "-2"),
339 Self::Index(n) => writeln!(f, "{}", n),
340 }
341 }
342}
343
344impl D {
345 fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
346 let dim = match self {
347 Self::Minus1 => -1,
348 Self::Minus2 => -2,
349 Self::Minus(u) => -(*u as i32),
350 Self::Index(u) => *u as i32,
351 };
352 Error::DimOutOfRange {
353 shape: shape.clone(),
354 dim,
355 op,
356 }
357 }
358}
359
360
361macro_rules! extract_dims {
362 ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
363 pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
364 if dims.len() != $cnt {
365 Err(Error::UnexpectedNumberOfDims {
366 expected: $cnt,
367 got: dims.len(),
368 shape: Shape::from(dims),
369 })?
370 } else {
371 Ok($dims(dims))
372 }
373 }
374
375 impl Shape {
376 pub fn $fn_name(&self) -> Result<$out_type> {
377 $fn_name(self.0.as_slice())
378 }
379 }
380
381 impl<T: crate::WithDType> crate::Tensor<T> {
382 pub fn $fn_name(&self) -> Result<$out_type> {
383 self.shape().$fn_name()
384 }
385 }
386
387 impl std::convert::TryInto<$out_type> for Shape {
388 type Error = crate::Error;
389 fn try_into(self) -> crate::Result<$out_type> {
390 self.$fn_name()
391 }
392 }
393 };
394}
395
396extract_dims!(dims0, 0, |_: &[usize]| (), ());
397extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
398extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
399extract_dims!(
400 dims3,
401 3,
402 |d: &[usize]| (d[0], d[1], d[2]),
403 (usize, usize, usize)
404);
405extract_dims!(
406 dims4,
407 4,
408 |d: &[usize]| (d[0], d[1], d[2], d[3]),
409 (usize, usize, usize, usize)
410);
411extract_dims!(
412 dims5,
413 5,
414 |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
415 (usize, usize, usize, usize, usize)
416);
417
418
419pub trait Dim : Copy {
420 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
421 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
422}
423
424impl Dim for usize {
425 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
426 let dim = *self;
427 if dim >= shape.rank() {
428 Err(Error::DimOutOfRange {
429 shape: shape.clone(),
430 dim: dim as i32,
431 op,
432 })?
433 } else {
434 Ok(dim)
435 }
436 }
437
438 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
439 let dim = *self;
440 if dim > shape.rank() {
441 Err(Error::DimOutOfRange {
442 shape: shape.clone(),
443 dim: dim as i32,
444 op,
445 })?
446 } else {
447 Ok(dim)
448 }
449 }
450}
451
452impl Dim for D {
453 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
454 let rank = shape.rank();
455 match self {
456 Self::Minus1 if rank >= 1 => Ok(rank - 1),
457 Self::Minus2 if rank >= 2 => Ok(rank - 2),
458 Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
459 Self::Index(u) => u.to_index(shape, op),
460 _ => Err(self.out_of_range(shape, op))?,
461 }
462 }
463
464 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
465 let rank = shape.rank();
466 match self {
467 Self::Minus1 => Ok(rank),
468 Self::Minus2 if rank >= 1 => Ok(rank - 1),
469 Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
470 Self::Index(u) => u.to_index_plus_one(shape, op),
471 _ => Err(self.out_of_range(shape, op))?,
472 }
473 }
474}
475
476pub trait Dims {
477 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
478 fn check_indexes(dims: &[usize], shape: &Shape, op: &'static str) -> Result<()> {
479 for (i, &dim) in dims.iter().enumerate() {
480 if dims[..i].contains(&dim) {
481 return Err(Error::DuplicateDimIndex {
482 shape: shape.clone(),
483 dims: dims.to_vec(),
484 op,
485 })?;
486 }
487 if dim >= shape.rank() {
488 return Err(Error::DimOutOfRange {
489 shape: shape.clone(),
490 dim: dim as i32,
491 op,
492 })?;
493 }
494 }
495 Ok(())
496 }
497}
498
499impl Dims for Vec<usize> {
500 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
501 Self::check_indexes(&self, shape, op)?;
502 Ok(self)
503 }
504}
505
506impl<const N: usize> Dims for [usize; N] {
507 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
508 Self::check_indexes(&self, shape, op)?;
509 Ok(self.to_vec())
510 }
511}
512
513impl Dims for &[usize] {
514 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
515 Self::check_indexes(&self, shape, op)?;
516 Ok(self.to_vec())
517 }
518}
519
520impl Dims for () {
521 fn to_indexes(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
522 Ok(vec![])
523 }
524}
525
526impl<D: Dim + Sized> Dims for D {
527 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
528 let dim = self.to_index(shape, op)?;
529 Ok([dim].to_vec())
530 }
531}
532
533impl<D: Dim> Dims for (D,) {
534 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
535 let dim = self.0.to_index(shape, op)?;
536 Ok([dim].to_vec())
537 }
538}
539
540impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
541 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
542 let d0 = self.0.to_index(shape, op)?;
543 let d1 = self.1.to_index(shape, op)?;
544 Ok([d0, d1].to_vec())
545 }
546}
547
548impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
549 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
550 let d0 = self.0.to_index(shape, op)?;
551 let d1 = self.1.to_index(shape, op)?;
552 let d2 = self.2.to_index(shape, op)?;
553 Ok([d0, d1, d2].to_vec())
554 }
555}
556
557impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
558 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
559 let d0 = self.0.to_index(shape, op)?;
560 let d1 = self.1.to_index(shape, op)?;
561 let d2 = self.2.to_index(shape, op)?;
562 let d3 = self.3.to_index(shape, op)?;
563 Ok([d0, d1, d2, d3].to_vec())
564 }
565}
566
567impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
568 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
569 let d0 = self.0.to_index(shape, op)?;
570 let d1 = self.1.to_index(shape, op)?;
571 let d2 = self.2.to_index(shape, op)?;
572 let d3 = self.3.to_index(shape, op)?;
573 let d4 = self.4.to_index(shape, op)?;
574 Ok([d0, d1, d2, d3, d4].to_vec())
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581
582 #[test]
583 fn stride() {
584 let shape = Shape::from(());
585 assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
586 let shape = Shape::from(42);
587 assert_eq!(shape.stride_contiguous(), [1]);
588 let shape = Shape::from((42, 1337));
589 assert_eq!(shape.stride_contiguous(), [1337, 1]);
590 let shape = Shape::from((299, 792, 458));
591 assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
592 }
593
594 #[test]
595 fn test_from_tuple() {
596 let shape = Shape::from((2,));
597 assert_eq!(shape.dims(), &[2]);
598 let shape = Shape::from((2, 3));
599 assert_eq!(shape.dims(), &[2, 3]);
600 let shape = Shape::from((2, 3, 4));
601 assert_eq!(shape.dims(), &[2, 3, 4]);
602 let shape = Shape::from((2, 3, 4, 5));
603 assert_eq!(shape.dims(), &[2, 3, 4, 5]);
604 let shape = Shape::from((2, 3, 4, 5, 6));
605 assert_eq!(shape.dims(), &[2, 3, 4, 5, 6]);
606 let shape = Shape::from((2, 3, 4, 5, 6, 7));
607 assert_eq!(shape.dims(), &[2, 3, 4, 5, 6, 7]);
608 }
609
610 #[test]
611 fn test_dim_coordinates_2d() {
612 let shape = Shape([2, 2].to_vec());
613 let mut iter = shape.dim_coordinates();
614
615 let expected = [
616 [0, 0].to_vec(),
617 [0, 1].to_vec(),
618 [1, 0].to_vec(),
619 [1, 1].to_vec(),
620 ];
621
622 for e in expected {
623 let idx = iter.next();
624 assert_eq!(idx.unwrap(), e);
625 }
626
627 assert!(iter.next().is_none());
629 }
630
631 #[test]
632 fn test_dim_coordinates_2d_varied() {
633 let shape = Shape([3, 1].to_vec());
634 let mut iter = shape.dim_coordinates();
635
636 let expected = [
637 [0, 0].to_vec(),
638 [1, 0].to_vec(),
639 [2, 0].to_vec(),
640 ];
641
642 for e in expected {
643 let idx = iter.next();
644 assert_eq!(idx.unwrap(), e);
645 }
646
647 assert!(iter.next().is_none());
648 }
649
650 #[test]
651 fn test_dim_coordinates_3d() {
652 let shape = Shape([2, 2, 2].to_vec());
653 let mut iter = shape.dim_coordinates();
654
655 let mut collected = Vec::new();
656 while let Some(idx) = iter.next() {
657 collected.push(idx);
658 }
659
660 let expected = [
661 [0, 0, 0].to_vec(),
662 [0, 0, 1].to_vec(),
663 [0, 1, 0].to_vec(),
664 [0, 1, 1].to_vec(),
665 [1, 0, 0].to_vec(),
666 [1, 0, 1].to_vec(),
667 [1, 1, 0].to_vec(),
668 [1, 1, 1].to_vec(),
669 ];
670
671 assert_eq!(collected, expected);
672 }
673
674 #[test]
675 fn test_dim_n_coordinates_2d() {
676 let shape = Shape([2, 2].to_vec());
677 let mut iter = shape.dim2_coordinates().unwrap();
678
679 let expected = [
680 [0, 0],
681 [0, 1],
682 [1, 0],
683 [1, 1],
684 ];
685
686 for e in expected {
687 let idx = iter.next();
688 assert_eq!(idx.unwrap(), e);
689 }
690
691 assert!(iter.next().is_none());
692 }
693
694 #[test]
695 fn test_dim_n_coordinates_3d() {
696 let shape = Shape([2, 2, 2].to_vec());
697 let mut iter = shape.dim3_coordinates().unwrap();
698
699 let expected = [
700 [0, 0, 0],
701 [0, 0, 1],
702 [0, 1, 0],
703 [0, 1, 1],
704 [1, 0, 0],
705 [1, 0, 1],
706 [1, 1, 0],
707 [1, 1, 1],
708 ];
709
710 for e in expected {
711 let idx = iter.next();
712 assert_eq!(idx.unwrap(), e);
713 }
714
715 assert!(iter.next().is_none());
716 }
717
718 #[test]
719 fn test_dim_n_coordinates_wrong_dim() {
720 let shape = Shape([2, 2].to_vec());
721
722 assert!(shape.dim3_coordinates().is_err());
724 assert!(shape.dims_coordinates::<3>().is_err());
725 }
726
727 #[test]
728 fn test_dim_n_coordinates_empty_shape() {
729 let shape = Shape(vec![]);
730 let mut iter = shape.dims_coordinates::<0>().unwrap();
731 let result = iter.next();
732 assert_eq!(result, None);
733 }
734}