1use crate::{Slice, SliceArg};
2use alloc::vec::Vec;
3use core::{
4 ops::{Deref, DerefMut, Index, IndexMut, Range},
5 slice::{Iter, IterMut, SliceIndex},
6};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct Shape {
12 pub dims: Vec<usize>,
14}
15
16#[allow(missing_docs)]
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum ShapeError {
20 RankMismatch { left: usize, right: usize },
22 IncompatibleDims {
24 left: usize,
25 right: usize,
26 dim: usize,
27 },
28 OutOfBounds { dim: usize, rank: usize },
30 IncompatibleShapes { left: Shape, right: Shape },
32 Empty,
34}
35
36impl Shape {
37 pub fn new<const D: usize>(dims: [usize; D]) -> Self {
39 Self {
41 dims: dims.to_vec(),
42 }
43 }
44
45 pub fn num_elements(&self) -> usize {
47 self.dims.iter().product()
48 }
49
50 pub fn num_dims(&self) -> usize {
54 self.dims.len()
55 }
56
57 pub fn rank(&self) -> usize {
61 self.num_dims()
62 }
63
64 pub fn dims<const D: usize>(&self) -> [usize; D] {
67 let mut dims = [1; D];
68 dims[..D].copy_from_slice(&self.dims[..D]);
69 dims
70 }
71
72 pub fn flatten(mut self) -> Self {
74 self.dims = [self.num_elements()].into();
75 self
76 }
77
78 pub fn into_ranges(self) -> Vec<Range<usize>> {
80 self.into_iter().map(|d| 0..d).collect()
81 }
82
83 pub fn into_slices<const D: usize, S>(self, slices: S) -> [Slice; D]
146 where
147 S: SliceArg<D>,
148 {
149 slices.into_slices(self)
150 }
151
152 pub fn to_vec(&self) -> Vec<usize> {
154 self.dims.clone()
155 }
156
157 pub fn iter(&self) -> Iter<'_, usize> {
159 self.dims.iter()
160 }
161
162 pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
164 self.dims.iter_mut()
165 }
166
167 pub fn as_slice(&self) -> &[usize] {
169 &self.dims
170 }
171
172 pub fn as_mut_slice(&mut self) -> &mut [usize] {
174 &mut self.dims
175 }
176
177 pub fn insert(&mut self, index: usize, size: usize) {
179 self.dims.insert(index, size);
180 }
181
182 pub fn remove(&mut self, index: usize) -> usize {
184 self.dims.remove(index)
185 }
186
187 pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
189 self.dims.extend(iter)
190 }
191
192 pub fn swap(mut self, dim1: usize, dim2: usize) -> Result<Self, ShapeError> {
194 if dim1 > self.rank() {
195 return Err(ShapeError::OutOfBounds {
196 dim: dim1,
197 rank: self.rank(),
198 });
199 }
200 if dim2 > self.rank() {
201 return Err(ShapeError::OutOfBounds {
202 dim: dim2,
203 rank: self.rank(),
204 });
205 }
206 self.dims.swap(dim1, dim2);
207 Ok(self)
208 }
209
210 pub fn permute(mut self, axes: &[usize]) -> Result<Self, ShapeError> {
212 if axes.len() != self.rank() {
213 return Err(ShapeError::RankMismatch {
214 left: self.rank(),
215 right: axes.len(),
216 });
217 }
218 debug_assert!(axes.iter().all(|i| i < &self.rank()));
219
220 self.dims = axes.iter().map(|&i| self.dims[i]).collect();
221 Ok(self)
222 }
223
224 pub fn repeat(mut self, dim: usize, times: usize) -> Self {
226 self.dims[dim] *= times;
227 self
228 }
229
230 pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, ShapeError>
232 where
233 I: IntoIterator<Item = &'a Shape>,
234 {
235 let mut iter = shapes.into_iter();
236
237 let first = iter.next().ok_or(ShapeError::Empty)?;
238
239 if dim >= first.rank() {
240 return Err(ShapeError::OutOfBounds {
241 dim,
242 rank: first.rank(),
243 });
244 }
245
246 let mut shape = first.clone();
247
248 for s in iter {
249 if s.rank() != shape.rank() {
250 return Err(ShapeError::RankMismatch {
251 left: shape.rank(),
252 right: s.rank(),
253 });
254 }
255
256 if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
257 return Err(ShapeError::IncompatibleShapes {
258 left: shape.clone(),
259 right: s.clone(),
260 });
261 }
262
263 shape[dim] += s[dim];
264 }
265
266 Ok(shape)
267 }
268
269 pub fn slice(mut self, slices: &[Slice]) -> Result<Self, ShapeError> {
271 if slices.len() > self.rank() {
272 return Err(ShapeError::RankMismatch {
273 left: self.rank(),
274 right: slices.len(),
275 });
276 }
277
278 slices
279 .iter()
280 .zip(self.iter_mut())
281 .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
282
283 Ok(self)
284 }
285
286 pub fn broadcast(&self, other: &Self) -> Result<Self, ShapeError> {
295 Self::broadcast_many([self, other])
296 }
297
298 pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, ShapeError>
302 where
303 I: IntoIterator<Item = &'a Shape>,
304 {
305 let mut iter = shapes.into_iter();
306 let mut broadcasted = iter.next().ok_or(ShapeError::Empty)?.clone();
307 let rank = broadcasted.rank();
308
309 for shape in iter {
310 if shape.rank() != rank {
311 return Err(ShapeError::RankMismatch {
312 left: rank,
313 right: shape.rank(),
314 });
315 }
316
317 for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
318 match (*d_lhs, d_rhs) {
319 (a, b) if a == b => {} (1, b) => *d_lhs = b, (_a, 1) => {} _ => {
323 return Err(ShapeError::IncompatibleDims {
324 left: *d_lhs,
325 right: d_rhs,
326 dim,
327 });
328 }
329 }
330 }
331 }
332
333 Ok(broadcasted)
334 }
335
336 pub fn matmul(lhs: &Self, rhs: &Self) -> Result<Self, ShapeError> {
341 let rank = lhs.rank();
342 if rank != rhs.rank() {
343 return Err(ShapeError::RankMismatch {
344 left: rank,
345 right: rhs.rank(),
346 });
347 }
348
349 if lhs[rank - 1] != rhs[rank - 2] {
350 return Err(ShapeError::IncompatibleShapes {
351 left: lhs.clone(),
352 right: rhs.clone(),
353 });
354 }
355
356 let mut shape = if rank > 2 {
357 Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
359 } else {
360 Shape::new([])
361 };
362 shape.extend([lhs[rank - 2], rhs[rank - 1]]);
363
364 Ok(shape)
365 }
366}
367
368impl IntoIterator for Shape {
369 type Item = usize;
370 type IntoIter = alloc::vec::IntoIter<Self::Item>;
371
372 fn into_iter(self) -> Self::IntoIter {
373 self.dims.into_iter()
374 }
375}
376
377impl<Idx> Index<Idx> for Shape
378where
379 Idx: SliceIndex<[usize]>,
380{
381 type Output = Idx::Output;
382
383 fn index(&self, index: Idx) -> &Self::Output {
384 &self.dims[index]
385 }
386}
387
388impl<Idx> IndexMut<Idx> for Shape
389where
390 Idx: SliceIndex<[usize]>,
391{
392 fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
393 &mut self.dims[index]
394 }
395}
396
397impl Deref for Shape {
399 type Target = [usize];
400
401 fn deref(&self) -> &Self::Target {
402 &self.dims
403 }
404}
405
406impl DerefMut for Shape {
408 fn deref_mut(&mut self) -> &mut Self::Target {
409 &mut self.dims
410 }
411}
412
413impl<const D: usize> From<[usize; D]> for Shape {
415 fn from(dims: [usize; D]) -> Self {
416 Shape::new(dims)
417 }
418}
419
420impl<const D: usize> From<[i64; D]> for Shape {
421 fn from(dims: [i64; D]) -> Self {
422 Shape {
423 dims: dims.into_iter().map(|d| d as usize).collect(),
424 }
425 }
426}
427
428impl<const D: usize> From<[i32; D]> for Shape {
429 fn from(dims: [i32; D]) -> Self {
430 Shape {
431 dims: dims.into_iter().map(|d| d as usize).collect(),
432 }
433 }
434}
435
436impl From<&[usize]> for Shape {
437 fn from(dims: &[usize]) -> Self {
438 Shape { dims: dims.into() }
439 }
440}
441
442impl From<Vec<i64>> for Shape {
443 fn from(shape: Vec<i64>) -> Self {
444 Self {
445 dims: shape.into_iter().map(|d| d as usize).collect(),
446 }
447 }
448}
449
450impl From<Vec<u64>> for Shape {
451 fn from(shape: Vec<u64>) -> Self {
452 Self {
453 dims: shape.into_iter().map(|d| d as usize).collect(),
454 }
455 }
456}
457
458impl From<Vec<usize>> for Shape {
459 fn from(shape: Vec<usize>) -> Self {
460 Self { dims: shape }
461 }
462}
463
464impl From<&Vec<usize>> for Shape {
465 fn from(shape: &Vec<usize>) -> Self {
466 Self {
467 dims: shape.clone(),
468 }
469 }
470}
471
472impl From<Shape> for Vec<usize> {
473 fn from(shape: Shape) -> Self {
474 shape.dims
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::s;
482 use alloc::vec;
483
484 #[test]
485 fn num_dims_and_rank() {
486 let dims = [2, 3, 4, 5];
487 let shape = Shape::new(dims);
488 assert_eq!(4, shape.num_dims());
489 assert_eq!(4, shape.rank());
490 }
491
492 #[test]
493 fn num_elements() {
494 let dims = [2, 3, 4, 5];
495 let shape = Shape::new(dims);
496 assert_eq!(120, shape.num_elements());
497 }
498
499 #[test]
500 fn test_shape_into_iter() {
501 let dims = [2, 3, 4, 5];
502 let shape = Shape::new(dims);
503
504 assert_eq!(shape.into_iter().sum::<usize>(), 14);
505 }
506
507 #[test]
508 fn test_into_ranges() {
509 let dims = [2, 3, 4, 5];
510 let shape = Shape::new(dims);
511 assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
512 }
513
514 #[test]
515 fn test_to_vec() {
516 let dims = [2, 3, 4, 5];
517 let shape = Shape::new(dims);
518 assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
519 }
520
521 #[allow(clippy::single_range_in_vec_init)]
522 #[test]
523 fn test_into_slices() {
524 let slices = Shape::new([3]).into_slices(1..4);
525 assert_eq!(slices[0].to_range(3), 1..3);
526
527 let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
528 assert_eq!(slices[0].to_range(3), 1..3);
529 assert_eq!(slices[1].to_range(4), 0..2);
530
531 let slices = Shape::new([3]).into_slices(..-2);
532 assert_eq!(slices[0].to_range(3), 0..1);
533
534 let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
535 assert_eq!(slices[0].to_range(2), 0..2);
536 assert_eq!(slices[1].to_range(3), 1..2);
537
538 let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
539 assert_eq!(slices[0].to_range(2), 0..2);
540 assert_eq!(slices[1].to_range(3), 2..3);
541 }
542
543 #[test]
544 fn test_shape_index() {
545 let shape = Shape::new([2, 3, 4, 5]);
546
547 assert_eq!(shape[0], 2);
548 assert_eq!(shape[1], 3);
549 assert_eq!(shape[2], 4);
550 assert_eq!(shape[3], 5);
551
552 assert_eq!(shape[1..3], *&[3, 4]);
554 assert_eq!(shape[1..=2], *&[3, 4]);
555 assert_eq!(shape[..], *&[2, 3, 4, 5]);
556 }
557
558 #[test]
559 fn test_shape_slice_methods() {
560 let shape = Shape::new([2, 3, 4, 5]);
561
562 let dim = shape.first();
563 assert_eq!(dim, Some(&2));
564 let dim = shape.last();
565 assert_eq!(dim, Some(&5));
566
567 assert!(!shape.is_empty());
568 let shape = Shape::new([]);
569 assert!(shape.is_empty());
570 }
571
572 #[test]
573 fn test_shape_iter() {
574 let dims = [2, 3, 4, 5];
575 let shape = Shape::new(dims);
576
577 for (d, sd) in dims.iter().zip(shape.iter()) {
578 assert_eq!(d, sd);
579 }
580 }
581
582 #[test]
583 fn test_shape_iter_mut() {
584 let mut shape = Shape::new([2, 3, 4, 5]);
585
586 for d in shape.iter_mut() {
587 *d += 1;
588 }
589
590 assert_eq!(&shape.dims, &[3, 4, 5, 6]);
591 }
592
593 #[test]
594 fn test_shape_as_slice() {
595 let dims = [2, 3, 4, 5];
596 let shape = Shape::new(dims);
597
598 assert_eq!(shape.as_slice(), dims.as_slice());
599
600 let shape_slice: &[usize] = &shape;
602 assert_eq!(shape_slice, *&[2, 3, 4, 5]);
603 }
604
605 #[test]
606 fn test_shape_as_mut_slice() {
607 let mut dims = [2, 3, 4, 5];
608 let mut shape = Shape::new(dims);
609
610 let shape_mut = shape.as_mut_slice();
611 assert_eq!(shape_mut, dims.as_mut_slice());
612 shape_mut[1] = 6;
613
614 assert_eq!(shape_mut, &[2, 6, 4, 5]);
615
616 let mut shape = Shape::new(dims);
617 let shape = &mut shape[..];
618 shape[1] = 6;
619
620 assert_eq!(shape, shape_mut)
621 }
622
623 #[test]
624 fn test_shape_flatten() {
625 let shape = Shape::new([2, 3, 4, 5]);
626 assert_eq!(shape.num_elements(), 120);
627
628 let shape = shape.flatten();
629 assert_eq!(shape.num_elements(), 120);
630 assert_eq!(&shape.dims, &[120]);
631 }
632
633 #[test]
634 fn test_shape_insert_remove() {
635 let dims = [2, 3, 4, 5];
636 let mut shape = Shape::new(dims);
637 let size = 6;
638 shape.insert(1, size);
639
640 assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
641
642 let removed = shape.remove(1);
643 assert_eq!(removed, size);
644 assert_eq!(shape, Shape::new(dims));
645 }
646
647 #[test]
648 fn test_shape_swap_permute() {
649 let dims = [2, 3, 4, 5];
650 let shape = Shape::new(dims);
651 let shape = shape.swap(1, 2).unwrap();
652
653 assert_eq!(&shape.dims, &[2, 4, 3, 5]);
654
655 let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
656 assert_eq!(shape, Shape::new(dims));
657 }
658
659 #[test]
660 #[should_panic]
661 fn test_shape_swap_out_of_bounds() {
662 let shape = Shape::new([2, 3, 4, 5]);
663
664 shape.swap(0, 4).unwrap();
665 }
666
667 #[test]
668 #[should_panic]
669 fn test_shape_permute_incomplete() {
670 let shape = Shape::new([2, 3, 4, 5]);
671
672 shape.permute(&[0, 2, 1]).unwrap();
673 }
674
675 #[test]
676 fn test_shape_repeat() {
677 let shape = Shape::new([2, 3, 4, 5]);
678
679 let shape = shape.repeat(2, 3);
680 assert_eq!(shape, Shape::new([2, 3, 12, 5]));
681 }
682
683 #[test]
684 fn test_shape_broadcast_binary() {
685 let lhs = Shape::new([1, 1, 2, 4]);
686 let rhs = Shape::new([7, 6, 2, 1]);
687
688 let out = lhs.broadcast(&rhs).unwrap();
689 assert_eq!(out, Shape::new([7, 6, 2, 4]));
690 }
691
692 #[test]
693 fn test_shape_broadcast_rank_mismatch() {
694 let lhs = Shape::new([1, 2, 4]);
695 let rhs = Shape::new([7, 6, 2, 4]);
696
697 let out = lhs.broadcast(&rhs);
698 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
699 }
700
701 #[test]
702 fn test_shape_broadcast_incompatible_dims() {
703 let lhs = Shape::new([1, 2, 2, 4]);
704 let rhs = Shape::new([7, 6, 2, 1]);
705
706 let out = lhs.broadcast(&rhs);
707 assert_eq!(
708 out,
709 Err(ShapeError::IncompatibleDims {
710 left: 2,
711 right: 6,
712 dim: 1
713 })
714 );
715 }
716
717 #[test]
718 fn test_shape_broadcast_many() {
719 let s1 = Shape::new([1, 1, 2, 4]);
720 let s2 = Shape::new([7, 1, 2, 1]);
721 let s3 = Shape::new([7, 6, 1, 1]);
722
723 let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
724 assert_eq!(out, Shape::new([7, 6, 2, 4]));
725 }
726
727 #[test]
728 fn test_shape_broadcast_many_rank_mismatch() {
729 let s1 = Shape::new([1, 1, 2, 4]);
730 let s2 = Shape::new([7, 1, 2, 1]);
731 let s3 = Shape::new([1, 6, 1]);
732
733 let out = Shape::broadcast_many([&s1, &s2, &s3]);
734 assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 }));
735 }
736
737 #[test]
738 fn test_shape_broadcast_many_incompatible_dims() {
739 let s1 = Shape::new([1, 1, 2, 4]);
740 let s2 = Shape::new([7, 1, 2, 1]);
741 let s3 = Shape::new([4, 6, 1, 1]);
742
743 let out = Shape::broadcast_many([&s1, &s2, &s3]);
744 assert_eq!(
745 out,
746 Err(ShapeError::IncompatibleDims {
747 left: 7,
748 right: 4,
749 dim: 0
750 })
751 );
752 }
753
754 #[test]
755 fn test_shape_broadcast_many_empty() {
756 let out = Shape::broadcast_many(&[]);
757 assert_eq!(out, Err(ShapeError::Empty));
758 }
759
760 #[test]
761 fn test_shape_matmul_2d() {
762 let lhs = Shape::new([2, 4]);
763 let rhs = Shape::new([4, 2]);
764 let out = Shape::matmul(&lhs, &rhs).unwrap();
765 assert_eq!(out, Shape::new([2, 2]));
766 }
767
768 #[test]
769 fn test_shape_matmul_4d_broadcasted() {
770 let lhs = Shape::new([1, 3, 2, 4]);
771 let rhs = Shape::new([2, 1, 4, 2]);
772 let out = Shape::matmul(&lhs, &rhs).unwrap();
773 assert_eq!(out, Shape::new([2, 3, 2, 2]));
774 }
775
776 #[test]
777 fn test_shape_matmul_invalid_rank() {
778 let lhs = Shape::new([3, 2, 4]);
779 let rhs = Shape::new([2, 1, 4, 2]);
780 let out = Shape::matmul(&lhs, &rhs);
781 assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
782 }
783
784 #[test]
785 fn test_shape_matmul_invalid_shape() {
786 let lhs = Shape::new([1, 3, 2, 4]);
787 let rhs = Shape::new([2, 1, 3, 2]);
788 let out = Shape::matmul(&lhs, &rhs);
789 assert_eq!(
790 out,
791 Err(ShapeError::IncompatibleShapes {
792 left: lhs,
793 right: rhs
794 })
795 );
796 }
797
798 #[test]
799 fn test_shape_matmul_invalid_broadcast() {
800 let lhs = Shape::new([1, 3, 2, 4]);
801 let rhs = Shape::new([2, 2, 4, 2]);
802 let out = Shape::matmul(&lhs, &rhs);
803 assert_eq!(
804 out,
805 Err(ShapeError::IncompatibleDims {
806 left: 3,
807 right: 2,
808 dim: 1
809 })
810 );
811 }
812
813 #[test]
814 fn test_shape_cat() {
815 let s1 = Shape::new([2, 3, 4, 5]);
816 let s2 = Shape::new([1, 3, 4, 5]);
817 let s3 = Shape::new([4, 3, 4, 5]);
818
819 let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
820 assert_eq!(out, Shape::new([7, 3, 4, 5]));
821
822 let s1 = Shape::new([2, 3, 4, 5]);
823 let s2 = Shape::new([2, 3, 2, 5]);
824 let s3 = Shape::new([2, 3, 1, 5]);
825
826 let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
827 assert_eq!(out, Shape::new([2, 3, 7, 5]));
828 }
829
830 #[test]
831 fn test_shape_cat_empty() {
832 let out = Shape::cat(&[], 0);
833 assert_eq!(out, Err(ShapeError::Empty));
834 }
835
836 #[test]
837 fn test_shape_cat_dim_out_of_bounds() {
838 let s1 = Shape::new([2, 3, 4, 5]);
839 let s2 = Shape::new([2, 3, 4, 5]);
840 let out = Shape::cat(&[s1, s2], 4);
841 assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 }));
842 }
843
844 #[test]
845 fn test_shape_cat_rank_mismatch() {
846 let s1 = Shape::new([2, 3, 4, 5]);
847 let s2 = Shape::new([2, 3, 4, 5, 6]);
848 let out = Shape::cat(&[s1, s2], 0);
849 assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 }));
850 }
851
852 #[test]
853 fn test_shape_cat_incompatible_shapes() {
854 let s1 = Shape::new([2, 3, 4, 5]);
855 let s2 = Shape::new([1, 3, 4, 5]);
856 let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
857
858 assert_eq!(
859 out,
860 Err(ShapeError::IncompatibleShapes {
861 left: s1,
862 right: s2
863 })
864 );
865 }
866
867 #[test]
868 fn test_shape_slice_output_shape_basic() {
869 let slices = [
871 Slice::new(0, Some(5), 1), Slice::new(2, Some(8), 1), ];
874 let original_shape = Shape::new([10, 10, 10]);
875 let result = original_shape.slice(&slices).unwrap();
876 assert_eq!(result, Shape::new([5, 6, 10]));
877 }
878
879 #[test]
880 fn test_shape_slice_output_shape_with_positive_steps() {
881 let slices = [
883 Slice::new(0, Some(10), 2), Slice::new(1, Some(9), 3), Slice::new(0, Some(7), 4), ];
887 let original_shape = Shape::new([20, 20, 20, 30]);
888 let result = original_shape.slice(&slices).unwrap();
889 assert_eq!(result, Shape::new([5, 3, 2, 30]));
890 }
891
892 #[test]
893 fn test_shape_slice_output_shape_with_negative_steps() {
894 let slices = [
896 Slice::new(0, Some(10), -1), Slice::new(2, Some(8), -2), ];
899 let original_shape = Shape::new([20, 20, 20]);
900 let result = original_shape.slice(&slices).unwrap();
901 assert_eq!(result, Shape::new([10, 3, 20]));
902 }
903
904 #[test]
905 fn test_shape_slice_output_shape_mixed_steps() {
906 let slices = [
908 Slice::from_range_stepped(1..6, 1), Slice::from_range_stepped(0..10, -3), Slice::from_range_stepped(2..14, 4), ];
912 let original_shape = Shape::new([20, 20, 20]);
913 let result = original_shape.slice(&slices).unwrap();
914 assert_eq!(result, Shape::new([5, 4, 3]));
915 }
916
917 #[test]
918 fn test_shape_slice_output_shape_partial_dims() {
919 let slices = [
921 Slice::from_range_stepped(2..7, 2), ];
923 let original_shape = Shape::new([10, 20, 30, 40]);
924 let result = original_shape.slice(&slices).unwrap();
925 assert_eq!(result, Shape::new([3, 20, 30, 40]));
926 }
927
928 #[test]
929 fn test_shape_slice_output_shape_edge_cases() {
930 let slices = [
932 Slice::from_range_stepped(0..1, 1), Slice::from_range_stepped(0..10, 100), Slice::from_range_stepped(5..5, 1), ];
936 let original_shape = Shape::new([10, 20, 30]);
937 let result = original_shape.slice(&slices).unwrap();
938 assert_eq!(result, Shape::new([1, 1, 0]));
939 }
940
941 #[test]
942 fn test_shape_slice_output_shape_empty() {
943 let slices = [];
945 let original_shape = Shape::new([10, 20, 30]);
946 let result = original_shape.slice(&slices).unwrap();
947 assert_eq!(result, Shape::new([10, 20, 30]));
948 }
949
950 #[test]
951 fn test_shape_slice_output_shape_uneven_division() {
952 let slices = [
954 Slice::from_range_stepped(0..7, 3), Slice::from_range_stepped(0..11, 4), Slice::from_range_stepped(1..10, 5), ];
958 let original_shape = Shape::new([20, 20, 20]);
959 let result = original_shape.slice(&slices).unwrap();
960 assert_eq!(result, Shape::new([3, 3, 2]));
961 }
962}