burn_tensor/tensor/
shape.rs

1use crate::{Slice, SliceArg};
2use alloc::vec::Vec;
3use core::{
4    ops::{Deref, DerefMut, Index, IndexMut, Range},
5    slice::{Iter, IterMut, SliceIndex},
6};
7use serde::{Deserialize, Serialize};
8
9/// Shape of a tensor.
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct Shape {
12    /// The dimensions of the tensor.
13    pub dims: Vec<usize>,
14}
15
16#[allow(missing_docs)]
17#[derive(Debug, Clone, PartialEq, Eq)]
18/// Error that can occur when attempting to modify shapes.
19pub enum ShapeError {
20    /// The operands have different ranks.
21    RankMismatch { left: usize, right: usize },
22    /// A pair of dimensions are incompatible for broadcasting.
23    IncompatibleDims {
24        left: usize,
25        right: usize,
26        dim: usize,
27    },
28    /// Invalid dimension specified for the rank.
29    OutOfBounds { dim: usize, rank: usize },
30    /// A pair of shapes are incompatible for the operation.
31    IncompatibleShapes { left: Shape, right: Shape },
32    /// Invalid empty shape.
33    Empty,
34}
35
36impl Shape {
37    /// Constructs a new `Shape`.
38    pub fn new<const D: usize>(dims: [usize; D]) -> Self {
39        // For backward compat
40        Self {
41            dims: dims.to_vec(),
42        }
43    }
44
45    /// Returns the total number of elements of a tensor having this shape
46    pub fn num_elements(&self) -> usize {
47        self.dims.iter().product()
48    }
49
50    /// Returns the number of dimensions.
51    ///
52    /// Alias for `Shape::rank()`.
53    pub fn num_dims(&self) -> usize {
54        self.dims.len()
55    }
56
57    /// Returns the rank (the number of dimensions).
58    ///
59    /// Alias for `Shape::num_dims()`.
60    pub fn rank(&self) -> usize {
61        self.num_dims()
62    }
63
64    // For compat with dims: [usize; D]
65    /// Returns the dimensions of the tensor as an array.
66    pub fn dims<const D: usize>(&self) -> [usize; D] {
67        let mut dims = [1; D];
68        dims[..D].copy_from_slice(&self.dims[..D]);
69        dims
70    }
71
72    /// Change the shape to one dimensional with the same number of elements.
73    pub fn flatten(mut self) -> Self {
74        self.dims = [self.num_elements()].into();
75        self
76    }
77
78    /// Convert shape dimensions to full covering ranges (0..dim) for each dimension.
79    pub fn into_ranges(self) -> Vec<Range<usize>> {
80        self.into_iter().map(|d| 0..d).collect()
81    }
82
83    /// Converts slice arguments into an array of slice specifications for the shape.
84    ///
85    /// This method returns an array of `Slice` objects that can be used for slicing operations.
86    /// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but
87    /// allows custom slice specifications instead of full ranges.
88    /// For creating complex slice specifications, use the [`s!`] macro.
89    ///
90    /// # Arguments
91    ///
92    /// * `slices` - An array of slice specifications, where each element can be:
93    ///   - A range (e.g., `2..5`)
94    ///   - An index
95    ///   - A `Slice` object
96    ///   - The output of the [`s!`] macro for advanced slicing
97    ///
98    /// # Behavior
99    ///
100    /// - Supports partial and full slicing in any number of dimensions.
101    /// - Missing ranges are treated as full slices if D > D2.
102    /// - Handles negative indices by wrapping around from the end of the dimension.
103    /// - Clamps ranges to the shape's dimensions if they exceed the bounds.
104    ///
105    /// # Returns
106    ///
107    /// An array of `Slice` objects corresponding to the provided slice specifications,
108    /// clamped to the shape's actual dimensions.
109    ///
110    /// # Examples
111    ///
112    /// ```rust
113    /// use burn_tensor::backend::Backend;
114    /// use burn_tensor::{Tensor, Shape, Slice, s};
115    ///
116    /// fn example<B: Backend>() {
117    ///     // 1D slicing
118    ///     let slices = Shape::new([4]).into_slices(1..4);
119    ///     assert_eq!(slices[0].to_range(4), 1..3);
120    ///
121    ///     // 2D slicing
122    ///     let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
123    ///     assert_eq!(slices[0].to_range(3), 1..3);
124    ///     assert_eq!(slices[1].to_range(4), 0..2);
125    ///
126    ///     // Using negative indices
127    ///     let slices = Shape::new([3]).into_slices(..-2);
128    ///     assert_eq!(slices[0].to_range(3), 0..1);
129    ///
130    ///     // Using the slice macro to select different ranges
131    ///     let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
132    ///     assert_eq!(slices[0].to_range(2), 0..2);
133    ///     assert_eq!(slices[1].to_range(3), 1..2);
134    /// }
135    /// ```
136    ///
137    /// # See Also
138    ///
139    /// - [`s!`] - The recommended macro for creating slice specifications
140    /// - [`Tensor::slice`] - Apply slicing to a tensor
141    /// - [`Shape::into_ranges`] - Convert to full covering ranges
142    ///
143    /// [`s!`]: crate::s!
144    /// [`Tensor::slice`]: crate::Tensor::slice
145    pub fn into_slices<const D: usize, S>(self, slices: S) -> [Slice; D]
146    where
147        S: SliceArg<D>,
148    {
149        slices.into_slices(self)
150    }
151
152    /// Construct a vector of the dims.
153    pub fn to_vec(&self) -> Vec<usize> {
154        self.dims.clone()
155    }
156
157    /// Returns an iterator over the shape dimensions.
158    pub fn iter(&self) -> Iter<'_, usize> {
159        self.dims.iter()
160    }
161
162    /// Mutable iterator over the dimensions.
163    pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
164        self.dims.iter_mut()
165    }
166
167    /// Borrow the underlying dimensions slice.
168    pub fn as_slice(&self) -> &[usize] {
169        &self.dims
170    }
171
172    /// Borrow the underlying dimensions slice mutably.
173    pub fn as_mut_slice(&mut self) -> &mut [usize] {
174        &mut self.dims
175    }
176
177    /// Insert a dimension of `size` at position `index`.
178    pub fn insert(&mut self, index: usize, size: usize) {
179        self.dims.insert(index, size);
180    }
181
182    /// Remove and return the dimension at position `index` from the shape.
183    pub fn remove(&mut self, index: usize) -> usize {
184        self.dims.remove(index)
185    }
186
187    /// Extend the shape with the content of another shape or iterator.
188    pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
189        self.dims.extend(iter)
190    }
191
192    /// Swap two dimensions in the shape.
193    pub fn swap(mut self, dim1: usize, dim2: usize) -> Result<Self, ShapeError> {
194        if dim1 > self.rank() {
195            return Err(ShapeError::OutOfBounds {
196                dim: dim1,
197                rank: self.rank(),
198            });
199        }
200        if dim2 > self.rank() {
201            return Err(ShapeError::OutOfBounds {
202                dim: dim2,
203                rank: self.rank(),
204            });
205        }
206        self.dims.swap(dim1, dim2);
207        Ok(self)
208    }
209
210    /// Reorder the shape dimensions according to the permutation of `axes`.
211    pub fn permute(mut self, axes: &[usize]) -> Result<Self, ShapeError> {
212        if axes.len() != self.rank() {
213            return Err(ShapeError::RankMismatch {
214                left: self.rank(),
215                right: axes.len(),
216            });
217        }
218        debug_assert!(axes.iter().all(|i| i < &self.rank()));
219
220        self.dims = axes.iter().map(|&i| self.dims[i]).collect();
221        Ok(self)
222    }
223
224    /// Repeated the specified `dim` a number of `times`.
225    pub fn repeat(mut self, dim: usize, times: usize) -> Self {
226        self.dims[dim] *= times;
227        self
228    }
229
230    /// Concatenates all shapes into a new one along the given dimension.
231    pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, ShapeError>
232    where
233        I: IntoIterator<Item = &'a Shape>,
234    {
235        let mut iter = shapes.into_iter();
236
237        let first = iter.next().ok_or(ShapeError::Empty)?;
238
239        if dim >= first.rank() {
240            return Err(ShapeError::OutOfBounds {
241                dim,
242                rank: first.rank(),
243            });
244        }
245
246        let mut shape = first.clone();
247
248        for s in iter {
249            if s.rank() != shape.rank() {
250                return Err(ShapeError::RankMismatch {
251                    left: shape.rank(),
252                    right: s.rank(),
253                });
254            }
255
256            if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
257                return Err(ShapeError::IncompatibleShapes {
258                    left: shape.clone(),
259                    right: s.clone(),
260                });
261            }
262
263            shape[dim] += s[dim];
264        }
265
266        Ok(shape)
267    }
268
269    /// Compute the output shape from the given slices.
270    pub fn slice(mut self, slices: &[Slice]) -> Result<Self, ShapeError> {
271        if slices.len() > self.rank() {
272            return Err(ShapeError::RankMismatch {
273                left: self.rank(),
274                right: slices.len(),
275            });
276        }
277
278        slices
279            .iter()
280            .zip(self.iter_mut())
281            .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
282
283        Ok(self)
284    }
285
286    /// Compute the output shape for binary operations with broadcasting support.
287    ///
288    /// - Shapes must be of the same rank (missing dimensions are not handled automatically).
289    /// - Two dimensions are compatible if they are equal, or one of them is 1.
290    ///
291    /// For example, a shape `[1, 1, 2, 4]` can be broadcast into `[7, 6, 2, 4]`
292    /// because its axes are either equal or 1. On the other hand, a shape `[2, 2]`
293    /// can *not* be broadcast into `[2, 4]`.
294    pub fn broadcast(&self, other: &Self) -> Result<Self, ShapeError> {
295        Self::broadcast_many([self, other])
296    }
297
298    /// Compute the broadcasted output shape across multiple input shapes.
299    ///
300    /// See also [broadcast](Self::broadcast).
301    pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, ShapeError>
302    where
303        I: IntoIterator<Item = &'a Shape>,
304    {
305        let mut iter = shapes.into_iter();
306        let mut broadcasted = iter.next().ok_or(ShapeError::Empty)?.clone();
307        let rank = broadcasted.rank();
308
309        for shape in iter {
310            if shape.rank() != rank {
311                return Err(ShapeError::RankMismatch {
312                    left: rank,
313                    right: shape.rank(),
314                });
315            }
316
317            for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
318                match (*d_lhs, d_rhs) {
319                    (a, b) if a == b => {} // same
320                    (1, b) => *d_lhs = b,  // broadcast to rhs
321                    (_a, 1) => {}          // keep existing dimension
322                    _ => {
323                        return Err(ShapeError::IncompatibleDims {
324                            left: *d_lhs,
325                            right: d_rhs,
326                            dim,
327                        });
328                    }
329                }
330            }
331        }
332
333        Ok(broadcasted)
334    }
335
336    /// Compute the output shape for matrix multiplication with broadcasting support.
337    ///
338    /// The last two dimensions are treated as matrices, while preceding dimensions
339    /// follow broadcast semantics similar to elementwise operations.
340    pub fn matmul(lhs: &Self, rhs: &Self) -> Result<Self, ShapeError> {
341        let rank = lhs.rank();
342        if rank != rhs.rank() {
343            return Err(ShapeError::RankMismatch {
344                left: rank,
345                right: rhs.rank(),
346            });
347        }
348
349        if lhs[rank - 1] != rhs[rank - 2] {
350            return Err(ShapeError::IncompatibleShapes {
351                left: lhs.clone(),
352                right: rhs.clone(),
353            });
354        }
355
356        let mut shape = if rank > 2 {
357            // Broadcast leading dims
358            Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
359        } else {
360            Shape::new([])
361        };
362        shape.extend([lhs[rank - 2], rhs[rank - 1]]);
363
364        Ok(shape)
365    }
366}
367
368impl IntoIterator for Shape {
369    type Item = usize;
370    type IntoIter = alloc::vec::IntoIter<Self::Item>;
371
372    fn into_iter(self) -> Self::IntoIter {
373        self.dims.into_iter()
374    }
375}
376
377impl<Idx> Index<Idx> for Shape
378where
379    Idx: SliceIndex<[usize]>,
380{
381    type Output = Idx::Output;
382
383    fn index(&self, index: Idx) -> &Self::Output {
384        &self.dims[index]
385    }
386}
387
388impl<Idx> IndexMut<Idx> for Shape
389where
390    Idx: SliceIndex<[usize]>,
391{
392    fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
393        &mut self.dims[index]
394    }
395}
396
397// Allow `&shape` to behave like a slice `&[usize]` directly
398impl Deref for Shape {
399    type Target = [usize];
400
401    fn deref(&self) -> &Self::Target {
402        &self.dims
403    }
404}
405
406// Allow `&shape` to behave like a mut slice `&mut [usize]` directly
407impl DerefMut for Shape {
408    fn deref_mut(&mut self) -> &mut Self::Target {
409        &mut self.dims
410    }
411}
412
413// Conversion sugar
414impl<const D: usize> From<[usize; D]> for Shape {
415    fn from(dims: [usize; D]) -> Self {
416        Shape::new(dims)
417    }
418}
419
420impl<const D: usize> From<[i64; D]> for Shape {
421    fn from(dims: [i64; D]) -> Self {
422        Shape {
423            dims: dims.into_iter().map(|d| d as usize).collect(),
424        }
425    }
426}
427
428impl<const D: usize> From<[i32; D]> for Shape {
429    fn from(dims: [i32; D]) -> Self {
430        Shape {
431            dims: dims.into_iter().map(|d| d as usize).collect(),
432        }
433    }
434}
435
436impl From<&[usize]> for Shape {
437    fn from(dims: &[usize]) -> Self {
438        Shape { dims: dims.into() }
439    }
440}
441
442impl From<Vec<i64>> for Shape {
443    fn from(shape: Vec<i64>) -> Self {
444        Self {
445            dims: shape.into_iter().map(|d| d as usize).collect(),
446        }
447    }
448}
449
450impl From<Vec<u64>> for Shape {
451    fn from(shape: Vec<u64>) -> Self {
452        Self {
453            dims: shape.into_iter().map(|d| d as usize).collect(),
454        }
455    }
456}
457
458impl From<Vec<usize>> for Shape {
459    fn from(shape: Vec<usize>) -> Self {
460        Self { dims: shape }
461    }
462}
463
464impl From<&Vec<usize>> for Shape {
465    fn from(shape: &Vec<usize>) -> Self {
466        Self {
467            dims: shape.clone(),
468        }
469    }
470}
471
472impl From<Shape> for Vec<usize> {
473    fn from(shape: Shape) -> Self {
474        shape.dims
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use crate::s;
482    use alloc::vec;
483
484    #[test]
485    fn num_dims_and_rank() {
486        let dims = [2, 3, 4, 5];
487        let shape = Shape::new(dims);
488        assert_eq!(4, shape.num_dims());
489        assert_eq!(4, shape.rank());
490    }
491
492    #[test]
493    fn num_elements() {
494        let dims = [2, 3, 4, 5];
495        let shape = Shape::new(dims);
496        assert_eq!(120, shape.num_elements());
497    }
498
499    #[test]
500    fn test_shape_into_iter() {
501        let dims = [2, 3, 4, 5];
502        let shape = Shape::new(dims);
503
504        assert_eq!(shape.into_iter().sum::<usize>(), 14);
505    }
506
507    #[test]
508    fn test_into_ranges() {
509        let dims = [2, 3, 4, 5];
510        let shape = Shape::new(dims);
511        assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
512    }
513
514    #[test]
515    fn test_to_vec() {
516        let dims = [2, 3, 4, 5];
517        let shape = Shape::new(dims);
518        assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
519    }
520
521    #[allow(clippy::single_range_in_vec_init)]
522    #[test]
523    fn test_into_slices() {
524        let slices = Shape::new([3]).into_slices(1..4);
525        assert_eq!(slices[0].to_range(3), 1..3);
526
527        let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
528        assert_eq!(slices[0].to_range(3), 1..3);
529        assert_eq!(slices[1].to_range(4), 0..2);
530
531        let slices = Shape::new([3]).into_slices(..-2);
532        assert_eq!(slices[0].to_range(3), 0..1);
533
534        let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
535        assert_eq!(slices[0].to_range(2), 0..2);
536        assert_eq!(slices[1].to_range(3), 1..2);
537
538        let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
539        assert_eq!(slices[0].to_range(2), 0..2);
540        assert_eq!(slices[1].to_range(3), 2..3);
541    }
542
543    #[test]
544    fn test_shape_index() {
545        let shape = Shape::new([2, 3, 4, 5]);
546
547        assert_eq!(shape[0], 2);
548        assert_eq!(shape[1], 3);
549        assert_eq!(shape[2], 4);
550        assert_eq!(shape[3], 5);
551
552        // Works with ranges
553        assert_eq!(shape[1..3], *&[3, 4]);
554        assert_eq!(shape[1..=2], *&[3, 4]);
555        assert_eq!(shape[..], *&[2, 3, 4, 5]);
556    }
557
558    #[test]
559    fn test_shape_slice_methods() {
560        let shape = Shape::new([2, 3, 4, 5]);
561
562        let dim = shape.first();
563        assert_eq!(dim, Some(&2));
564        let dim = shape.last();
565        assert_eq!(dim, Some(&5));
566
567        assert!(!shape.is_empty());
568        let shape = Shape::new([]);
569        assert!(shape.is_empty());
570    }
571
572    #[test]
573    fn test_shape_iter() {
574        let dims = [2, 3, 4, 5];
575        let shape = Shape::new(dims);
576
577        for (d, sd) in dims.iter().zip(shape.iter()) {
578            assert_eq!(d, sd);
579        }
580    }
581
582    #[test]
583    fn test_shape_iter_mut() {
584        let mut shape = Shape::new([2, 3, 4, 5]);
585
586        for d in shape.iter_mut() {
587            *d += 1;
588        }
589
590        assert_eq!(&shape.dims, &[3, 4, 5, 6]);
591    }
592
593    #[test]
594    fn test_shape_as_slice() {
595        let dims = [2, 3, 4, 5];
596        let shape = Shape::new(dims);
597
598        assert_eq!(shape.as_slice(), dims.as_slice());
599
600        // Deref coercion
601        let shape_slice: &[usize] = &shape;
602        assert_eq!(shape_slice, *&[2, 3, 4, 5]);
603    }
604
605    #[test]
606    fn test_shape_as_mut_slice() {
607        let mut dims = [2, 3, 4, 5];
608        let mut shape = Shape::new(dims);
609
610        let shape_mut = shape.as_mut_slice();
611        assert_eq!(shape_mut, dims.as_mut_slice());
612        shape_mut[1] = 6;
613
614        assert_eq!(shape_mut, &[2, 6, 4, 5]);
615
616        let mut shape = Shape::new(dims);
617        let shape = &mut shape[..];
618        shape[1] = 6;
619
620        assert_eq!(shape, shape_mut)
621    }
622
623    #[test]
624    fn test_shape_flatten() {
625        let shape = Shape::new([2, 3, 4, 5]);
626        assert_eq!(shape.num_elements(), 120);
627
628        let shape = shape.flatten();
629        assert_eq!(shape.num_elements(), 120);
630        assert_eq!(&shape.dims, &[120]);
631    }
632
633    #[test]
634    fn test_shape_insert_remove() {
635        let dims = [2, 3, 4, 5];
636        let mut shape = Shape::new(dims);
637        let size = 6;
638        shape.insert(1, size);
639
640        assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
641
642        let removed = shape.remove(1);
643        assert_eq!(removed, size);
644        assert_eq!(shape, Shape::new(dims));
645    }
646
647    #[test]
648    fn test_shape_swap_permute() {
649        let dims = [2, 3, 4, 5];
650        let shape = Shape::new(dims);
651        let shape = shape.swap(1, 2).unwrap();
652
653        assert_eq!(&shape.dims, &[2, 4, 3, 5]);
654
655        let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
656        assert_eq!(shape, Shape::new(dims));
657    }
658
659    #[test]
660    #[should_panic]
661    fn test_shape_swap_out_of_bounds() {
662        let shape = Shape::new([2, 3, 4, 5]);
663
664        shape.swap(0, 4).unwrap();
665    }
666
667    #[test]
668    #[should_panic]
669    fn test_shape_permute_incomplete() {
670        let shape = Shape::new([2, 3, 4, 5]);
671
672        shape.permute(&[0, 2, 1]).unwrap();
673    }
674
675    #[test]
676    fn test_shape_repeat() {
677        let shape = Shape::new([2, 3, 4, 5]);
678
679        let shape = shape.repeat(2, 3);
680        assert_eq!(shape, Shape::new([2, 3, 12, 5]));
681    }
682
683    #[test]
684    fn test_shape_broadcast_binary() {
685        let lhs = Shape::new([1, 1, 2, 4]);
686        let rhs = Shape::new([7, 6, 2, 1]);
687
688        let out = lhs.broadcast(&rhs).unwrap();
689        assert_eq!(out, Shape::new([7, 6, 2, 4]));
690    }
691
692    #[test]
693    fn test_shape_broadcast_rank_mismatch() {
694        let lhs = Shape::new([1, 2, 4]);
695        let rhs = Shape::new([7, 6, 2, 4]);
696
697        let out = lhs.broadcast(&rhs);
698        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
699    }
700
701    #[test]
702    fn test_shape_broadcast_incompatible_dims() {
703        let lhs = Shape::new([1, 2, 2, 4]);
704        let rhs = Shape::new([7, 6, 2, 1]);
705
706        let out = lhs.broadcast(&rhs);
707        assert_eq!(
708            out,
709            Err(ShapeError::IncompatibleDims {
710                left: 2,
711                right: 6,
712                dim: 1
713            })
714        );
715    }
716
717    #[test]
718    fn test_shape_broadcast_many() {
719        let s1 = Shape::new([1, 1, 2, 4]);
720        let s2 = Shape::new([7, 1, 2, 1]);
721        let s3 = Shape::new([7, 6, 1, 1]);
722
723        let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
724        assert_eq!(out, Shape::new([7, 6, 2, 4]));
725    }
726
727    #[test]
728    fn test_shape_broadcast_many_rank_mismatch() {
729        let s1 = Shape::new([1, 1, 2, 4]);
730        let s2 = Shape::new([7, 1, 2, 1]);
731        let s3 = Shape::new([1, 6, 1]);
732
733        let out = Shape::broadcast_many([&s1, &s2, &s3]);
734        assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 }));
735    }
736
737    #[test]
738    fn test_shape_broadcast_many_incompatible_dims() {
739        let s1 = Shape::new([1, 1, 2, 4]);
740        let s2 = Shape::new([7, 1, 2, 1]);
741        let s3 = Shape::new([4, 6, 1, 1]);
742
743        let out = Shape::broadcast_many([&s1, &s2, &s3]);
744        assert_eq!(
745            out,
746            Err(ShapeError::IncompatibleDims {
747                left: 7,
748                right: 4,
749                dim: 0
750            })
751        );
752    }
753
754    #[test]
755    fn test_shape_broadcast_many_empty() {
756        let out = Shape::broadcast_many(&[]);
757        assert_eq!(out, Err(ShapeError::Empty));
758    }
759
760    #[test]
761    fn test_shape_matmul_2d() {
762        let lhs = Shape::new([2, 4]);
763        let rhs = Shape::new([4, 2]);
764        let out = Shape::matmul(&lhs, &rhs).unwrap();
765        assert_eq!(out, Shape::new([2, 2]));
766    }
767
768    #[test]
769    fn test_shape_matmul_4d_broadcasted() {
770        let lhs = Shape::new([1, 3, 2, 4]);
771        let rhs = Shape::new([2, 1, 4, 2]);
772        let out = Shape::matmul(&lhs, &rhs).unwrap();
773        assert_eq!(out, Shape::new([2, 3, 2, 2]));
774    }
775
776    #[test]
777    fn test_shape_matmul_invalid_rank() {
778        let lhs = Shape::new([3, 2, 4]);
779        let rhs = Shape::new([2, 1, 4, 2]);
780        let out = Shape::matmul(&lhs, &rhs);
781        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
782    }
783
784    #[test]
785    fn test_shape_matmul_invalid_shape() {
786        let lhs = Shape::new([1, 3, 2, 4]);
787        let rhs = Shape::new([2, 1, 3, 2]);
788        let out = Shape::matmul(&lhs, &rhs);
789        assert_eq!(
790            out,
791            Err(ShapeError::IncompatibleShapes {
792                left: lhs,
793                right: rhs
794            })
795        );
796    }
797
798    #[test]
799    fn test_shape_matmul_invalid_broadcast() {
800        let lhs = Shape::new([1, 3, 2, 4]);
801        let rhs = Shape::new([2, 2, 4, 2]);
802        let out = Shape::matmul(&lhs, &rhs);
803        assert_eq!(
804            out,
805            Err(ShapeError::IncompatibleDims {
806                left: 3,
807                right: 2,
808                dim: 1
809            })
810        );
811    }
812
813    #[test]
814    fn test_shape_cat() {
815        let s1 = Shape::new([2, 3, 4, 5]);
816        let s2 = Shape::new([1, 3, 4, 5]);
817        let s3 = Shape::new([4, 3, 4, 5]);
818
819        let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
820        assert_eq!(out, Shape::new([7, 3, 4, 5]));
821
822        let s1 = Shape::new([2, 3, 4, 5]);
823        let s2 = Shape::new([2, 3, 2, 5]);
824        let s3 = Shape::new([2, 3, 1, 5]);
825
826        let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
827        assert_eq!(out, Shape::new([2, 3, 7, 5]));
828    }
829
830    #[test]
831    fn test_shape_cat_empty() {
832        let out = Shape::cat(&[], 0);
833        assert_eq!(out, Err(ShapeError::Empty));
834    }
835
836    #[test]
837    fn test_shape_cat_dim_out_of_bounds() {
838        let s1 = Shape::new([2, 3, 4, 5]);
839        let s2 = Shape::new([2, 3, 4, 5]);
840        let out = Shape::cat(&[s1, s2], 4);
841        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 }));
842    }
843
844    #[test]
845    fn test_shape_cat_rank_mismatch() {
846        let s1 = Shape::new([2, 3, 4, 5]);
847        let s2 = Shape::new([2, 3, 4, 5, 6]);
848        let out = Shape::cat(&[s1, s2], 0);
849        assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 }));
850    }
851
852    #[test]
853    fn test_shape_cat_incompatible_shapes() {
854        let s1 = Shape::new([2, 3, 4, 5]);
855        let s2 = Shape::new([1, 3, 4, 5]);
856        let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
857
858        assert_eq!(
859            out,
860            Err(ShapeError::IncompatibleShapes {
861                left: s1,
862                right: s2
863            })
864        );
865    }
866
867    #[test]
868    fn test_shape_slice_output_shape_basic() {
869        // Test basic slicing with step=1
870        let slices = [
871            Slice::new(0, Some(5), 1), // 5 elements
872            Slice::new(2, Some(8), 1), // 6 elements
873        ];
874        let original_shape = Shape::new([10, 10, 10]);
875        let result = original_shape.slice(&slices).unwrap();
876        assert_eq!(result, Shape::new([5, 6, 10]));
877    }
878
879    #[test]
880    fn test_shape_slice_output_shape_with_positive_steps() {
881        // Test slicing with various positive steps
882        let slices = [
883            Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements
884            Slice::new(1, Some(9), 3),  // [1,4,7] -> 3 elements
885            Slice::new(0, Some(7), 4),  // [0,4] -> 2 elements
886        ];
887        let original_shape = Shape::new([20, 20, 20, 30]);
888        let result = original_shape.slice(&slices).unwrap();
889        assert_eq!(result, Shape::new([5, 3, 2, 30]));
890    }
891
892    #[test]
893    fn test_shape_slice_output_shape_with_negative_steps() {
894        // Test slicing with negative steps (backward iteration)
895        let slices = [
896            Slice::new(0, Some(10), -1), // 10 elements traversed backward
897            Slice::new(2, Some(8), -2),  // [7,5,3] -> 3 elements
898        ];
899        let original_shape = Shape::new([20, 20, 20]);
900        let result = original_shape.slice(&slices).unwrap();
901        assert_eq!(result, Shape::new([10, 3, 20]));
902    }
903
904    #[test]
905    fn test_shape_slice_output_shape_mixed_steps() {
906        // Test with a mix of positive, negative, and unit steps
907        let slices = [
908            Slice::from_range_stepped(1..6, 1),   // 5 elements
909            Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements
910            Slice::from_range_stepped(2..14, 4),  // [2,6,10] -> 3 elements
911        ];
912        let original_shape = Shape::new([20, 20, 20]);
913        let result = original_shape.slice(&slices).unwrap();
914        assert_eq!(result, Shape::new([5, 4, 3]));
915    }
916
917    #[test]
918    fn test_shape_slice_output_shape_partial_dims() {
919        // Test when slices has fewer dimensions than original shape
920        let slices = [
921            Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements
922        ];
923        let original_shape = Shape::new([10, 20, 30, 40]);
924        let result = original_shape.slice(&slices).unwrap();
925        assert_eq!(result, Shape::new([3, 20, 30, 40]));
926    }
927
928    #[test]
929    fn test_shape_slice_output_shape_edge_cases() {
930        // Test edge cases with small ranges and large steps
931        let slices = [
932            Slice::from_range_stepped(0..1, 1),    // Single element
933            Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element
934            Slice::from_range_stepped(5..5, 1),    // Empty range -> 0 elements
935        ];
936        let original_shape = Shape::new([10, 20, 30]);
937        let result = original_shape.slice(&slices).unwrap();
938        assert_eq!(result, Shape::new([1, 1, 0]));
939    }
940
941    #[test]
942    fn test_shape_slice_output_shape_empty() {
943        // Test with no slice infos (should return original shape)
944        let slices = [];
945        let original_shape = Shape::new([10, 20, 30]);
946        let result = original_shape.slice(&slices).unwrap();
947        assert_eq!(result, Shape::new([10, 20, 30]));
948    }
949
950    #[test]
951    fn test_shape_slice_output_shape_uneven_division() {
952        // Test cases where range size doesn't divide evenly by step
953        let slices = [
954            Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6]
955            Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8]
956            Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6]
957        ];
958        let original_shape = Shape::new([20, 20, 20]);
959        let result = original_shape.slice(&slices).unwrap();
960        assert_eq!(result, Shape::new([3, 3, 2]));
961    }
962}