1use super::indexing::ravel_index;
4use super::{AsIndex, Slice, SliceArg};
5use alloc::format;
6use alloc::string::{String, ToString};
7use alloc::vec;
8use alloc::vec::Vec;
9use core::fmt::{Debug, Display, Formatter};
10use core::str::FromStr;
11use core::{
12 ops::{Deref, DerefMut, Index, IndexMut, Range},
13 slice::{Iter, IterMut, SliceIndex},
14};
15use serde::{Deserialize, Serialize};
16
17pub use crate::errors::ExpressionError;
18pub use crate::tensor::index_conversion::AsSize;
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
22pub struct Shape {
23 pub dims: Vec<usize>,
25}
26
27#[allow(missing_docs)]
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum ShapeError {
31 RankMismatch { left: usize, right: usize },
33 IncompatibleDims {
35 left: usize,
36 right: usize,
37 dim: usize,
38 },
39 OutOfBounds { dim: usize, rank: usize },
41 IncompatibleShapes { left: Shape, right: Shape },
43 Invalid { reason: String },
45}
46
47impl ShapeError {
48 fn empty() -> Self {
49 Self::Invalid {
50 reason: "Shape is empty.".into(),
51 }
52 }
53}
54
55impl Shape {
56 pub fn new<const D: usize>(dims: [usize; D]) -> Self {
58 Self {
60 dims: dims.to_vec(),
61 }
62 }
63
64 pub fn num_elements(&self) -> usize {
66 self.dims.iter().product()
67 }
68
69 pub fn num_dims(&self) -> usize {
73 self.dims.len()
74 }
75
76 pub fn rank(&self) -> usize {
80 self.num_dims()
81 }
82
83 pub fn dims<const D: usize>(&self) -> [usize; D] {
86 let mut dims = [1; D];
87 dims[..D].copy_from_slice(&self.dims[..D]);
88 dims
89 }
90
91 pub fn flatten(mut self) -> Self {
93 self.dims = [self.num_elements()].into();
94 self
95 }
96
97 pub fn flatten_dims(self, start_dim: impl AsIndex, end_dim: impl AsIndex) -> Self {
127 let rank = self.rank();
128 let start = start_dim.expect_dim_index(rank);
129 let end = end_dim.expect_dim_index(rank);
130
131 assert!(
132 start <= end,
133 "start_dim ({start}) must be <= than end_dim ({end})"
134 );
135
136 let existing = self.dims;
137
138 let flattened_size = existing[start..=end].iter().product();
139
140 let new_rank = rank - (end - start);
141 let mut dims = vec![0; new_rank];
142 dims[..start].copy_from_slice(&existing[..start]);
143 dims[start] = flattened_size;
144 dims[start + 1..].copy_from_slice(&existing[end + 1..]);
145
146 Self { dims }
147 }
148
149 pub fn ravel_index<I: AsIndex>(&self, indices: &[I]) -> usize {
163 ravel_index(indices, &self.dims)
164 }
165
166 pub fn into_ranges(self) -> Vec<Range<usize>> {
168 self.iter().map(|&d| 0..d).collect()
169 }
170
171 pub fn into_slices<S>(self, slices: S) -> Vec<Slice>
231 where
232 S: SliceArg,
233 {
234 slices.into_slices(&self)
235 }
236
237 pub fn to_vec(&self) -> Vec<usize> {
239 self.dims.clone()
240 }
241
242 pub fn iter(&self) -> Iter<'_, usize> {
244 self.dims.iter()
245 }
246
247 pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
249 self.dims.iter_mut()
250 }
251
252 pub fn as_slice(&self) -> &[usize] {
254 &self.dims
255 }
256
257 pub fn as_mut_slice(&mut self) -> &mut [usize] {
259 &mut self.dims
260 }
261
262 pub fn insert(&mut self, index: usize, size: usize) {
264 self.dims.insert(index, size);
265 }
266
267 pub fn remove(&mut self, index: usize) -> usize {
269 self.dims.remove(index)
270 }
271
272 pub fn push(&mut self, size: usize) {
274 self.dims.push(size)
275 }
276
277 pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
279 self.dims.extend(iter)
280 }
281
282 pub fn swap(mut self, dim1: usize, dim2: usize) -> Result<Self, ShapeError> {
284 if dim1 > self.rank() {
285 return Err(ShapeError::OutOfBounds {
286 dim: dim1,
287 rank: self.rank(),
288 });
289 }
290 if dim2 > self.rank() {
291 return Err(ShapeError::OutOfBounds {
292 dim: dim2,
293 rank: self.rank(),
294 });
295 }
296 self.dims.swap(dim1, dim2);
297 Ok(self)
298 }
299
300 pub fn permute(mut self, axes: &[usize]) -> Result<Self, ShapeError> {
302 if axes.len() != self.rank() {
303 return Err(ShapeError::RankMismatch {
304 left: self.rank(),
305 right: axes.len(),
306 });
307 }
308 debug_assert!(axes.iter().all(|i| i < &self.rank()));
309
310 self.dims = axes.iter().map(|&i| self.dims[i]).collect();
311 Ok(self)
312 }
313
314 pub fn repeat(mut self, dim: usize, times: usize) -> Result<Shape, ShapeError> {
316 if dim >= self.rank() {
317 return Err(ShapeError::OutOfBounds {
318 dim,
319 rank: self.rank(),
320 });
321 }
322
323 self.dims[dim] *= times;
324 Ok(self)
325 }
326
327 pub fn reduce(mut self, dim: usize) -> Result<Shape, ShapeError> {
329 if dim >= self.rank() {
330 return Err(ShapeError::OutOfBounds {
331 dim,
332 rank: self.rank(),
333 });
334 }
335
336 self.dims[dim] = 1;
337 Ok(self)
338 }
339
340 pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, ShapeError>
342 where
343 I: IntoIterator<Item = &'a Shape>,
344 {
345 let mut iter = shapes.into_iter();
346
347 let first = iter.next().ok_or(ShapeError::empty())?;
348
349 if dim >= first.rank() {
350 return Err(ShapeError::OutOfBounds {
351 dim,
352 rank: first.rank(),
353 });
354 }
355
356 let mut shape = first.clone();
357
358 for s in iter {
359 if s.rank() != shape.rank() {
360 return Err(ShapeError::RankMismatch {
361 left: shape.rank(),
362 right: s.rank(),
363 });
364 }
365
366 if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
367 return Err(ShapeError::IncompatibleShapes {
368 left: shape.clone(),
369 right: s.clone(),
370 });
371 }
372
373 shape[dim] += s[dim];
374 }
375
376 Ok(shape)
377 }
378
379 pub fn slice(mut self, slices: &[Slice]) -> Result<Self, ShapeError> {
381 if slices.len() > self.rank() {
382 return Err(ShapeError::RankMismatch {
383 left: self.rank(),
384 right: slices.len(),
385 });
386 }
387
388 slices
389 .iter()
390 .zip(self.iter_mut())
391 .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
392
393 Ok(self)
394 }
395
396 pub fn broadcast(&self, other: &Self) -> Result<Self, ShapeError> {
405 Self::broadcast_many([self, other])
406 }
407
408 pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, ShapeError>
412 where
413 I: IntoIterator<Item = &'a Shape>,
414 {
415 let mut iter = shapes.into_iter();
416 let mut broadcasted = iter.next().ok_or(ShapeError::empty())?.clone();
417 let rank = broadcasted.rank();
418
419 for shape in iter {
420 if shape.rank() != rank {
421 return Err(ShapeError::RankMismatch {
422 left: rank,
423 right: shape.rank(),
424 });
425 }
426
427 for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
428 match (*d_lhs, d_rhs) {
429 (a, b) if a == b => {} (1, b) => *d_lhs = b, (_a, 1) => {} _ => {
433 return Err(ShapeError::IncompatibleDims {
434 left: *d_lhs,
435 right: d_rhs,
436 dim,
437 });
438 }
439 }
440 }
441 }
442
443 Ok(broadcasted)
444 }
445
446 pub fn expand(&self, target: Shape) -> Result<Shape, ShapeError> {
448 let target_rank = target.rank();
449 if self.rank() > target_rank {
450 return Err(ShapeError::RankMismatch {
451 left: self.rank(),
452 right: target_rank,
453 });
454 }
455
456 for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() {
457 if dim_self != dim_target && *dim_self != 1 {
458 return Err(ShapeError::IncompatibleDims {
459 left: *dim_self,
460 right: *dim_target,
461 dim: target_rank - i - 1,
462 });
463 }
464 }
465
466 Ok(target)
467 }
468
469 pub fn reshape<A, T>(&self, args: A) -> Result<Shape, ShapeError>
471 where
472 A: AsRef<[T]> + Debug,
473 T: AsIndex,
474 {
475 let args = args.as_ref();
476 let mut infer_index = None;
477 let mut dims = Vec::new();
478
479 let mut new_size = 1;
480
481 for (idx, &s) in args.iter().enumerate() {
482 let s = s.as_index();
483 if s > 0 {
484 let s = s as usize;
485 new_size *= s;
486 dims.push(s);
487 } else if s == 0 {
488 let s = self.dims[idx];
491 new_size *= s;
492 dims.push(s);
493 } else if s == -1 {
494 match infer_index {
495 None => {
496 infer_index = Some(idx);
497 dims.push(1);
499 }
500 Some(_) => {
501 return Err(ShapeError::Invalid {
502 reason: "Repeated -1 in reshape".to_string(),
503 });
504 }
505 }
506 } else {
507 return Err(ShapeError::Invalid {
508 reason: "The given shape cannot contain negative dimensions (other than -1)."
509 .to_string(),
510 });
511 }
512 }
513
514 let source_size = self.num_elements();
515 match infer_index {
516 None => {
517 if source_size != new_size {
518 return Err(ShapeError::Invalid {
519 reason: format!(
520 "The given shape doesn't have the same number of elements as the current shape. Current shape: {self}, target shape: {dims:?}.",
521 ),
522 });
523 }
524 }
525 Some(idx) => {
526 if !source_size.is_multiple_of(new_size) {
527 return Err(ShapeError::Invalid {
528 reason: format!(
529 "Cannot infer a valid target shape. Current shape: {self}, target dimensions: {args:?}."
530 ),
531 });
532 }
533 dims[idx] = source_size / new_size;
534 }
535 }
536
537 Ok(dims.into())
538 }
539}
540
541pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result<Shape, ShapeError> {
546 let rank = lhs.rank();
547 if rank != rhs.rank() {
548 return Err(ShapeError::RankMismatch {
549 left: rank,
550 right: rhs.rank(),
551 });
552 }
553
554 if lhs[rank - 1] != rhs[rank - 2] {
555 return Err(ShapeError::IncompatibleShapes {
556 left: lhs.clone(),
557 right: rhs.clone(),
558 });
559 }
560
561 let mut shape = if rank > 2 {
562 Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
564 } else {
565 Shape::new([])
566 };
567 shape.extend([lhs[rank - 2], rhs[rank - 1]]);
568
569 Ok(shape)
570}
571
572impl Display for Shape {
573 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
574 self.dims.fmt(f)
575 }
576}
577
578impl FromStr for Shape {
579 type Err = crate::ExpressionError;
580
581 fn from_str(source: &str) -> Result<Self, Self::Err> {
582 let mut s = source.trim();
583
584 const DELIMS: [(&str, &str); 2] = [("[", "]"), ("(", ")")];
585
586 for (open, close) in DELIMS {
587 if let Some(p) = s.strip_prefix(open) {
588 if let Some(p) = p.strip_suffix(close) {
589 s = p.trim();
590 break;
591 } else {
592 return Err(crate::ExpressionError::ParseError {
593 message: "Unbalanced delimiters".to_string(),
594 source: source.to_string(),
595 });
596 }
597 }
598 }
599
600 if s.is_empty() {
601 return Ok(Shape::new([]));
602 }
603
604 let dims =
605 s.split(',')
606 .map(|dim_str| {
607 dim_str.trim().parse::<usize>().map_err(|_| {
608 crate::ExpressionError::ParseError {
609 message: "Unable to parse shape".to_string(),
610 source: source.to_string(),
611 }
612 })
613 })
614 .collect::<Result<Vec<usize>, crate::ExpressionError>>()?;
615
616 if dims.is_empty() {
617 unreachable!("Split should have returned at least one element");
618 }
619
620 Ok(Shape { dims })
621 }
622}
623
624impl<Idx> Index<Idx> for Shape
625where
626 Idx: SliceIndex<[usize]>,
627{
628 type Output = Idx::Output;
629
630 fn index(&self, index: Idx) -> &Self::Output {
631 &self.dims[index]
632 }
633}
634
635impl<Idx> IndexMut<Idx> for Shape
636where
637 Idx: SliceIndex<[usize]>,
638{
639 fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
640 &mut self.dims[index]
641 }
642}
643
644impl Deref for Shape {
646 type Target = [usize];
647
648 fn deref(&self) -> &Self::Target {
649 &self.dims
650 }
651}
652
653impl DerefMut for Shape {
655 fn deref_mut(&mut self) -> &mut Self::Target {
656 &mut self.dims
657 }
658}
659impl AsRef<[usize]> for Shape {
664 fn as_ref(&self) -> &[usize] {
665 &self.dims
666 }
667}
668
669impl From<Shape> for Vec<usize> {
670 fn from(shape: Shape) -> Self {
671 shape.dims
672 }
673}
674
675impl<T, I> From<T> for Shape
676where
677 T: IntoIterator<Item = I>,
678 I: AsSize,
679{
680 fn from(dims: T) -> Self {
681 Shape {
682 dims: dims.into_iter().map(|d| d.as_size()).collect(),
683 }
684 }
685}
686
687#[cfg(test)]
688#[allow(clippy::identity_op, reason = "useful for clarity")]
689mod tests {
690 use super::*;
691 use crate::s;
692 use alloc::string::ToString;
693 use alloc::vec;
694
695 #[test]
696 fn test_shape_to_str() {
697 let shape = Shape::new([2, 3, 4, 5]);
698 assert_eq!(shape.to_string(), "[2, 3, 4, 5]");
699 }
700
701 #[test]
702 fn test_shape_from_str() {
703 assert_eq!(
704 "[2, 3, 4, 5]".parse::<Shape>().unwrap(),
705 Shape::new([2, 3, 4, 5])
706 );
707 assert_eq!(
708 "(2, 3, 4, 5)".parse::<Shape>().unwrap(),
709 Shape::new([2, 3, 4, 5])
710 );
711 assert_eq!(
712 "2, 3, 4, 5".parse::<Shape>().unwrap(),
713 Shape::new([2, 3, 4, 5])
714 );
715
716 assert_eq!("[2]".parse::<Shape>().unwrap(), Shape::new([2]));
717 assert_eq!("(2)".parse::<Shape>().unwrap(), Shape::new([2]));
718 assert_eq!("2".parse::<Shape>().unwrap(), Shape::new([2]));
719
720 assert_eq!("[]".parse::<Shape>().unwrap(), Shape::new([]));
721 assert_eq!("".parse::<Shape>().unwrap(), Shape::new([]));
722
723 assert_eq!(
724 "[".parse::<Shape>(),
725 Err(crate::ExpressionError::ParseError {
726 message: "Unbalanced delimiters".to_string(),
727 source: "[".to_string()
728 })
729 );
730
731 assert_eq!(
732 "[[1]".parse::<Shape>(),
733 Err(crate::ExpressionError::ParseError {
734 message: "Unable to parse shape".to_string(),
735 source: "[[1]".to_string()
736 })
737 );
738 assert_eq!(
739 "[[1]]".parse::<Shape>(),
740 Err(crate::ExpressionError::ParseError {
741 message: "Unable to parse shape".to_string(),
742 source: "[[1]]".to_string()
743 })
744 );
745 assert_eq!(
746 "[1)".parse::<Shape>(),
747 Err(crate::ExpressionError::ParseError {
748 message: "Unbalanced delimiters".to_string(),
749 source: "[1)".to_string()
750 })
751 );
752
753 assert_eq!(
754 "]".parse::<Shape>(),
755 Err(crate::ExpressionError::ParseError {
756 message: "Unable to parse shape".to_string(),
757 source: "]".to_string()
758 })
759 );
760
761 assert_eq!(
762 "[a]".parse::<Shape>(),
763 Err(crate::ExpressionError::ParseError {
764 message: "Unable to parse shape".to_string(),
765 source: "[a]".to_string()
766 })
767 );
768 }
769
770 #[test]
771 fn num_dims_and_rank() {
772 let dims = [2, 3, 4, 5];
773 let shape = Shape::new(dims);
774 assert_eq!(4, shape.num_dims());
775 assert_eq!(4, shape.rank());
776 }
777
778 #[test]
779 fn num_elements() {
780 let dims = [2, 3, 4, 5];
781 let shape = Shape::new(dims);
782 assert_eq!(120, shape.num_elements());
783 }
784
785 #[test]
786 #[allow(clippy::into_iter_on_ref)]
787 fn test_shape_into_iter() {
788 let dims = [2, 3, 4, 5];
789 let shape = Shape::new(dims);
790
791 assert_eq!(shape.into_iter().sum::<usize>(), 14);
792 }
793
794 #[test]
795 fn test_into_ranges() {
796 let dims = [2, 3, 4, 5];
797 let shape = Shape::new(dims);
798 assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
799 }
800
801 #[test]
802 fn test_to_vec() {
803 let dims = [2, 3, 4, 5];
804 let shape = Shape::new(dims);
805 assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
806 }
807
808 #[allow(clippy::single_range_in_vec_init)]
809 #[test]
810 fn test_into_slices() {
811 let slices = Shape::new([3]).into_slices(1..4);
812 assert_eq!(slices[0].to_range(3), 1..3);
813
814 let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
815 assert_eq!(slices[0].to_range(3), 1..3);
816 assert_eq!(slices[1].to_range(4), 0..2);
817
818 let slices = Shape::new([3]).into_slices(..-2);
819 assert_eq!(slices[0].to_range(3), 0..1);
820
821 let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
822 assert_eq!(slices[0].to_range(2), 0..2);
823 assert_eq!(slices[1].to_range(3), 1..2);
824
825 let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
826 assert_eq!(slices[0].to_range(2), 0..2);
827 assert_eq!(slices[1].to_range(3), 2..3);
828 }
829
830 #[test]
831 fn test_shape_index() {
832 let shape = Shape::new([2, 3, 4, 5]);
833
834 assert_eq!(shape[0], 2);
835 assert_eq!(shape[1], 3);
836 assert_eq!(shape[2], 4);
837 assert_eq!(shape[3], 5);
838
839 assert_eq!(shape[1..3], *&[3, 4]);
841 assert_eq!(shape[1..=2], *&[3, 4]);
842 assert_eq!(shape[..], *&[2, 3, 4, 5]);
843 }
844
845 #[test]
846 fn test_shape_slice_methods() {
847 let shape = Shape::new([2, 3, 4, 5]);
848
849 let dim = shape.first();
850 assert_eq!(dim, Some(&2));
851 let dim = shape.last();
852 assert_eq!(dim, Some(&5));
853
854 assert!(!shape.is_empty());
855 let shape = Shape::new([]);
856 assert!(shape.is_empty());
857 }
858
859 #[test]
860 fn test_shape_iter() {
861 let dims = [2, 3, 4, 5];
862 let shape = Shape::new(dims);
863
864 for (d, sd) in dims.iter().zip(shape.iter()) {
865 assert_eq!(d, sd);
866 }
867 }
868
869 #[test]
870 fn test_shape_iter_mut() {
871 let mut shape = Shape::new([2, 3, 4, 5]);
872
873 for d in shape.iter_mut() {
874 *d += 1;
875 }
876
877 assert_eq!(&shape.dims, &[3, 4, 5, 6]);
878 }
879
880 #[test]
881 fn test_shape_as_slice() {
882 let dims = [2, 3, 4, 5];
883 let shape = Shape::new(dims);
884
885 assert_eq!(shape.as_slice(), dims.as_slice());
886
887 let shape_slice: &[usize] = &shape;
889 assert_eq!(shape_slice, *&[2, 3, 4, 5]);
890 }
891
892 #[test]
893 fn test_shape_as_mut_slice() {
894 let mut dims = [2, 3, 4, 5];
895 let mut shape = Shape::new(dims);
896
897 let shape_mut = shape.as_mut_slice();
898 assert_eq!(shape_mut, dims.as_mut_slice());
899 shape_mut[1] = 6;
900
901 assert_eq!(shape_mut, &[2, 6, 4, 5]);
902
903 let mut shape = Shape::new(dims);
904 let shape = &mut shape[..];
905 shape[1] = 6;
906
907 assert_eq!(shape, shape_mut)
908 }
909
910 #[test]
911 fn test_shape_flatten() {
912 let shape = Shape::new([2, 3, 4, 5]);
913 assert_eq!(shape.num_elements(), 120);
914
915 let shape = shape.flatten();
916 assert_eq!(shape.num_elements(), 120);
917 assert_eq!(&shape.dims, &[120]);
918 }
919
920 #[test]
921 fn test_ravel() {
922 let shape = Shape::new([2, 3, 4, 5]);
923
924 assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0);
925 assert_eq!(
926 shape.ravel_index(&[1, 2, 3, 4]),
927 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
928 );
929 }
930
931 #[test]
932 fn test_shape_insert_remove_push() {
933 let dims = [2, 3, 4, 5];
934 let mut shape = Shape::new(dims);
935 let size = 6;
936 shape.insert(1, size);
937
938 assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
939
940 let removed = shape.remove(1);
941 assert_eq!(removed, size);
942 assert_eq!(shape, Shape::new(dims));
943
944 shape.push(6);
945 assert_eq!(shape, Shape::new([2, 3, 4, 5, 6]));
946 }
947
948 #[test]
949 fn test_shape_swap_permute() {
950 let dims = [2, 3, 4, 5];
951 let shape = Shape::new(dims);
952 let shape = shape.swap(1, 2).unwrap();
953
954 assert_eq!(&shape.dims, &[2, 4, 3, 5]);
955
956 let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
957 assert_eq!(shape, Shape::new(dims));
958 }
959
960 #[test]
961 #[should_panic]
962 fn test_shape_swap_out_of_bounds() {
963 let shape = Shape::new([2, 3, 4, 5]);
964
965 shape.swap(0, 4).unwrap();
966 }
967
968 #[test]
969 #[should_panic]
970 fn test_shape_permute_incomplete() {
971 let shape = Shape::new([2, 3, 4, 5]);
972
973 shape.permute(&[0, 2, 1]).unwrap();
974 }
975
976 #[test]
977 fn test_shape_repeat() {
978 let shape = Shape::new([2, 3, 4, 5]);
979
980 let out = shape.repeat(2, 3).unwrap();
981 assert_eq!(out, Shape::new([2, 3, 12, 5]));
982 }
983
984 #[test]
985 fn test_shape_repeat_invalid() {
986 let shape = Shape::new([2, 3, 4, 5]);
987
988 let out = shape.repeat(5, 3);
989 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
990 }
991
992 #[test]
993 fn test_shape_reduce() {
994 let shape = Shape::new([2, 3, 4, 5]);
995
996 let out = shape.reduce(2).unwrap();
997 assert_eq!(out, Shape::new([2, 3, 1, 5]));
998 }
999
1000 #[test]
1001 fn test_shape_reduce_invalid() {
1002 let shape = Shape::new([2, 3, 4, 5]);
1003
1004 let out = shape.reduce(5);
1005 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
1006 }
1007
1008 #[test]
1009 fn test_shape_broadcast_binary() {
1010 let lhs = Shape::new([1, 1, 2, 4]);
1011 let rhs = Shape::new([7, 6, 2, 1]);
1012
1013 let out = lhs.broadcast(&rhs).unwrap();
1014 assert_eq!(out, Shape::new([7, 6, 2, 4]));
1015 }
1016
1017 #[test]
1018 fn test_shape_broadcast_rank_mismatch() {
1019 let lhs = Shape::new([1, 2, 4]);
1020 let rhs = Shape::new([7, 6, 2, 4]);
1021
1022 let out = lhs.broadcast(&rhs);
1023 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
1024 }
1025
1026 #[test]
1027 fn test_shape_broadcast_incompatible_dims() {
1028 let lhs = Shape::new([1, 2, 2, 4]);
1029 let rhs = Shape::new([7, 6, 2, 1]);
1030
1031 let out = lhs.broadcast(&rhs);
1032 assert_eq!(
1033 out,
1034 Err(ShapeError::IncompatibleDims {
1035 left: 2,
1036 right: 6,
1037 dim: 1
1038 })
1039 );
1040 }
1041
1042 #[test]
1043 fn test_shape_broadcast_many() {
1044 let s1 = Shape::new([1, 1, 2, 4]);
1045 let s2 = Shape::new([7, 1, 2, 1]);
1046 let s3 = Shape::new([7, 6, 1, 1]);
1047
1048 let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
1049 assert_eq!(out, Shape::new([7, 6, 2, 4]));
1050 }
1051
1052 #[test]
1053 fn test_shape_broadcast_many_rank_mismatch() {
1054 let s1 = Shape::new([1, 1, 2, 4]);
1055 let s2 = Shape::new([7, 1, 2, 1]);
1056 let s3 = Shape::new([1, 6, 1]);
1057
1058 let out = Shape::broadcast_many([&s1, &s2, &s3]);
1059 assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 }));
1060 }
1061
1062 #[test]
1063 fn test_shape_broadcast_many_incompatible_dims() {
1064 let s1 = Shape::new([1, 1, 2, 4]);
1065 let s2 = Shape::new([7, 1, 2, 1]);
1066 let s3 = Shape::new([4, 6, 1, 1]);
1067
1068 let out = Shape::broadcast_many([&s1, &s2, &s3]);
1069 assert_eq!(
1070 out,
1071 Err(ShapeError::IncompatibleDims {
1072 left: 7,
1073 right: 4,
1074 dim: 0
1075 })
1076 );
1077 }
1078
1079 #[test]
1080 fn test_shape_broadcast_many_empty() {
1081 let out = Shape::broadcast_many(&[]);
1082 assert_eq!(out, Err(ShapeError::empty()));
1083 }
1084
1085 #[test]
1086 fn test_shape_matmul_2d() {
1087 let lhs = Shape::new([2, 4]);
1088 let rhs = Shape::new([4, 2]);
1089 let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1090 assert_eq!(out, Shape::new([2, 2]));
1091 }
1092
1093 #[test]
1094 fn test_shape_matmul_4d_broadcasted() {
1095 let lhs = Shape::new([1, 3, 2, 4]);
1096 let rhs = Shape::new([2, 1, 4, 2]);
1097 let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1098 assert_eq!(out, Shape::new([2, 3, 2, 2]));
1099 }
1100
1101 #[test]
1102 fn test_shape_matmul_invalid_rank() {
1103 let lhs = Shape::new([3, 2, 4]);
1104 let rhs = Shape::new([2, 1, 4, 2]);
1105 let out = calculate_matmul_output(&lhs, &rhs);
1106 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
1107 }
1108
1109 #[test]
1110 fn test_shape_matmul_invalid_shape() {
1111 let lhs = Shape::new([1, 3, 2, 4]);
1112 let rhs = Shape::new([2, 1, 3, 2]);
1113 let out = calculate_matmul_output(&lhs, &rhs);
1114 assert_eq!(
1115 out,
1116 Err(ShapeError::IncompatibleShapes {
1117 left: lhs,
1118 right: rhs
1119 })
1120 );
1121 }
1122
1123 #[test]
1124 fn test_shape_matmul_invalid_broadcast() {
1125 let lhs = Shape::new([1, 3, 2, 4]);
1126 let rhs = Shape::new([2, 2, 4, 2]);
1127 let out = calculate_matmul_output(&lhs, &rhs);
1128 assert_eq!(
1129 out,
1130 Err(ShapeError::IncompatibleDims {
1131 left: 3,
1132 right: 2,
1133 dim: 1
1134 })
1135 );
1136 }
1137
1138 #[test]
1139 fn test_shape_cat() {
1140 let s1 = Shape::new([2, 3, 4, 5]);
1141 let s2 = Shape::new([1, 3, 4, 5]);
1142 let s3 = Shape::new([4, 3, 4, 5]);
1143
1144 let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
1145 assert_eq!(out, Shape::new([7, 3, 4, 5]));
1146
1147 let s1 = Shape::new([2, 3, 4, 5]);
1148 let s2 = Shape::new([2, 3, 2, 5]);
1149 let s3 = Shape::new([2, 3, 1, 5]);
1150
1151 let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
1152 assert_eq!(out, Shape::new([2, 3, 7, 5]));
1153 }
1154
1155 #[test]
1156 fn test_shape_cat_empty() {
1157 let out = Shape::cat(&[], 0);
1158 assert_eq!(out, Err(ShapeError::empty()));
1159 }
1160
1161 #[test]
1162 fn test_shape_cat_dim_out_of_bounds() {
1163 let s1 = Shape::new([2, 3, 4, 5]);
1164 let s2 = Shape::new([2, 3, 4, 5]);
1165 let out = Shape::cat(&[s1, s2], 4);
1166 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 }));
1167 }
1168
1169 #[test]
1170 fn test_shape_cat_rank_mismatch() {
1171 let s1 = Shape::new([2, 3, 4, 5]);
1172 let s2 = Shape::new([2, 3, 4, 5, 6]);
1173 let out = Shape::cat(&[s1, s2], 0);
1174 assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 }));
1175 }
1176
1177 #[test]
1178 fn test_shape_cat_incompatible_shapes() {
1179 let s1 = Shape::new([2, 3, 4, 5]);
1180 let s2 = Shape::new([1, 3, 4, 5]);
1181 let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
1182
1183 assert_eq!(
1184 out,
1185 Err(ShapeError::IncompatibleShapes {
1186 left: s1,
1187 right: s2
1188 })
1189 );
1190 }
1191
1192 #[test]
1193 fn test_shape_slice_output_shape_basic() {
1194 let slices = [
1196 Slice::new(0, Some(5), 1), Slice::new(2, Some(8), 1), ];
1199 let original_shape = Shape::new([10, 10, 10]);
1200 let result = original_shape.slice(&slices).unwrap();
1201 assert_eq!(result, Shape::new([5, 6, 10]));
1202 }
1203
1204 #[test]
1205 fn test_shape_slice_output_shape_with_positive_steps() {
1206 let slices = [
1208 Slice::new(0, Some(10), 2), Slice::new(1, Some(9), 3), Slice::new(0, Some(7), 4), ];
1212 let original_shape = Shape::new([20, 20, 20, 30]);
1213 let result = original_shape.slice(&slices).unwrap();
1214 assert_eq!(result, Shape::new([5, 3, 2, 30]));
1215 }
1216
1217 #[test]
1218 fn test_shape_slice_output_shape_with_negative_steps() {
1219 let slices = [
1221 Slice::new(0, Some(10), -1), Slice::new(2, Some(8), -2), ];
1224 let original_shape = Shape::new([20, 20, 20]);
1225 let result = original_shape.slice(&slices).unwrap();
1226 assert_eq!(result, Shape::new([10, 3, 20]));
1227 }
1228
1229 #[test]
1230 fn test_shape_slice_output_shape_mixed_steps() {
1231 let slices = [
1233 Slice::from_range_stepped(1..6, 1), Slice::from_range_stepped(0..10, -3), Slice::from_range_stepped(2..14, 4), ];
1237 let original_shape = Shape::new([20, 20, 20]);
1238 let result = original_shape.slice(&slices).unwrap();
1239 assert_eq!(result, Shape::new([5, 4, 3]));
1240 }
1241
1242 #[test]
1243 fn test_shape_slice_output_shape_partial_dims() {
1244 let slices = [
1246 Slice::from_range_stepped(2..7, 2), ];
1248 let original_shape = Shape::new([10, 20, 30, 40]);
1249 let result = original_shape.slice(&slices).unwrap();
1250 assert_eq!(result, Shape::new([3, 20, 30, 40]));
1251 }
1252
1253 #[test]
1254 fn test_shape_slice_output_shape_edge_cases() {
1255 let slices = [
1257 Slice::from_range_stepped(0..1, 1), Slice::from_range_stepped(0..10, 100), Slice::from_range_stepped(5..5, 1), ];
1261 let original_shape = Shape::new([10, 20, 30]);
1262 let result = original_shape.slice(&slices).unwrap();
1263 assert_eq!(result, Shape::new([1, 1, 0]));
1264 }
1265
1266 #[test]
1267 fn test_shape_slice_output_shape_empty() {
1268 let slices = [];
1270 let original_shape = Shape::new([10, 20, 30]);
1271 let result = original_shape.slice(&slices).unwrap();
1272 assert_eq!(result, Shape::new([10, 20, 30]));
1273 }
1274
1275 #[test]
1276 fn test_shape_slice_output_shape_uneven_division() {
1277 let slices = [
1279 Slice::from_range_stepped(0..7, 3), Slice::from_range_stepped(0..11, 4), Slice::from_range_stepped(1..10, 5), ];
1283 let original_shape = Shape::new([20, 20, 20]);
1284 let result = original_shape.slice(&slices).unwrap();
1285 assert_eq!(result, Shape::new([3, 3, 2]));
1286 }
1287
1288 #[test]
1289 fn test_shape_expand() {
1290 let shape = Shape::new([1, 3, 1]);
1291 let expanded = Shape::new([2, 3, 4]);
1292 let out = shape.expand(expanded.clone()).unwrap();
1293 assert_eq!(out, expanded);
1294 }
1295
1296 #[test]
1297 fn test_shape_expand_higher_rank() {
1298 let shape = Shape::new([1, 4]);
1299 let expanded = Shape::new([2, 3, 4]);
1300 let out = shape.expand(expanded.clone()).unwrap();
1301 assert_eq!(out, expanded);
1302 }
1303
1304 #[test]
1305 fn test_shape_expand_invalid_rank() {
1306 let shape = Shape::new([1, 3, 1]);
1307 let expanded = Shape::new([3, 4]);
1308 let out = shape.expand(expanded);
1309 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 2 }));
1310 }
1311
1312 #[test]
1313 fn test_shape_expand_incompatible_dims() {
1314 let shape = Shape::new([1, 3, 2]);
1315 let expanded = Shape::new([2, 3, 4]);
1316 let out = shape.expand(expanded);
1317 assert_eq!(
1318 out,
1319 Err(ShapeError::IncompatibleDims {
1320 left: 2,
1321 right: 4,
1322 dim: 2
1323 })
1324 );
1325 }
1326
1327 #[test]
1328 fn test_shape_reshape() {
1329 let shape = Shape::new([2, 3, 4, 5]);
1330 let reshaped = Shape::new([1, 2, 12, 5]);
1331 let out = shape.reshape(reshaped.clone()).unwrap();
1332 assert_eq!(out, reshaped);
1333 }
1334
1335 #[test]
1336 fn test_shape_reshape_invalid() {
1337 let shape = Shape::new([2, 3, 4, 5]);
1338 let reshaped = Shape::new([2, 2, 12, 5]);
1339 let out = shape.reshape(reshaped.clone());
1340 assert_eq!(
1341 out,
1342 Err(ShapeError::Invalid {
1343 reason: "The given shape doesn't have the same number of elements as the current shape. Current shape: [2, 3, 4, 5], target shape: [2, 2, 12, 5].".into(),
1344 })
1345 );
1346 }
1347
1348 #[test]
1349 fn test_shape_reshape_invalid_inferred() {
1350 let shape = Shape::new([2, 4]);
1351 let out = shape.reshape([-1, 3]);
1352 assert_eq!(
1353 out,
1354 Err(ShapeError::Invalid {
1355 reason: "Cannot infer a valid target shape. Current shape: [2, 4], target dimensions: [-1, 3].".into(),
1356 })
1357 );
1358 }
1359
1360 #[test]
1361 fn test_flatten_dims() {
1362 let shape = Shape::new([2, 3, 4, 5]);
1363 let flattened = shape.flatten_dims(-2, 3);
1364 assert_eq!(flattened, Shape::new([2, 3, 20]));
1365 }
1366}