1use alloc::vec::Vec;
4use core::{
5 ops::{Deref, DerefMut, Index, IndexMut, Range},
6 slice::{Iter, IterMut, SliceIndex},
7};
8use serde::{Deserialize, Serialize};
9
10use super::indexing::ravel_index;
11use super::{AsIndex, Slice, SliceArg};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct Shape {
16 pub dims: Vec<usize>,
18}
19
20#[allow(missing_docs)]
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum ShapeError {
24 RankMismatch { left: usize, right: usize },
26 IncompatibleDims {
28 left: usize,
29 right: usize,
30 dim: usize,
31 },
32 OutOfBounds { dim: usize, rank: usize },
34 IncompatibleShapes { left: Shape, right: Shape },
36 Empty,
38}
39
40impl Shape {
41 pub fn new<const D: usize>(dims: [usize; D]) -> Self {
43 Self {
45 dims: dims.to_vec(),
46 }
47 }
48
49 pub fn num_elements(&self) -> usize {
51 self.dims.iter().product()
52 }
53
54 pub fn num_dims(&self) -> usize {
58 self.dims.len()
59 }
60
61 pub fn rank(&self) -> usize {
65 self.num_dims()
66 }
67
68 pub fn dims<const D: usize>(&self) -> [usize; D] {
71 let mut dims = [1; D];
72 dims[..D].copy_from_slice(&self.dims[..D]);
73 dims
74 }
75
76 pub fn flatten(mut self) -> Self {
78 self.dims = [self.num_elements()].into();
79 self
80 }
81
82 pub fn ravel_index<I: AsIndex>(&self, indices: &[I]) -> usize {
96 ravel_index(indices, &self.dims)
97 }
98
99 pub fn into_ranges(self) -> Vec<Range<usize>> {
101 self.into_iter().map(|d| 0..d).collect()
102 }
103
104 pub fn into_slices<const D: usize, S>(self, slices: S) -> [Slice; D]
164 where
165 S: SliceArg<D>,
166 {
167 slices.into_slices(self)
168 }
169
170 pub fn to_vec(&self) -> Vec<usize> {
172 self.dims.clone()
173 }
174
175 pub fn iter(&self) -> Iter<'_, usize> {
177 self.dims.iter()
178 }
179
180 pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
182 self.dims.iter_mut()
183 }
184
185 pub fn as_slice(&self) -> &[usize] {
187 &self.dims
188 }
189
190 pub fn as_mut_slice(&mut self) -> &mut [usize] {
192 &mut self.dims
193 }
194
195 pub fn insert(&mut self, index: usize, size: usize) {
197 self.dims.insert(index, size);
198 }
199
200 pub fn remove(&mut self, index: usize) -> usize {
202 self.dims.remove(index)
203 }
204
205 pub fn push(&mut self, size: usize) {
207 self.dims.push(size)
208 }
209
210 pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
212 self.dims.extend(iter)
213 }
214
215 pub fn swap(mut self, dim1: usize, dim2: usize) -> Result<Self, ShapeError> {
217 if dim1 > self.rank() {
218 return Err(ShapeError::OutOfBounds {
219 dim: dim1,
220 rank: self.rank(),
221 });
222 }
223 if dim2 > self.rank() {
224 return Err(ShapeError::OutOfBounds {
225 dim: dim2,
226 rank: self.rank(),
227 });
228 }
229 self.dims.swap(dim1, dim2);
230 Ok(self)
231 }
232
233 pub fn permute(mut self, axes: &[usize]) -> Result<Self, ShapeError> {
235 if axes.len() != self.rank() {
236 return Err(ShapeError::RankMismatch {
237 left: self.rank(),
238 right: axes.len(),
239 });
240 }
241 debug_assert!(axes.iter().all(|i| i < &self.rank()));
242
243 self.dims = axes.iter().map(|&i| self.dims[i]).collect();
244 Ok(self)
245 }
246
247 pub fn repeat(mut self, dim: usize, times: usize) -> Result<Shape, ShapeError> {
249 if dim >= self.rank() {
250 return Err(ShapeError::OutOfBounds {
251 dim,
252 rank: self.rank(),
253 });
254 }
255
256 self.dims[dim] *= times;
257 Ok(self)
258 }
259
260 pub fn reduce(mut self, dim: usize) -> Result<Shape, ShapeError> {
262 if dim >= self.rank() {
263 return Err(ShapeError::OutOfBounds {
264 dim,
265 rank: self.rank(),
266 });
267 }
268
269 self.dims[dim] = 1;
270 Ok(self)
271 }
272
273 pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, ShapeError>
275 where
276 I: IntoIterator<Item = &'a Shape>,
277 {
278 let mut iter = shapes.into_iter();
279
280 let first = iter.next().ok_or(ShapeError::Empty)?;
281
282 if dim >= first.rank() {
283 return Err(ShapeError::OutOfBounds {
284 dim,
285 rank: first.rank(),
286 });
287 }
288
289 let mut shape = first.clone();
290
291 for s in iter {
292 if s.rank() != shape.rank() {
293 return Err(ShapeError::RankMismatch {
294 left: shape.rank(),
295 right: s.rank(),
296 });
297 }
298
299 if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
300 return Err(ShapeError::IncompatibleShapes {
301 left: shape.clone(),
302 right: s.clone(),
303 });
304 }
305
306 shape[dim] += s[dim];
307 }
308
309 Ok(shape)
310 }
311
312 pub fn slice(mut self, slices: &[Slice]) -> Result<Self, ShapeError> {
314 if slices.len() > self.rank() {
315 return Err(ShapeError::RankMismatch {
316 left: self.rank(),
317 right: slices.len(),
318 });
319 }
320
321 slices
322 .iter()
323 .zip(self.iter_mut())
324 .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
325
326 Ok(self)
327 }
328
329 pub fn broadcast(&self, other: &Self) -> Result<Self, ShapeError> {
338 Self::broadcast_many([self, other])
339 }
340
341 pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, ShapeError>
345 where
346 I: IntoIterator<Item = &'a Shape>,
347 {
348 let mut iter = shapes.into_iter();
349 let mut broadcasted = iter.next().ok_or(ShapeError::Empty)?.clone();
350 let rank = broadcasted.rank();
351
352 for shape in iter {
353 if shape.rank() != rank {
354 return Err(ShapeError::RankMismatch {
355 left: rank,
356 right: shape.rank(),
357 });
358 }
359
360 for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
361 match (*d_lhs, d_rhs) {
362 (a, b) if a == b => {} (1, b) => *d_lhs = b, (_a, 1) => {} _ => {
366 return Err(ShapeError::IncompatibleDims {
367 left: *d_lhs,
368 right: d_rhs,
369 dim,
370 });
371 }
372 }
373 }
374 }
375
376 Ok(broadcasted)
377 }
378
379 pub fn expand(&self, target: Shape) -> Result<Shape, ShapeError> {
381 let target_rank = target.rank();
382 if self.rank() > target_rank {
383 return Err(ShapeError::RankMismatch {
384 left: self.rank(),
385 right: target_rank,
386 });
387 }
388
389 for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() {
390 if dim_self != dim_target && *dim_self != 1 {
391 return Err(ShapeError::IncompatibleDims {
392 left: *dim_self,
393 right: *dim_target,
394 dim: target_rank - i - 1,
395 });
396 }
397 }
398
399 Ok(target)
400 }
401
402 pub fn reshape(&self, target: Shape) -> Result<Shape, ShapeError> {
404 if self.num_elements() != target.num_elements() {
405 return Err(ShapeError::IncompatibleShapes {
406 left: self.clone(),
407 right: target,
408 });
409 }
410 Ok(target)
411 }
412}
413
414pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result<Shape, ShapeError> {
419 let rank = lhs.rank();
420 if rank != rhs.rank() {
421 return Err(ShapeError::RankMismatch {
422 left: rank,
423 right: rhs.rank(),
424 });
425 }
426
427 if lhs[rank - 1] != rhs[rank - 2] {
428 return Err(ShapeError::IncompatibleShapes {
429 left: lhs.clone(),
430 right: rhs.clone(),
431 });
432 }
433
434 let mut shape = if rank > 2 {
435 Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
437 } else {
438 Shape::new([])
439 };
440 shape.extend([lhs[rank - 2], rhs[rank - 1]]);
441
442 Ok(shape)
443}
444
445impl IntoIterator for Shape {
446 type Item = usize;
447 type IntoIter = alloc::vec::IntoIter<Self::Item>;
448
449 fn into_iter(self) -> Self::IntoIter {
450 self.dims.into_iter()
451 }
452}
453
454impl<Idx> Index<Idx> for Shape
455where
456 Idx: SliceIndex<[usize]>,
457{
458 type Output = Idx::Output;
459
460 fn index(&self, index: Idx) -> &Self::Output {
461 &self.dims[index]
462 }
463}
464
465impl<Idx> IndexMut<Idx> for Shape
466where
467 Idx: SliceIndex<[usize]>,
468{
469 fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
470 &mut self.dims[index]
471 }
472}
473
474impl Deref for Shape {
476 type Target = [usize];
477
478 fn deref(&self) -> &Self::Target {
479 &self.dims
480 }
481}
482
483impl DerefMut for Shape {
485 fn deref_mut(&mut self) -> &mut Self::Target {
486 &mut self.dims
487 }
488}
489
490impl<const D: usize> From<[usize; D]> for Shape {
492 fn from(dims: [usize; D]) -> Self {
493 Shape::new(dims)
494 }
495}
496
497impl<const D: usize> From<[i64; D]> for Shape {
498 fn from(dims: [i64; D]) -> Self {
499 Shape {
500 dims: dims.into_iter().map(|d| d as usize).collect(),
501 }
502 }
503}
504
505impl<const D: usize> From<[i32; D]> for Shape {
506 fn from(dims: [i32; D]) -> Self {
507 Shape {
508 dims: dims.into_iter().map(|d| d as usize).collect(),
509 }
510 }
511}
512
513impl From<&[usize]> for Shape {
514 fn from(dims: &[usize]) -> Self {
515 Shape { dims: dims.into() }
516 }
517}
518
519impl From<Vec<i64>> for Shape {
520 fn from(shape: Vec<i64>) -> Self {
521 Self {
522 dims: shape.into_iter().map(|d| d as usize).collect(),
523 }
524 }
525}
526
527impl From<Vec<u64>> for Shape {
528 fn from(shape: Vec<u64>) -> Self {
529 Self {
530 dims: shape.into_iter().map(|d| d as usize).collect(),
531 }
532 }
533}
534
535impl From<Vec<usize>> for Shape {
536 fn from(shape: Vec<usize>) -> Self {
537 Self { dims: shape }
538 }
539}
540
541impl From<&Vec<usize>> for Shape {
542 fn from(shape: &Vec<usize>) -> Self {
543 Self {
544 dims: shape.clone(),
545 }
546 }
547}
548
549impl From<Shape> for Vec<usize> {
550 fn from(shape: Shape) -> Self {
551 shape.dims
552 }
553}
554
555#[cfg(test)]
556#[allow(clippy::identity_op, reason = "useful for clarity")]
557mod tests {
558 use super::*;
559 use crate::s;
560 use alloc::vec;
561
562 #[test]
563 fn num_dims_and_rank() {
564 let dims = [2, 3, 4, 5];
565 let shape = Shape::new(dims);
566 assert_eq!(4, shape.num_dims());
567 assert_eq!(4, shape.rank());
568 }
569
570 #[test]
571 fn num_elements() {
572 let dims = [2, 3, 4, 5];
573 let shape = Shape::new(dims);
574 assert_eq!(120, shape.num_elements());
575 }
576
577 #[test]
578 fn test_shape_into_iter() {
579 let dims = [2, 3, 4, 5];
580 let shape = Shape::new(dims);
581
582 assert_eq!(shape.into_iter().sum::<usize>(), 14);
583 }
584
585 #[test]
586 fn test_into_ranges() {
587 let dims = [2, 3, 4, 5];
588 let shape = Shape::new(dims);
589 assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
590 }
591
592 #[test]
593 fn test_to_vec() {
594 let dims = [2, 3, 4, 5];
595 let shape = Shape::new(dims);
596 assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
597 }
598
599 #[allow(clippy::single_range_in_vec_init)]
600 #[test]
601 fn test_into_slices() {
602 let slices = Shape::new([3]).into_slices(1..4);
603 assert_eq!(slices[0].to_range(3), 1..3);
604
605 let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
606 assert_eq!(slices[0].to_range(3), 1..3);
607 assert_eq!(slices[1].to_range(4), 0..2);
608
609 let slices = Shape::new([3]).into_slices(..-2);
610 assert_eq!(slices[0].to_range(3), 0..1);
611
612 let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
613 assert_eq!(slices[0].to_range(2), 0..2);
614 assert_eq!(slices[1].to_range(3), 1..2);
615
616 let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
617 assert_eq!(slices[0].to_range(2), 0..2);
618 assert_eq!(slices[1].to_range(3), 2..3);
619 }
620
621 #[test]
622 fn test_shape_index() {
623 let shape = Shape::new([2, 3, 4, 5]);
624
625 assert_eq!(shape[0], 2);
626 assert_eq!(shape[1], 3);
627 assert_eq!(shape[2], 4);
628 assert_eq!(shape[3], 5);
629
630 assert_eq!(shape[1..3], *&[3, 4]);
632 assert_eq!(shape[1..=2], *&[3, 4]);
633 assert_eq!(shape[..], *&[2, 3, 4, 5]);
634 }
635
636 #[test]
637 fn test_shape_slice_methods() {
638 let shape = Shape::new([2, 3, 4, 5]);
639
640 let dim = shape.first();
641 assert_eq!(dim, Some(&2));
642 let dim = shape.last();
643 assert_eq!(dim, Some(&5));
644
645 assert!(!shape.is_empty());
646 let shape = Shape::new([]);
647 assert!(shape.is_empty());
648 }
649
650 #[test]
651 fn test_shape_iter() {
652 let dims = [2, 3, 4, 5];
653 let shape = Shape::new(dims);
654
655 for (d, sd) in dims.iter().zip(shape.iter()) {
656 assert_eq!(d, sd);
657 }
658 }
659
660 #[test]
661 fn test_shape_iter_mut() {
662 let mut shape = Shape::new([2, 3, 4, 5]);
663
664 for d in shape.iter_mut() {
665 *d += 1;
666 }
667
668 assert_eq!(&shape.dims, &[3, 4, 5, 6]);
669 }
670
671 #[test]
672 fn test_shape_as_slice() {
673 let dims = [2, 3, 4, 5];
674 let shape = Shape::new(dims);
675
676 assert_eq!(shape.as_slice(), dims.as_slice());
677
678 let shape_slice: &[usize] = &shape;
680 assert_eq!(shape_slice, *&[2, 3, 4, 5]);
681 }
682
683 #[test]
684 fn test_shape_as_mut_slice() {
685 let mut dims = [2, 3, 4, 5];
686 let mut shape = Shape::new(dims);
687
688 let shape_mut = shape.as_mut_slice();
689 assert_eq!(shape_mut, dims.as_mut_slice());
690 shape_mut[1] = 6;
691
692 assert_eq!(shape_mut, &[2, 6, 4, 5]);
693
694 let mut shape = Shape::new(dims);
695 let shape = &mut shape[..];
696 shape[1] = 6;
697
698 assert_eq!(shape, shape_mut)
699 }
700
701 #[test]
702 fn test_shape_flatten() {
703 let shape = Shape::new([2, 3, 4, 5]);
704 assert_eq!(shape.num_elements(), 120);
705
706 let shape = shape.flatten();
707 assert_eq!(shape.num_elements(), 120);
708 assert_eq!(&shape.dims, &[120]);
709 }
710
711 #[test]
712 fn test_ravel() {
713 let shape = Shape::new([2, 3, 4, 5]);
714
715 assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0);
716 assert_eq!(
717 shape.ravel_index(&[1, 2, 3, 4]),
718 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
719 );
720 }
721
722 #[test]
723 fn test_shape_insert_remove_push() {
724 let dims = [2, 3, 4, 5];
725 let mut shape = Shape::new(dims);
726 let size = 6;
727 shape.insert(1, size);
728
729 assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
730
731 let removed = shape.remove(1);
732 assert_eq!(removed, size);
733 assert_eq!(shape, Shape::new(dims));
734
735 shape.push(6);
736 assert_eq!(shape, Shape::new([2, 3, 4, 5, 6]));
737 }
738
739 #[test]
740 fn test_shape_swap_permute() {
741 let dims = [2, 3, 4, 5];
742 let shape = Shape::new(dims);
743 let shape = shape.swap(1, 2).unwrap();
744
745 assert_eq!(&shape.dims, &[2, 4, 3, 5]);
746
747 let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
748 assert_eq!(shape, Shape::new(dims));
749 }
750
751 #[test]
752 #[should_panic]
753 fn test_shape_swap_out_of_bounds() {
754 let shape = Shape::new([2, 3, 4, 5]);
755
756 shape.swap(0, 4).unwrap();
757 }
758
759 #[test]
760 #[should_panic]
761 fn test_shape_permute_incomplete() {
762 let shape = Shape::new([2, 3, 4, 5]);
763
764 shape.permute(&[0, 2, 1]).unwrap();
765 }
766
767 #[test]
768 fn test_shape_repeat() {
769 let shape = Shape::new([2, 3, 4, 5]);
770
771 let out = shape.repeat(2, 3).unwrap();
772 assert_eq!(out, Shape::new([2, 3, 12, 5]));
773 }
774
775 #[test]
776 fn test_shape_repeat_invalid() {
777 let shape = Shape::new([2, 3, 4, 5]);
778
779 let out = shape.repeat(5, 3);
780 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
781 }
782
783 #[test]
784 fn test_shape_reduce() {
785 let shape = Shape::new([2, 3, 4, 5]);
786
787 let out = shape.reduce(2).unwrap();
788 assert_eq!(out, Shape::new([2, 3, 1, 5]));
789 }
790
791 #[test]
792 fn test_shape_reduce_invalid() {
793 let shape = Shape::new([2, 3, 4, 5]);
794
795 let out = shape.reduce(5);
796 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
797 }
798
799 #[test]
800 fn test_shape_broadcast_binary() {
801 let lhs = Shape::new([1, 1, 2, 4]);
802 let rhs = Shape::new([7, 6, 2, 1]);
803
804 let out = lhs.broadcast(&rhs).unwrap();
805 assert_eq!(out, Shape::new([7, 6, 2, 4]));
806 }
807
808 #[test]
809 fn test_shape_broadcast_rank_mismatch() {
810 let lhs = Shape::new([1, 2, 4]);
811 let rhs = Shape::new([7, 6, 2, 4]);
812
813 let out = lhs.broadcast(&rhs);
814 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
815 }
816
817 #[test]
818 fn test_shape_broadcast_incompatible_dims() {
819 let lhs = Shape::new([1, 2, 2, 4]);
820 let rhs = Shape::new([7, 6, 2, 1]);
821
822 let out = lhs.broadcast(&rhs);
823 assert_eq!(
824 out,
825 Err(ShapeError::IncompatibleDims {
826 left: 2,
827 right: 6,
828 dim: 1
829 })
830 );
831 }
832
833 #[test]
834 fn test_shape_broadcast_many() {
835 let s1 = Shape::new([1, 1, 2, 4]);
836 let s2 = Shape::new([7, 1, 2, 1]);
837 let s3 = Shape::new([7, 6, 1, 1]);
838
839 let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
840 assert_eq!(out, Shape::new([7, 6, 2, 4]));
841 }
842
843 #[test]
844 fn test_shape_broadcast_many_rank_mismatch() {
845 let s1 = Shape::new([1, 1, 2, 4]);
846 let s2 = Shape::new([7, 1, 2, 1]);
847 let s3 = Shape::new([1, 6, 1]);
848
849 let out = Shape::broadcast_many([&s1, &s2, &s3]);
850 assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 }));
851 }
852
853 #[test]
854 fn test_shape_broadcast_many_incompatible_dims() {
855 let s1 = Shape::new([1, 1, 2, 4]);
856 let s2 = Shape::new([7, 1, 2, 1]);
857 let s3 = Shape::new([4, 6, 1, 1]);
858
859 let out = Shape::broadcast_many([&s1, &s2, &s3]);
860 assert_eq!(
861 out,
862 Err(ShapeError::IncompatibleDims {
863 left: 7,
864 right: 4,
865 dim: 0
866 })
867 );
868 }
869
870 #[test]
871 fn test_shape_broadcast_many_empty() {
872 let out = Shape::broadcast_many(&[]);
873 assert_eq!(out, Err(ShapeError::Empty));
874 }
875
876 #[test]
877 fn test_shape_matmul_2d() {
878 let lhs = Shape::new([2, 4]);
879 let rhs = Shape::new([4, 2]);
880 let out = calculate_matmul_output(&lhs, &rhs).unwrap();
881 assert_eq!(out, Shape::new([2, 2]));
882 }
883
884 #[test]
885 fn test_shape_matmul_4d_broadcasted() {
886 let lhs = Shape::new([1, 3, 2, 4]);
887 let rhs = Shape::new([2, 1, 4, 2]);
888 let out = calculate_matmul_output(&lhs, &rhs).unwrap();
889 assert_eq!(out, Shape::new([2, 3, 2, 2]));
890 }
891
892 #[test]
893 fn test_shape_matmul_invalid_rank() {
894 let lhs = Shape::new([3, 2, 4]);
895 let rhs = Shape::new([2, 1, 4, 2]);
896 let out = calculate_matmul_output(&lhs, &rhs);
897 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
898 }
899
900 #[test]
901 fn test_shape_matmul_invalid_shape() {
902 let lhs = Shape::new([1, 3, 2, 4]);
903 let rhs = Shape::new([2, 1, 3, 2]);
904 let out = calculate_matmul_output(&lhs, &rhs);
905 assert_eq!(
906 out,
907 Err(ShapeError::IncompatibleShapes {
908 left: lhs,
909 right: rhs
910 })
911 );
912 }
913
914 #[test]
915 fn test_shape_matmul_invalid_broadcast() {
916 let lhs = Shape::new([1, 3, 2, 4]);
917 let rhs = Shape::new([2, 2, 4, 2]);
918 let out = calculate_matmul_output(&lhs, &rhs);
919 assert_eq!(
920 out,
921 Err(ShapeError::IncompatibleDims {
922 left: 3,
923 right: 2,
924 dim: 1
925 })
926 );
927 }
928
929 #[test]
930 fn test_shape_cat() {
931 let s1 = Shape::new([2, 3, 4, 5]);
932 let s2 = Shape::new([1, 3, 4, 5]);
933 let s3 = Shape::new([4, 3, 4, 5]);
934
935 let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
936 assert_eq!(out, Shape::new([7, 3, 4, 5]));
937
938 let s1 = Shape::new([2, 3, 4, 5]);
939 let s2 = Shape::new([2, 3, 2, 5]);
940 let s3 = Shape::new([2, 3, 1, 5]);
941
942 let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
943 assert_eq!(out, Shape::new([2, 3, 7, 5]));
944 }
945
946 #[test]
947 fn test_shape_cat_empty() {
948 let out = Shape::cat(&[], 0);
949 assert_eq!(out, Err(ShapeError::Empty));
950 }
951
952 #[test]
953 fn test_shape_cat_dim_out_of_bounds() {
954 let s1 = Shape::new([2, 3, 4, 5]);
955 let s2 = Shape::new([2, 3, 4, 5]);
956 let out = Shape::cat(&[s1, s2], 4);
957 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 }));
958 }
959
960 #[test]
961 fn test_shape_cat_rank_mismatch() {
962 let s1 = Shape::new([2, 3, 4, 5]);
963 let s2 = Shape::new([2, 3, 4, 5, 6]);
964 let out = Shape::cat(&[s1, s2], 0);
965 assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 }));
966 }
967
968 #[test]
969 fn test_shape_cat_incompatible_shapes() {
970 let s1 = Shape::new([2, 3, 4, 5]);
971 let s2 = Shape::new([1, 3, 4, 5]);
972 let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
973
974 assert_eq!(
975 out,
976 Err(ShapeError::IncompatibleShapes {
977 left: s1,
978 right: s2
979 })
980 );
981 }
982
983 #[test]
984 fn test_shape_slice_output_shape_basic() {
985 let slices = [
987 Slice::new(0, Some(5), 1), Slice::new(2, Some(8), 1), ];
990 let original_shape = Shape::new([10, 10, 10]);
991 let result = original_shape.slice(&slices).unwrap();
992 assert_eq!(result, Shape::new([5, 6, 10]));
993 }
994
995 #[test]
996 fn test_shape_slice_output_shape_with_positive_steps() {
997 let slices = [
999 Slice::new(0, Some(10), 2), Slice::new(1, Some(9), 3), Slice::new(0, Some(7), 4), ];
1003 let original_shape = Shape::new([20, 20, 20, 30]);
1004 let result = original_shape.slice(&slices).unwrap();
1005 assert_eq!(result, Shape::new([5, 3, 2, 30]));
1006 }
1007
1008 #[test]
1009 fn test_shape_slice_output_shape_with_negative_steps() {
1010 let slices = [
1012 Slice::new(0, Some(10), -1), Slice::new(2, Some(8), -2), ];
1015 let original_shape = Shape::new([20, 20, 20]);
1016 let result = original_shape.slice(&slices).unwrap();
1017 assert_eq!(result, Shape::new([10, 3, 20]));
1018 }
1019
1020 #[test]
1021 fn test_shape_slice_output_shape_mixed_steps() {
1022 let slices = [
1024 Slice::from_range_stepped(1..6, 1), Slice::from_range_stepped(0..10, -3), Slice::from_range_stepped(2..14, 4), ];
1028 let original_shape = Shape::new([20, 20, 20]);
1029 let result = original_shape.slice(&slices).unwrap();
1030 assert_eq!(result, Shape::new([5, 4, 3]));
1031 }
1032
1033 #[test]
1034 fn test_shape_slice_output_shape_partial_dims() {
1035 let slices = [
1037 Slice::from_range_stepped(2..7, 2), ];
1039 let original_shape = Shape::new([10, 20, 30, 40]);
1040 let result = original_shape.slice(&slices).unwrap();
1041 assert_eq!(result, Shape::new([3, 20, 30, 40]));
1042 }
1043
1044 #[test]
1045 fn test_shape_slice_output_shape_edge_cases() {
1046 let slices = [
1048 Slice::from_range_stepped(0..1, 1), Slice::from_range_stepped(0..10, 100), Slice::from_range_stepped(5..5, 1), ];
1052 let original_shape = Shape::new([10, 20, 30]);
1053 let result = original_shape.slice(&slices).unwrap();
1054 assert_eq!(result, Shape::new([1, 1, 0]));
1055 }
1056
1057 #[test]
1058 fn test_shape_slice_output_shape_empty() {
1059 let slices = [];
1061 let original_shape = Shape::new([10, 20, 30]);
1062 let result = original_shape.slice(&slices).unwrap();
1063 assert_eq!(result, Shape::new([10, 20, 30]));
1064 }
1065
1066 #[test]
1067 fn test_shape_slice_output_shape_uneven_division() {
1068 let slices = [
1070 Slice::from_range_stepped(0..7, 3), Slice::from_range_stepped(0..11, 4), Slice::from_range_stepped(1..10, 5), ];
1074 let original_shape = Shape::new([20, 20, 20]);
1075 let result = original_shape.slice(&slices).unwrap();
1076 assert_eq!(result, Shape::new([3, 3, 2]));
1077 }
1078
1079 #[test]
1080 fn test_shape_expand() {
1081 let shape = Shape::new([1, 3, 1]);
1082 let expanded = Shape::new([2, 3, 4]);
1083 let out = shape.expand(expanded.clone()).unwrap();
1084 assert_eq!(out, expanded);
1085 }
1086
1087 #[test]
1088 fn test_shape_expand_higher_rank() {
1089 let shape = Shape::new([1, 4]);
1090 let expanded = Shape::new([2, 3, 4]);
1091 let out = shape.expand(expanded.clone()).unwrap();
1092 assert_eq!(out, expanded);
1093 }
1094
1095 #[test]
1096 fn test_shape_expand_invalid_rank() {
1097 let shape = Shape::new([1, 3, 1]);
1098 let expanded = Shape::new([3, 4]);
1099 let out = shape.expand(expanded);
1100 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 2 }));
1101 }
1102
1103 #[test]
1104 fn test_shape_expand_incompatible_dims() {
1105 let shape = Shape::new([1, 3, 2]);
1106 let expanded = Shape::new([2, 3, 4]);
1107 let out = shape.expand(expanded);
1108 assert_eq!(
1109 out,
1110 Err(ShapeError::IncompatibleDims {
1111 left: 2,
1112 right: 4,
1113 dim: 2
1114 })
1115 );
1116 }
1117
1118 #[test]
1119 fn test_shape_reshape() {
1120 let shape = Shape::new([2, 3, 4, 5]);
1121 let reshaped = Shape::new([1, 2, 12, 5]);
1122 let out = shape.reshape(reshaped.clone()).unwrap();
1123 assert_eq!(out, reshaped);
1124 }
1125
1126 #[test]
1127 fn test_shape_reshape_invalid() {
1128 let shape = Shape::new([2, 3, 4, 5]);
1129 let reshaped = Shape::new([2, 2, 12, 5]);
1130 let out = shape.clone().reshape(reshaped.clone());
1131 assert_eq!(
1132 out,
1133 Err(ShapeError::IncompatibleShapes {
1134 left: shape,
1135 right: reshaped
1136 })
1137 );
1138 }
1139}