1use super::indexing::ravel_index;
4use super::{AsIndex, Slice, SliceArg};
5use alloc::string::ToString;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::{Debug, Display, Formatter};
9use core::str::FromStr;
10use core::{
11 ops::{Deref, DerefMut, Index, IndexMut, Range},
12 slice::{Iter, IterMut, SliceIndex},
13};
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub struct Shape {
19 pub dims: Vec<usize>,
21}
22
23#[allow(missing_docs)]
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum ShapeError {
27 RankMismatch { left: usize, right: usize },
29 IncompatibleDims {
31 left: usize,
32 right: usize,
33 dim: usize,
34 },
35 OutOfBounds { dim: usize, rank: usize },
37 IncompatibleShapes { left: Shape, right: Shape },
39 Empty,
41}
42
43impl Shape {
44 pub fn new<const D: usize>(dims: [usize; D]) -> Self {
46 Self {
48 dims: dims.to_vec(),
49 }
50 }
51
52 pub fn num_elements(&self) -> usize {
54 self.dims.iter().product()
55 }
56
57 pub fn num_dims(&self) -> usize {
61 self.dims.len()
62 }
63
64 pub fn rank(&self) -> usize {
68 self.num_dims()
69 }
70
71 pub fn dims<const D: usize>(&self) -> [usize; D] {
74 let mut dims = [1; D];
75 dims[..D].copy_from_slice(&self.dims[..D]);
76 dims
77 }
78
79 pub fn flatten(mut self) -> Self {
81 self.dims = [self.num_elements()].into();
82 self
83 }
84
85 pub fn flatten_dims(self, start_dim: impl AsIndex, end_dim: impl AsIndex) -> Self {
115 let rank = self.rank();
116 let start = start_dim.expect_dim_index(rank);
117 let end = end_dim.expect_dim_index(rank);
118
119 assert!(
120 start <= end,
121 "start_dim ({start}) must be <= than end_dim ({end})"
122 );
123
124 let existing = self.dims;
125
126 let flattened_size = existing[start..=end].iter().product();
127
128 let new_rank = rank - (end - start);
129 let mut dims = vec![0; new_rank];
130 dims[..start].copy_from_slice(&existing[..start]);
131 dims[start] = flattened_size;
132 dims[start + 1..].copy_from_slice(&existing[end + 1..]);
133
134 Self { dims }
135 }
136
137 pub fn ravel_index<I: AsIndex>(&self, indices: &[I]) -> usize {
151 ravel_index(indices, &self.dims)
152 }
153
154 pub fn into_ranges(self) -> Vec<Range<usize>> {
156 self.into_iter().map(|d| 0..d).collect()
157 }
158
159 pub fn into_slices<S>(self, slices: S) -> Vec<Slice>
219 where
220 S: SliceArg,
221 {
222 slices.into_slices(&self)
223 }
224
225 pub fn to_vec(&self) -> Vec<usize> {
227 self.dims.clone()
228 }
229
230 pub fn iter(&self) -> Iter<'_, usize> {
232 self.dims.iter()
233 }
234
235 pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
237 self.dims.iter_mut()
238 }
239
240 pub fn as_slice(&self) -> &[usize] {
242 &self.dims
243 }
244
245 pub fn as_mut_slice(&mut self) -> &mut [usize] {
247 &mut self.dims
248 }
249
250 pub fn insert(&mut self, index: usize, size: usize) {
252 self.dims.insert(index, size);
253 }
254
255 pub fn remove(&mut self, index: usize) -> usize {
257 self.dims.remove(index)
258 }
259
260 pub fn push(&mut self, size: usize) {
262 self.dims.push(size)
263 }
264
265 pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
267 self.dims.extend(iter)
268 }
269
270 pub fn swap(mut self, dim1: usize, dim2: usize) -> Result<Self, ShapeError> {
272 if dim1 > self.rank() {
273 return Err(ShapeError::OutOfBounds {
274 dim: dim1,
275 rank: self.rank(),
276 });
277 }
278 if dim2 > self.rank() {
279 return Err(ShapeError::OutOfBounds {
280 dim: dim2,
281 rank: self.rank(),
282 });
283 }
284 self.dims.swap(dim1, dim2);
285 Ok(self)
286 }
287
288 pub fn permute(mut self, axes: &[usize]) -> Result<Self, ShapeError> {
290 if axes.len() != self.rank() {
291 return Err(ShapeError::RankMismatch {
292 left: self.rank(),
293 right: axes.len(),
294 });
295 }
296 debug_assert!(axes.iter().all(|i| i < &self.rank()));
297
298 self.dims = axes.iter().map(|&i| self.dims[i]).collect();
299 Ok(self)
300 }
301
302 pub fn repeat(mut self, dim: usize, times: usize) -> Result<Shape, ShapeError> {
304 if dim >= self.rank() {
305 return Err(ShapeError::OutOfBounds {
306 dim,
307 rank: self.rank(),
308 });
309 }
310
311 self.dims[dim] *= times;
312 Ok(self)
313 }
314
315 pub fn reduce(mut self, dim: usize) -> Result<Shape, ShapeError> {
317 if dim >= self.rank() {
318 return Err(ShapeError::OutOfBounds {
319 dim,
320 rank: self.rank(),
321 });
322 }
323
324 self.dims[dim] = 1;
325 Ok(self)
326 }
327
328 pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, ShapeError>
330 where
331 I: IntoIterator<Item = &'a Shape>,
332 {
333 let mut iter = shapes.into_iter();
334
335 let first = iter.next().ok_or(ShapeError::Empty)?;
336
337 if dim >= first.rank() {
338 return Err(ShapeError::OutOfBounds {
339 dim,
340 rank: first.rank(),
341 });
342 }
343
344 let mut shape = first.clone();
345
346 for s in iter {
347 if s.rank() != shape.rank() {
348 return Err(ShapeError::RankMismatch {
349 left: shape.rank(),
350 right: s.rank(),
351 });
352 }
353
354 if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
355 return Err(ShapeError::IncompatibleShapes {
356 left: shape.clone(),
357 right: s.clone(),
358 });
359 }
360
361 shape[dim] += s[dim];
362 }
363
364 Ok(shape)
365 }
366
367 pub fn slice(mut self, slices: &[Slice]) -> Result<Self, ShapeError> {
369 if slices.len() > self.rank() {
370 return Err(ShapeError::RankMismatch {
371 left: self.rank(),
372 right: slices.len(),
373 });
374 }
375
376 slices
377 .iter()
378 .zip(self.iter_mut())
379 .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
380
381 Ok(self)
382 }
383
384 pub fn broadcast(&self, other: &Self) -> Result<Self, ShapeError> {
393 Self::broadcast_many([self, other])
394 }
395
396 pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, ShapeError>
400 where
401 I: IntoIterator<Item = &'a Shape>,
402 {
403 let mut iter = shapes.into_iter();
404 let mut broadcasted = iter.next().ok_or(ShapeError::Empty)?.clone();
405 let rank = broadcasted.rank();
406
407 for shape in iter {
408 if shape.rank() != rank {
409 return Err(ShapeError::RankMismatch {
410 left: rank,
411 right: shape.rank(),
412 });
413 }
414
415 for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
416 match (*d_lhs, d_rhs) {
417 (a, b) if a == b => {} (1, b) => *d_lhs = b, (_a, 1) => {} _ => {
421 return Err(ShapeError::IncompatibleDims {
422 left: *d_lhs,
423 right: d_rhs,
424 dim,
425 });
426 }
427 }
428 }
429 }
430
431 Ok(broadcasted)
432 }
433
434 pub fn expand(&self, target: Shape) -> Result<Shape, ShapeError> {
436 let target_rank = target.rank();
437 if self.rank() > target_rank {
438 return Err(ShapeError::RankMismatch {
439 left: self.rank(),
440 right: target_rank,
441 });
442 }
443
444 for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() {
445 if dim_self != dim_target && *dim_self != 1 {
446 return Err(ShapeError::IncompatibleDims {
447 left: *dim_self,
448 right: *dim_target,
449 dim: target_rank - i - 1,
450 });
451 }
452 }
453
454 Ok(target)
455 }
456
457 pub fn reshape(&self, target: Shape) -> Result<Shape, ShapeError> {
459 if self.num_elements() != target.num_elements() {
460 return Err(ShapeError::IncompatibleShapes {
461 left: self.clone(),
462 right: target,
463 });
464 }
465 Ok(target)
466 }
467}
468
469pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result<Shape, ShapeError> {
474 let rank = lhs.rank();
475 if rank != rhs.rank() {
476 return Err(ShapeError::RankMismatch {
477 left: rank,
478 right: rhs.rank(),
479 });
480 }
481
482 if lhs[rank - 1] != rhs[rank - 2] {
483 return Err(ShapeError::IncompatibleShapes {
484 left: lhs.clone(),
485 right: rhs.clone(),
486 });
487 }
488
489 let mut shape = if rank > 2 {
490 Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
492 } else {
493 Shape::new([])
494 };
495 shape.extend([lhs[rank - 2], rhs[rank - 1]]);
496
497 Ok(shape)
498}
499
500impl Display for Shape {
501 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
502 self.dims.fmt(f)
503 }
504}
505
506impl FromStr for Shape {
507 type Err = crate::ExpressionError;
508
509 fn from_str(source: &str) -> Result<Self, Self::Err> {
510 let mut s = source.trim();
511
512 const DELIMS: [(&str, &str); 2] = [("[", "]"), ("(", ")")];
513
514 for (open, close) in DELIMS {
515 if let Some(p) = s.strip_prefix(open) {
516 if let Some(p) = p.strip_suffix(close) {
517 s = p.trim();
518 break;
519 } else {
520 return Err(crate::ExpressionError::ParseError {
521 message: "Unbalanced delimiters".to_string(),
522 source: source.to_string(),
523 });
524 }
525 }
526 }
527
528 if s.is_empty() {
529 return Ok(Shape::new([]));
530 }
531
532 let dims =
533 s.split(',')
534 .map(|dim_str| {
535 dim_str.trim().parse::<usize>().map_err(|_| {
536 crate::ExpressionError::ParseError {
537 message: "Unable to parse shape".to_string(),
538 source: source.to_string(),
539 }
540 })
541 })
542 .collect::<Result<Vec<usize>, crate::ExpressionError>>()?;
543
544 if dims.is_empty() {
545 unreachable!("Split should have returned at least one element");
546 }
547
548 Ok(Shape { dims })
549 }
550}
551
552impl IntoIterator for Shape {
553 type Item = usize;
554 type IntoIter = alloc::vec::IntoIter<Self::Item>;
555
556 fn into_iter(self) -> Self::IntoIter {
557 self.dims.into_iter()
558 }
559}
560
561impl<Idx> Index<Idx> for Shape
562where
563 Idx: SliceIndex<[usize]>,
564{
565 type Output = Idx::Output;
566
567 fn index(&self, index: Idx) -> &Self::Output {
568 &self.dims[index]
569 }
570}
571
572impl<Idx> IndexMut<Idx> for Shape
573where
574 Idx: SliceIndex<[usize]>,
575{
576 fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
577 &mut self.dims[index]
578 }
579}
580
581impl Deref for Shape {
583 type Target = [usize];
584
585 fn deref(&self) -> &Self::Target {
586 &self.dims
587 }
588}
589
590impl DerefMut for Shape {
592 fn deref_mut(&mut self) -> &mut Self::Target {
593 &mut self.dims
594 }
595}
596
597impl<const D: usize> From<[usize; D]> for Shape {
599 fn from(dims: [usize; D]) -> Self {
600 Shape::new(dims)
601 }
602}
603
604impl<const D: usize> From<[i64; D]> for Shape {
605 fn from(dims: [i64; D]) -> Self {
606 Shape {
607 dims: dims.into_iter().map(|d| d as usize).collect(),
608 }
609 }
610}
611
612impl<const D: usize> From<[i32; D]> for Shape {
613 fn from(dims: [i32; D]) -> Self {
614 Shape {
615 dims: dims.into_iter().map(|d| d as usize).collect(),
616 }
617 }
618}
619
620impl From<&[usize]> for Shape {
621 fn from(dims: &[usize]) -> Self {
622 Shape { dims: dims.into() }
623 }
624}
625
626impl From<Vec<i64>> for Shape {
627 fn from(shape: Vec<i64>) -> Self {
628 Self {
629 dims: shape.into_iter().map(|d| d as usize).collect(),
630 }
631 }
632}
633
634impl From<Vec<u64>> for Shape {
635 fn from(shape: Vec<u64>) -> Self {
636 Self {
637 dims: shape.into_iter().map(|d| d as usize).collect(),
638 }
639 }
640}
641
642impl From<Vec<usize>> for Shape {
643 fn from(shape: Vec<usize>) -> Self {
644 Self { dims: shape }
645 }
646}
647
648impl From<&Vec<usize>> for Shape {
649 fn from(shape: &Vec<usize>) -> Self {
650 Self {
651 dims: shape.clone(),
652 }
653 }
654}
655
656impl From<Shape> for Vec<usize> {
657 fn from(shape: Shape) -> Self {
658 shape.dims
659 }
660}
661
662#[cfg(test)]
663#[allow(clippy::identity_op, reason = "useful for clarity")]
664mod tests {
665 use super::*;
666 use crate::s;
667 use alloc::string::ToString;
668 use alloc::vec;
669
670 #[test]
671 fn test_shape_to_str() {
672 let shape = Shape::new([2, 3, 4, 5]);
673 assert_eq!(shape.to_string(), "[2, 3, 4, 5]");
674 }
675
676 #[test]
677 fn test_shape_from_str() {
678 assert_eq!(
679 "[2, 3, 4, 5]".parse::<Shape>().unwrap(),
680 Shape::new([2, 3, 4, 5])
681 );
682 assert_eq!(
683 "(2, 3, 4, 5)".parse::<Shape>().unwrap(),
684 Shape::new([2, 3, 4, 5])
685 );
686 assert_eq!(
687 "2, 3, 4, 5".parse::<Shape>().unwrap(),
688 Shape::new([2, 3, 4, 5])
689 );
690
691 assert_eq!("[2]".parse::<Shape>().unwrap(), Shape::new([2]));
692 assert_eq!("(2)".parse::<Shape>().unwrap(), Shape::new([2]));
693 assert_eq!("2".parse::<Shape>().unwrap(), Shape::new([2]));
694
695 assert_eq!("[]".parse::<Shape>().unwrap(), Shape::new([]));
696 assert_eq!("".parse::<Shape>().unwrap(), Shape::new([]));
697
698 assert_eq!(
699 "[".parse::<Shape>(),
700 Err(crate::ExpressionError::ParseError {
701 message: "Unbalanced delimiters".to_string(),
702 source: "[".to_string()
703 })
704 );
705
706 assert_eq!(
707 "[[1]".parse::<Shape>(),
708 Err(crate::ExpressionError::ParseError {
709 message: "Unable to parse shape".to_string(),
710 source: "[[1]".to_string()
711 })
712 );
713 assert_eq!(
714 "[[1]]".parse::<Shape>(),
715 Err(crate::ExpressionError::ParseError {
716 message: "Unable to parse shape".to_string(),
717 source: "[[1]]".to_string()
718 })
719 );
720 assert_eq!(
721 "[1)".parse::<Shape>(),
722 Err(crate::ExpressionError::ParseError {
723 message: "Unbalanced delimiters".to_string(),
724 source: "[1)".to_string()
725 })
726 );
727
728 assert_eq!(
729 "]".parse::<Shape>(),
730 Err(crate::ExpressionError::ParseError {
731 message: "Unable to parse shape".to_string(),
732 source: "]".to_string()
733 })
734 );
735
736 assert_eq!(
737 "[a]".parse::<Shape>(),
738 Err(crate::ExpressionError::ParseError {
739 message: "Unable to parse shape".to_string(),
740 source: "[a]".to_string()
741 })
742 );
743 }
744
745 #[test]
746 fn num_dims_and_rank() {
747 let dims = [2, 3, 4, 5];
748 let shape = Shape::new(dims);
749 assert_eq!(4, shape.num_dims());
750 assert_eq!(4, shape.rank());
751 }
752
753 #[test]
754 fn num_elements() {
755 let dims = [2, 3, 4, 5];
756 let shape = Shape::new(dims);
757 assert_eq!(120, shape.num_elements());
758 }
759
760 #[test]
761 fn test_shape_into_iter() {
762 let dims = [2, 3, 4, 5];
763 let shape = Shape::new(dims);
764
765 assert_eq!(shape.into_iter().sum::<usize>(), 14);
766 }
767
768 #[test]
769 fn test_into_ranges() {
770 let dims = [2, 3, 4, 5];
771 let shape = Shape::new(dims);
772 assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
773 }
774
775 #[test]
776 fn test_to_vec() {
777 let dims = [2, 3, 4, 5];
778 let shape = Shape::new(dims);
779 assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
780 }
781
782 #[allow(clippy::single_range_in_vec_init)]
783 #[test]
784 fn test_into_slices() {
785 let slices = Shape::new([3]).into_slices(1..4);
786 assert_eq!(slices[0].to_range(3), 1..3);
787
788 let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
789 assert_eq!(slices[0].to_range(3), 1..3);
790 assert_eq!(slices[1].to_range(4), 0..2);
791
792 let slices = Shape::new([3]).into_slices(..-2);
793 assert_eq!(slices[0].to_range(3), 0..1);
794
795 let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
796 assert_eq!(slices[0].to_range(2), 0..2);
797 assert_eq!(slices[1].to_range(3), 1..2);
798
799 let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
800 assert_eq!(slices[0].to_range(2), 0..2);
801 assert_eq!(slices[1].to_range(3), 2..3);
802 }
803
804 #[test]
805 fn test_shape_index() {
806 let shape = Shape::new([2, 3, 4, 5]);
807
808 assert_eq!(shape[0], 2);
809 assert_eq!(shape[1], 3);
810 assert_eq!(shape[2], 4);
811 assert_eq!(shape[3], 5);
812
813 assert_eq!(shape[1..3], *&[3, 4]);
815 assert_eq!(shape[1..=2], *&[3, 4]);
816 assert_eq!(shape[..], *&[2, 3, 4, 5]);
817 }
818
819 #[test]
820 fn test_shape_slice_methods() {
821 let shape = Shape::new([2, 3, 4, 5]);
822
823 let dim = shape.first();
824 assert_eq!(dim, Some(&2));
825 let dim = shape.last();
826 assert_eq!(dim, Some(&5));
827
828 assert!(!shape.is_empty());
829 let shape = Shape::new([]);
830 assert!(shape.is_empty());
831 }
832
833 #[test]
834 fn test_shape_iter() {
835 let dims = [2, 3, 4, 5];
836 let shape = Shape::new(dims);
837
838 for (d, sd) in dims.iter().zip(shape.iter()) {
839 assert_eq!(d, sd);
840 }
841 }
842
843 #[test]
844 fn test_shape_iter_mut() {
845 let mut shape = Shape::new([2, 3, 4, 5]);
846
847 for d in shape.iter_mut() {
848 *d += 1;
849 }
850
851 assert_eq!(&shape.dims, &[3, 4, 5, 6]);
852 }
853
854 #[test]
855 fn test_shape_as_slice() {
856 let dims = [2, 3, 4, 5];
857 let shape = Shape::new(dims);
858
859 assert_eq!(shape.as_slice(), dims.as_slice());
860
861 let shape_slice: &[usize] = &shape;
863 assert_eq!(shape_slice, *&[2, 3, 4, 5]);
864 }
865
866 #[test]
867 fn test_shape_as_mut_slice() {
868 let mut dims = [2, 3, 4, 5];
869 let mut shape = Shape::new(dims);
870
871 let shape_mut = shape.as_mut_slice();
872 assert_eq!(shape_mut, dims.as_mut_slice());
873 shape_mut[1] = 6;
874
875 assert_eq!(shape_mut, &[2, 6, 4, 5]);
876
877 let mut shape = Shape::new(dims);
878 let shape = &mut shape[..];
879 shape[1] = 6;
880
881 assert_eq!(shape, shape_mut)
882 }
883
884 #[test]
885 fn test_shape_flatten() {
886 let shape = Shape::new([2, 3, 4, 5]);
887 assert_eq!(shape.num_elements(), 120);
888
889 let shape = shape.flatten();
890 assert_eq!(shape.num_elements(), 120);
891 assert_eq!(&shape.dims, &[120]);
892 }
893
894 #[test]
895 fn test_ravel() {
896 let shape = Shape::new([2, 3, 4, 5]);
897
898 assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0);
899 assert_eq!(
900 shape.ravel_index(&[1, 2, 3, 4]),
901 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
902 );
903 }
904
905 #[test]
906 fn test_shape_insert_remove_push() {
907 let dims = [2, 3, 4, 5];
908 let mut shape = Shape::new(dims);
909 let size = 6;
910 shape.insert(1, size);
911
912 assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
913
914 let removed = shape.remove(1);
915 assert_eq!(removed, size);
916 assert_eq!(shape, Shape::new(dims));
917
918 shape.push(6);
919 assert_eq!(shape, Shape::new([2, 3, 4, 5, 6]));
920 }
921
922 #[test]
923 fn test_shape_swap_permute() {
924 let dims = [2, 3, 4, 5];
925 let shape = Shape::new(dims);
926 let shape = shape.swap(1, 2).unwrap();
927
928 assert_eq!(&shape.dims, &[2, 4, 3, 5]);
929
930 let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
931 assert_eq!(shape, Shape::new(dims));
932 }
933
934 #[test]
935 #[should_panic]
936 fn test_shape_swap_out_of_bounds() {
937 let shape = Shape::new([2, 3, 4, 5]);
938
939 shape.swap(0, 4).unwrap();
940 }
941
942 #[test]
943 #[should_panic]
944 fn test_shape_permute_incomplete() {
945 let shape = Shape::new([2, 3, 4, 5]);
946
947 shape.permute(&[0, 2, 1]).unwrap();
948 }
949
950 #[test]
951 fn test_shape_repeat() {
952 let shape = Shape::new([2, 3, 4, 5]);
953
954 let out = shape.repeat(2, 3).unwrap();
955 assert_eq!(out, Shape::new([2, 3, 12, 5]));
956 }
957
958 #[test]
959 fn test_shape_repeat_invalid() {
960 let shape = Shape::new([2, 3, 4, 5]);
961
962 let out = shape.repeat(5, 3);
963 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
964 }
965
966 #[test]
967 fn test_shape_reduce() {
968 let shape = Shape::new([2, 3, 4, 5]);
969
970 let out = shape.reduce(2).unwrap();
971 assert_eq!(out, Shape::new([2, 3, 1, 5]));
972 }
973
974 #[test]
975 fn test_shape_reduce_invalid() {
976 let shape = Shape::new([2, 3, 4, 5]);
977
978 let out = shape.reduce(5);
979 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
980 }
981
982 #[test]
983 fn test_shape_broadcast_binary() {
984 let lhs = Shape::new([1, 1, 2, 4]);
985 let rhs = Shape::new([7, 6, 2, 1]);
986
987 let out = lhs.broadcast(&rhs).unwrap();
988 assert_eq!(out, Shape::new([7, 6, 2, 4]));
989 }
990
991 #[test]
992 fn test_shape_broadcast_rank_mismatch() {
993 let lhs = Shape::new([1, 2, 4]);
994 let rhs = Shape::new([7, 6, 2, 4]);
995
996 let out = lhs.broadcast(&rhs);
997 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
998 }
999
1000 #[test]
1001 fn test_shape_broadcast_incompatible_dims() {
1002 let lhs = Shape::new([1, 2, 2, 4]);
1003 let rhs = Shape::new([7, 6, 2, 1]);
1004
1005 let out = lhs.broadcast(&rhs);
1006 assert_eq!(
1007 out,
1008 Err(ShapeError::IncompatibleDims {
1009 left: 2,
1010 right: 6,
1011 dim: 1
1012 })
1013 );
1014 }
1015
1016 #[test]
1017 fn test_shape_broadcast_many() {
1018 let s1 = Shape::new([1, 1, 2, 4]);
1019 let s2 = Shape::new([7, 1, 2, 1]);
1020 let s3 = Shape::new([7, 6, 1, 1]);
1021
1022 let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
1023 assert_eq!(out, Shape::new([7, 6, 2, 4]));
1024 }
1025
1026 #[test]
1027 fn test_shape_broadcast_many_rank_mismatch() {
1028 let s1 = Shape::new([1, 1, 2, 4]);
1029 let s2 = Shape::new([7, 1, 2, 1]);
1030 let s3 = Shape::new([1, 6, 1]);
1031
1032 let out = Shape::broadcast_many([&s1, &s2, &s3]);
1033 assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 }));
1034 }
1035
1036 #[test]
1037 fn test_shape_broadcast_many_incompatible_dims() {
1038 let s1 = Shape::new([1, 1, 2, 4]);
1039 let s2 = Shape::new([7, 1, 2, 1]);
1040 let s3 = Shape::new([4, 6, 1, 1]);
1041
1042 let out = Shape::broadcast_many([&s1, &s2, &s3]);
1043 assert_eq!(
1044 out,
1045 Err(ShapeError::IncompatibleDims {
1046 left: 7,
1047 right: 4,
1048 dim: 0
1049 })
1050 );
1051 }
1052
1053 #[test]
1054 fn test_shape_broadcast_many_empty() {
1055 let out = Shape::broadcast_many(&[]);
1056 assert_eq!(out, Err(ShapeError::Empty));
1057 }
1058
1059 #[test]
1060 fn test_shape_matmul_2d() {
1061 let lhs = Shape::new([2, 4]);
1062 let rhs = Shape::new([4, 2]);
1063 let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1064 assert_eq!(out, Shape::new([2, 2]));
1065 }
1066
1067 #[test]
1068 fn test_shape_matmul_4d_broadcasted() {
1069 let lhs = Shape::new([1, 3, 2, 4]);
1070 let rhs = Shape::new([2, 1, 4, 2]);
1071 let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1072 assert_eq!(out, Shape::new([2, 3, 2, 2]));
1073 }
1074
1075 #[test]
1076 fn test_shape_matmul_invalid_rank() {
1077 let lhs = Shape::new([3, 2, 4]);
1078 let rhs = Shape::new([2, 1, 4, 2]);
1079 let out = calculate_matmul_output(&lhs, &rhs);
1080 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
1081 }
1082
1083 #[test]
1084 fn test_shape_matmul_invalid_shape() {
1085 let lhs = Shape::new([1, 3, 2, 4]);
1086 let rhs = Shape::new([2, 1, 3, 2]);
1087 let out = calculate_matmul_output(&lhs, &rhs);
1088 assert_eq!(
1089 out,
1090 Err(ShapeError::IncompatibleShapes {
1091 left: lhs,
1092 right: rhs
1093 })
1094 );
1095 }
1096
1097 #[test]
1098 fn test_shape_matmul_invalid_broadcast() {
1099 let lhs = Shape::new([1, 3, 2, 4]);
1100 let rhs = Shape::new([2, 2, 4, 2]);
1101 let out = calculate_matmul_output(&lhs, &rhs);
1102 assert_eq!(
1103 out,
1104 Err(ShapeError::IncompatibleDims {
1105 left: 3,
1106 right: 2,
1107 dim: 1
1108 })
1109 );
1110 }
1111
1112 #[test]
1113 fn test_shape_cat() {
1114 let s1 = Shape::new([2, 3, 4, 5]);
1115 let s2 = Shape::new([1, 3, 4, 5]);
1116 let s3 = Shape::new([4, 3, 4, 5]);
1117
1118 let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
1119 assert_eq!(out, Shape::new([7, 3, 4, 5]));
1120
1121 let s1 = Shape::new([2, 3, 4, 5]);
1122 let s2 = Shape::new([2, 3, 2, 5]);
1123 let s3 = Shape::new([2, 3, 1, 5]);
1124
1125 let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
1126 assert_eq!(out, Shape::new([2, 3, 7, 5]));
1127 }
1128
1129 #[test]
1130 fn test_shape_cat_empty() {
1131 let out = Shape::cat(&[], 0);
1132 assert_eq!(out, Err(ShapeError::Empty));
1133 }
1134
1135 #[test]
1136 fn test_shape_cat_dim_out_of_bounds() {
1137 let s1 = Shape::new([2, 3, 4, 5]);
1138 let s2 = Shape::new([2, 3, 4, 5]);
1139 let out = Shape::cat(&[s1, s2], 4);
1140 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 }));
1141 }
1142
1143 #[test]
1144 fn test_shape_cat_rank_mismatch() {
1145 let s1 = Shape::new([2, 3, 4, 5]);
1146 let s2 = Shape::new([2, 3, 4, 5, 6]);
1147 let out = Shape::cat(&[s1, s2], 0);
1148 assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 }));
1149 }
1150
1151 #[test]
1152 fn test_shape_cat_incompatible_shapes() {
1153 let s1 = Shape::new([2, 3, 4, 5]);
1154 let s2 = Shape::new([1, 3, 4, 5]);
1155 let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
1156
1157 assert_eq!(
1158 out,
1159 Err(ShapeError::IncompatibleShapes {
1160 left: s1,
1161 right: s2
1162 })
1163 );
1164 }
1165
1166 #[test]
1167 fn test_shape_slice_output_shape_basic() {
1168 let slices = [
1170 Slice::new(0, Some(5), 1), Slice::new(2, Some(8), 1), ];
1173 let original_shape = Shape::new([10, 10, 10]);
1174 let result = original_shape.slice(&slices).unwrap();
1175 assert_eq!(result, Shape::new([5, 6, 10]));
1176 }
1177
1178 #[test]
1179 fn test_shape_slice_output_shape_with_positive_steps() {
1180 let slices = [
1182 Slice::new(0, Some(10), 2), Slice::new(1, Some(9), 3), Slice::new(0, Some(7), 4), ];
1186 let original_shape = Shape::new([20, 20, 20, 30]);
1187 let result = original_shape.slice(&slices).unwrap();
1188 assert_eq!(result, Shape::new([5, 3, 2, 30]));
1189 }
1190
1191 #[test]
1192 fn test_shape_slice_output_shape_with_negative_steps() {
1193 let slices = [
1195 Slice::new(0, Some(10), -1), Slice::new(2, Some(8), -2), ];
1198 let original_shape = Shape::new([20, 20, 20]);
1199 let result = original_shape.slice(&slices).unwrap();
1200 assert_eq!(result, Shape::new([10, 3, 20]));
1201 }
1202
1203 #[test]
1204 fn test_shape_slice_output_shape_mixed_steps() {
1205 let slices = [
1207 Slice::from_range_stepped(1..6, 1), Slice::from_range_stepped(0..10, -3), Slice::from_range_stepped(2..14, 4), ];
1211 let original_shape = Shape::new([20, 20, 20]);
1212 let result = original_shape.slice(&slices).unwrap();
1213 assert_eq!(result, Shape::new([5, 4, 3]));
1214 }
1215
1216 #[test]
1217 fn test_shape_slice_output_shape_partial_dims() {
1218 let slices = [
1220 Slice::from_range_stepped(2..7, 2), ];
1222 let original_shape = Shape::new([10, 20, 30, 40]);
1223 let result = original_shape.slice(&slices).unwrap();
1224 assert_eq!(result, Shape::new([3, 20, 30, 40]));
1225 }
1226
1227 #[test]
1228 fn test_shape_slice_output_shape_edge_cases() {
1229 let slices = [
1231 Slice::from_range_stepped(0..1, 1), Slice::from_range_stepped(0..10, 100), Slice::from_range_stepped(5..5, 1), ];
1235 let original_shape = Shape::new([10, 20, 30]);
1236 let result = original_shape.slice(&slices).unwrap();
1237 assert_eq!(result, Shape::new([1, 1, 0]));
1238 }
1239
1240 #[test]
1241 fn test_shape_slice_output_shape_empty() {
1242 let slices = [];
1244 let original_shape = Shape::new([10, 20, 30]);
1245 let result = original_shape.slice(&slices).unwrap();
1246 assert_eq!(result, Shape::new([10, 20, 30]));
1247 }
1248
1249 #[test]
1250 fn test_shape_slice_output_shape_uneven_division() {
1251 let slices = [
1253 Slice::from_range_stepped(0..7, 3), Slice::from_range_stepped(0..11, 4), Slice::from_range_stepped(1..10, 5), ];
1257 let original_shape = Shape::new([20, 20, 20]);
1258 let result = original_shape.slice(&slices).unwrap();
1259 assert_eq!(result, Shape::new([3, 3, 2]));
1260 }
1261
1262 #[test]
1263 fn test_shape_expand() {
1264 let shape = Shape::new([1, 3, 1]);
1265 let expanded = Shape::new([2, 3, 4]);
1266 let out = shape.expand(expanded.clone()).unwrap();
1267 assert_eq!(out, expanded);
1268 }
1269
1270 #[test]
1271 fn test_shape_expand_higher_rank() {
1272 let shape = Shape::new([1, 4]);
1273 let expanded = Shape::new([2, 3, 4]);
1274 let out = shape.expand(expanded.clone()).unwrap();
1275 assert_eq!(out, expanded);
1276 }
1277
1278 #[test]
1279 fn test_shape_expand_invalid_rank() {
1280 let shape = Shape::new([1, 3, 1]);
1281 let expanded = Shape::new([3, 4]);
1282 let out = shape.expand(expanded);
1283 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 2 }));
1284 }
1285
1286 #[test]
1287 fn test_shape_expand_incompatible_dims() {
1288 let shape = Shape::new([1, 3, 2]);
1289 let expanded = Shape::new([2, 3, 4]);
1290 let out = shape.expand(expanded);
1291 assert_eq!(
1292 out,
1293 Err(ShapeError::IncompatibleDims {
1294 left: 2,
1295 right: 4,
1296 dim: 2
1297 })
1298 );
1299 }
1300
1301 #[test]
1302 fn test_shape_reshape() {
1303 let shape = Shape::new([2, 3, 4, 5]);
1304 let reshaped = Shape::new([1, 2, 12, 5]);
1305 let out = shape.reshape(reshaped.clone()).unwrap();
1306 assert_eq!(out, reshaped);
1307 }
1308
1309 #[test]
1310 fn test_shape_reshape_invalid() {
1311 let shape = Shape::new([2, 3, 4, 5]);
1312 let reshaped = Shape::new([2, 2, 12, 5]);
1313 let out = shape.clone().reshape(reshaped.clone());
1314 assert_eq!(
1315 out,
1316 Err(ShapeError::IncompatibleShapes {
1317 left: shape,
1318 right: reshaped
1319 })
1320 );
1321 }
1322
1323 #[test]
1324 fn test_flatten_dims() {
1325 let shape = Shape::new([2, 3, 4, 5]);
1326 let flattened = shape.flatten_dims(-2, 3);
1327 assert_eq!(flattened, Shape::new([2, 3, 20]));
1328 }
1329}