burn_std/tensor/
shape.rs

1//! Tensor shape definition.
2
3use alloc::vec::Vec;
4use core::{
5    ops::{Deref, DerefMut, Index, IndexMut, Range},
6    slice::{Iter, IterMut, SliceIndex},
7};
8use serde::{Deserialize, Serialize};
9
10use super::indexing::ravel_index;
11use super::{AsIndex, Slice, SliceArg};
12
13/// Shape of a tensor.
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
15pub struct Shape {
16    /// The dimensions of the tensor.
17    pub dims: Vec<usize>,
18}
19
20#[allow(missing_docs)]
21#[derive(Debug, Clone, PartialEq, Eq)]
22/// Error that can occur when attempting to modify shapes.
23pub enum ShapeError {
24    /// The operands have different ranks.
25    RankMismatch { left: usize, right: usize },
26    /// A pair of dimensions are incompatible for broadcasting.
27    IncompatibleDims {
28        left: usize,
29        right: usize,
30        dim: usize,
31    },
32    /// Invalid dimension specified for the rank.
33    OutOfBounds { dim: usize, rank: usize },
34    /// A pair of shapes are incompatible for the operation.
35    IncompatibleShapes { left: Shape, right: Shape },
36    /// Invalid empty shape.
37    Empty,
38}
39
40impl Shape {
41    /// Constructs a new `Shape`.
42    pub fn new<const D: usize>(dims: [usize; D]) -> Self {
43        // For backward compat
44        Self {
45            dims: dims.to_vec(),
46        }
47    }
48
49    /// Returns the total number of elements of a tensor having this shape
50    pub fn num_elements(&self) -> usize {
51        self.dims.iter().product()
52    }
53
54    /// Returns the number of dimensions.
55    ///
56    /// Alias for `Shape::rank()`.
57    pub fn num_dims(&self) -> usize {
58        self.dims.len()
59    }
60
61    /// Returns the rank (the number of dimensions).
62    ///
63    /// Alias for `Shape::num_dims()`.
64    pub fn rank(&self) -> usize {
65        self.num_dims()
66    }
67
68    // For compat with dims: [usize; D]
69    /// Returns the dimensions of the tensor as an array.
70    pub fn dims<const D: usize>(&self) -> [usize; D] {
71        let mut dims = [1; D];
72        dims[..D].copy_from_slice(&self.dims[..D]);
73        dims
74    }
75
76    /// Change the shape to one dimensional with the same number of elements.
77    pub fn flatten(mut self) -> Self {
78        self.dims = [self.num_elements()].into();
79        self
80    }
81
82    /// Compute the ravel index for the given coordinates.
83    ///
84    /// This returns the row-major order raveling:
85    /// * `strides[-1] = 1`
86    /// * `strides[i] = strides[i+1] * dims[i+1]`
87    /// * `dim_strides = coords * strides`
88    /// * `ravel = sum(dim_strides)`
89    ///
90    /// # Arguments
91    /// - `indices`: the index for each dimension; must be the same length as `shape`.
92    ///
93    /// # Returns
94    /// - the ravel offset index.
95    pub fn ravel_index<I: AsIndex>(&self, indices: &[I]) -> usize {
96        ravel_index(indices, &self.dims)
97    }
98
99    /// Convert shape dimensions to full covering ranges (0..dim) for each dimension.
100    pub fn into_ranges(self) -> Vec<Range<usize>> {
101        self.into_iter().map(|d| 0..d).collect()
102    }
103
104    /// Converts slice arguments into an array of slice specifications for the shape.
105    ///
106    /// This method returns an array of `Slice` objects that can be used for slicing operations.
107    /// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but
108    /// allows custom slice specifications instead of full ranges.
109    /// For creating complex slice specifications, use the [`s!`] macro.
110    ///
111    /// # Arguments
112    ///
113    /// * `slices` - An array of slice specifications, where each element can be:
114    ///   - A range (e.g., `2..5`)
115    ///   - An index
116    ///   - A `Slice` object
117    ///   - The output of the [`s!`] macro for advanced slicing
118    ///
119    /// # Behavior
120    ///
121    /// - Supports partial and full slicing in any number of dimensions.
122    /// - Missing ranges are treated as full slices if D > D2.
123    /// - Handles negative indices by wrapping around from the end of the dimension.
124    /// - Clamps ranges to the shape's dimensions if they exceed the bounds.
125    ///
126    /// # Returns
127    ///
128    /// An array of `Slice` objects corresponding to the provided slice specifications,
129    /// clamped to the shape's actual dimensions.
130    ///
131    /// # Examples
132    ///
133    /// ```rust
134    /// use burn_std::{Shape, Slice, s};
135    ///
136    /// fn example() {
137    ///     // 1D slicing
138    ///     let slices = Shape::new([4]).into_slices(1..4);
139    ///     assert_eq!(slices[0].to_range(4), 1..3);
140    ///
141    ///     // 2D slicing
142    ///     let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
143    ///     assert_eq!(slices[0].to_range(3), 1..3);
144    ///     assert_eq!(slices[1].to_range(4), 0..2);
145    ///
146    ///     // Using negative indices
147    ///     let slices = Shape::new([3]).into_slices(..-2);
148    ///     assert_eq!(slices[0].to_range(3), 0..1);
149    ///
150    ///     // Using the slice macro to select different ranges
151    ///     let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
152    ///     assert_eq!(slices[0].to_range(2), 0..2);
153    ///     assert_eq!(slices[1].to_range(3), 1..2);
154    /// }
155    /// ```
156    ///
157    /// # See Also
158    ///
159    /// - [`s!`] - The recommended macro for creating slice specifications
160    /// - [`Shape::into_ranges`] - Convert to full covering ranges
161    ///
162    /// [`s!`]: crate::s!
163    pub fn into_slices<const D: usize, S>(self, slices: S) -> [Slice; D]
164    where
165        S: SliceArg<D>,
166    {
167        slices.into_slices(self)
168    }
169
170    /// Construct a vector of the dims.
171    pub fn to_vec(&self) -> Vec<usize> {
172        self.dims.clone()
173    }
174
175    /// Returns an iterator over the shape dimensions.
176    pub fn iter(&self) -> Iter<'_, usize> {
177        self.dims.iter()
178    }
179
180    /// Mutable iterator over the dimensions.
181    pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
182        self.dims.iter_mut()
183    }
184
185    /// Borrow the underlying dimensions slice.
186    pub fn as_slice(&self) -> &[usize] {
187        &self.dims
188    }
189
190    /// Borrow the underlying dimensions slice mutably.
191    pub fn as_mut_slice(&mut self) -> &mut [usize] {
192        &mut self.dims
193    }
194
195    /// Insert a dimension of `size` at position `index`.
196    pub fn insert(&mut self, index: usize, size: usize) {
197        self.dims.insert(index, size);
198    }
199
200    /// Remove and return the dimension at position `index` from the shape.
201    pub fn remove(&mut self, index: usize) -> usize {
202        self.dims.remove(index)
203    }
204
205    /// Appends a dimension of `size` to the back of the shape.
206    pub fn push(&mut self, size: usize) {
207        self.dims.push(size)
208    }
209
210    /// Extend the shape with the content of another shape or iterator.
211    pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
212        self.dims.extend(iter)
213    }
214
215    /// Swap two dimensions in the shape.
216    pub fn swap(mut self, dim1: usize, dim2: usize) -> Result<Self, ShapeError> {
217        if dim1 > self.rank() {
218            return Err(ShapeError::OutOfBounds {
219                dim: dim1,
220                rank: self.rank(),
221            });
222        }
223        if dim2 > self.rank() {
224            return Err(ShapeError::OutOfBounds {
225                dim: dim2,
226                rank: self.rank(),
227            });
228        }
229        self.dims.swap(dim1, dim2);
230        Ok(self)
231    }
232
233    /// Reorder the shape dimensions according to the permutation of `axes`.
234    pub fn permute(mut self, axes: &[usize]) -> Result<Self, ShapeError> {
235        if axes.len() != self.rank() {
236            return Err(ShapeError::RankMismatch {
237                left: self.rank(),
238                right: axes.len(),
239            });
240        }
241        debug_assert!(axes.iter().all(|i| i < &self.rank()));
242
243        self.dims = axes.iter().map(|&i| self.dims[i]).collect();
244        Ok(self)
245    }
246
247    /// Repeated the specified `dim` a number of `times`.
248    pub fn repeat(mut self, dim: usize, times: usize) -> Result<Shape, ShapeError> {
249        if dim >= self.rank() {
250            return Err(ShapeError::OutOfBounds {
251                dim,
252                rank: self.rank(),
253            });
254        }
255
256        self.dims[dim] *= times;
257        Ok(self)
258    }
259
260    /// Returns a new shape where the specified `dim` is reduced to size 1.
261    pub fn reduce(mut self, dim: usize) -> Result<Shape, ShapeError> {
262        if dim >= self.rank() {
263            return Err(ShapeError::OutOfBounds {
264                dim,
265                rank: self.rank(),
266            });
267        }
268
269        self.dims[dim] = 1;
270        Ok(self)
271    }
272
273    /// Concatenates all shapes into a new one along the given dimension.
274    pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, ShapeError>
275    where
276        I: IntoIterator<Item = &'a Shape>,
277    {
278        let mut iter = shapes.into_iter();
279
280        let first = iter.next().ok_or(ShapeError::Empty)?;
281
282        if dim >= first.rank() {
283            return Err(ShapeError::OutOfBounds {
284                dim,
285                rank: first.rank(),
286            });
287        }
288
289        let mut shape = first.clone();
290
291        for s in iter {
292            if s.rank() != shape.rank() {
293                return Err(ShapeError::RankMismatch {
294                    left: shape.rank(),
295                    right: s.rank(),
296                });
297            }
298
299            if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
300                return Err(ShapeError::IncompatibleShapes {
301                    left: shape.clone(),
302                    right: s.clone(),
303                });
304            }
305
306            shape[dim] += s[dim];
307        }
308
309        Ok(shape)
310    }
311
312    /// Compute the output shape from the given slices.
313    pub fn slice(mut self, slices: &[Slice]) -> Result<Self, ShapeError> {
314        if slices.len() > self.rank() {
315            return Err(ShapeError::RankMismatch {
316                left: self.rank(),
317                right: slices.len(),
318            });
319        }
320
321        slices
322            .iter()
323            .zip(self.iter_mut())
324            .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
325
326        Ok(self)
327    }
328
329    /// Compute the output shape for binary operations with broadcasting support.
330    ///
331    /// - Shapes must be of the same rank (missing dimensions are not handled automatically).
332    /// - Two dimensions are compatible if they are equal, or one of them is 1.
333    ///
334    /// For example, a shape `[1, 1, 2, 4]` can be broadcast into `[7, 6, 2, 4]`
335    /// because its axes are either equal or 1. On the other hand, a shape `[2, 2]`
336    /// can *not* be broadcast into `[2, 4]`.
337    pub fn broadcast(&self, other: &Self) -> Result<Self, ShapeError> {
338        Self::broadcast_many([self, other])
339    }
340
341    /// Compute the broadcasted output shape across multiple input shapes.
342    ///
343    /// See also [broadcast](Self::broadcast).
344    pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, ShapeError>
345    where
346        I: IntoIterator<Item = &'a Shape>,
347    {
348        let mut iter = shapes.into_iter();
349        let mut broadcasted = iter.next().ok_or(ShapeError::Empty)?.clone();
350        let rank = broadcasted.rank();
351
352        for shape in iter {
353            if shape.rank() != rank {
354                return Err(ShapeError::RankMismatch {
355                    left: rank,
356                    right: shape.rank(),
357                });
358            }
359
360            for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
361                match (*d_lhs, d_rhs) {
362                    (a, b) if a == b => {} // same
363                    (1, b) => *d_lhs = b,  // broadcast to rhs
364                    (_a, 1) => {}          // keep existing dimension
365                    _ => {
366                        return Err(ShapeError::IncompatibleDims {
367                            left: *d_lhs,
368                            right: d_rhs,
369                            dim,
370                        });
371                    }
372                }
373            }
374        }
375
376        Ok(broadcasted)
377    }
378
379    /// Expand this shape to match the target shape, following broadcasting rules.
380    pub fn expand(&self, target: Shape) -> Result<Shape, ShapeError> {
381        let target_rank = target.rank();
382        if self.rank() > target_rank {
383            return Err(ShapeError::RankMismatch {
384                left: self.rank(),
385                right: target_rank,
386            });
387        }
388
389        for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() {
390            if dim_self != dim_target && *dim_self != 1 {
391                return Err(ShapeError::IncompatibleDims {
392                    left: *dim_self,
393                    right: *dim_target,
394                    dim: target_rank - i - 1,
395                });
396            }
397        }
398
399        Ok(target)
400    }
401
402    /// Reshape this shape to the target shape.
403    pub fn reshape(&self, target: Shape) -> Result<Shape, ShapeError> {
404        if self.num_elements() != target.num_elements() {
405            return Err(ShapeError::IncompatibleShapes {
406                left: self.clone(),
407                right: target,
408            });
409        }
410        Ok(target)
411    }
412}
413
414/// Compute the output shape for matrix multiplication with broadcasting support.
415///
416/// The last two dimensions are treated as matrices, while preceding dimensions
417/// follow broadcast semantics similar to elementwise operations.
418pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result<Shape, ShapeError> {
419    let rank = lhs.rank();
420    if rank != rhs.rank() {
421        return Err(ShapeError::RankMismatch {
422            left: rank,
423            right: rhs.rank(),
424        });
425    }
426
427    if lhs[rank - 1] != rhs[rank - 2] {
428        return Err(ShapeError::IncompatibleShapes {
429            left: lhs.clone(),
430            right: rhs.clone(),
431        });
432    }
433
434    let mut shape = if rank > 2 {
435        // Broadcast leading dims
436        Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
437    } else {
438        Shape::new([])
439    };
440    shape.extend([lhs[rank - 2], rhs[rank - 1]]);
441
442    Ok(shape)
443}
444
445impl IntoIterator for Shape {
446    type Item = usize;
447    type IntoIter = alloc::vec::IntoIter<Self::Item>;
448
449    fn into_iter(self) -> Self::IntoIter {
450        self.dims.into_iter()
451    }
452}
453
454impl<Idx> Index<Idx> for Shape
455where
456    Idx: SliceIndex<[usize]>,
457{
458    type Output = Idx::Output;
459
460    fn index(&self, index: Idx) -> &Self::Output {
461        &self.dims[index]
462    }
463}
464
465impl<Idx> IndexMut<Idx> for Shape
466where
467    Idx: SliceIndex<[usize]>,
468{
469    fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
470        &mut self.dims[index]
471    }
472}
473
474// Allow `&shape` to behave like a slice `&[usize]` directly
475impl Deref for Shape {
476    type Target = [usize];
477
478    fn deref(&self) -> &Self::Target {
479        &self.dims
480    }
481}
482
483// Allow `&shape` to behave like a mut slice `&mut [usize]` directly
484impl DerefMut for Shape {
485    fn deref_mut(&mut self) -> &mut Self::Target {
486        &mut self.dims
487    }
488}
489
490// Conversion sugar
491impl<const D: usize> From<[usize; D]> for Shape {
492    fn from(dims: [usize; D]) -> Self {
493        Shape::new(dims)
494    }
495}
496
497impl<const D: usize> From<[i64; D]> for Shape {
498    fn from(dims: [i64; D]) -> Self {
499        Shape {
500            dims: dims.into_iter().map(|d| d as usize).collect(),
501        }
502    }
503}
504
505impl<const D: usize> From<[i32; D]> for Shape {
506    fn from(dims: [i32; D]) -> Self {
507        Shape {
508            dims: dims.into_iter().map(|d| d as usize).collect(),
509        }
510    }
511}
512
513impl From<&[usize]> for Shape {
514    fn from(dims: &[usize]) -> Self {
515        Shape { dims: dims.into() }
516    }
517}
518
519impl From<Vec<i64>> for Shape {
520    fn from(shape: Vec<i64>) -> Self {
521        Self {
522            dims: shape.into_iter().map(|d| d as usize).collect(),
523        }
524    }
525}
526
527impl From<Vec<u64>> for Shape {
528    fn from(shape: Vec<u64>) -> Self {
529        Self {
530            dims: shape.into_iter().map(|d| d as usize).collect(),
531        }
532    }
533}
534
535impl From<Vec<usize>> for Shape {
536    fn from(shape: Vec<usize>) -> Self {
537        Self { dims: shape }
538    }
539}
540
541impl From<&Vec<usize>> for Shape {
542    fn from(shape: &Vec<usize>) -> Self {
543        Self {
544            dims: shape.clone(),
545        }
546    }
547}
548
549impl From<Shape> for Vec<usize> {
550    fn from(shape: Shape) -> Self {
551        shape.dims
552    }
553}
554
555#[cfg(test)]
556#[allow(clippy::identity_op, reason = "useful for clarity")]
557mod tests {
558    use super::*;
559    use crate::s;
560    use alloc::vec;
561
562    #[test]
563    fn num_dims_and_rank() {
564        let dims = [2, 3, 4, 5];
565        let shape = Shape::new(dims);
566        assert_eq!(4, shape.num_dims());
567        assert_eq!(4, shape.rank());
568    }
569
570    #[test]
571    fn num_elements() {
572        let dims = [2, 3, 4, 5];
573        let shape = Shape::new(dims);
574        assert_eq!(120, shape.num_elements());
575    }
576
577    #[test]
578    fn test_shape_into_iter() {
579        let dims = [2, 3, 4, 5];
580        let shape = Shape::new(dims);
581
582        assert_eq!(shape.into_iter().sum::<usize>(), 14);
583    }
584
585    #[test]
586    fn test_into_ranges() {
587        let dims = [2, 3, 4, 5];
588        let shape = Shape::new(dims);
589        assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
590    }
591
592    #[test]
593    fn test_to_vec() {
594        let dims = [2, 3, 4, 5];
595        let shape = Shape::new(dims);
596        assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
597    }
598
599    #[allow(clippy::single_range_in_vec_init)]
600    #[test]
601    fn test_into_slices() {
602        let slices = Shape::new([3]).into_slices(1..4);
603        assert_eq!(slices[0].to_range(3), 1..3);
604
605        let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
606        assert_eq!(slices[0].to_range(3), 1..3);
607        assert_eq!(slices[1].to_range(4), 0..2);
608
609        let slices = Shape::new([3]).into_slices(..-2);
610        assert_eq!(slices[0].to_range(3), 0..1);
611
612        let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
613        assert_eq!(slices[0].to_range(2), 0..2);
614        assert_eq!(slices[1].to_range(3), 1..2);
615
616        let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
617        assert_eq!(slices[0].to_range(2), 0..2);
618        assert_eq!(slices[1].to_range(3), 2..3);
619    }
620
621    #[test]
622    fn test_shape_index() {
623        let shape = Shape::new([2, 3, 4, 5]);
624
625        assert_eq!(shape[0], 2);
626        assert_eq!(shape[1], 3);
627        assert_eq!(shape[2], 4);
628        assert_eq!(shape[3], 5);
629
630        // Works with ranges
631        assert_eq!(shape[1..3], *&[3, 4]);
632        assert_eq!(shape[1..=2], *&[3, 4]);
633        assert_eq!(shape[..], *&[2, 3, 4, 5]);
634    }
635
636    #[test]
637    fn test_shape_slice_methods() {
638        let shape = Shape::new([2, 3, 4, 5]);
639
640        let dim = shape.first();
641        assert_eq!(dim, Some(&2));
642        let dim = shape.last();
643        assert_eq!(dim, Some(&5));
644
645        assert!(!shape.is_empty());
646        let shape = Shape::new([]);
647        assert!(shape.is_empty());
648    }
649
650    #[test]
651    fn test_shape_iter() {
652        let dims = [2, 3, 4, 5];
653        let shape = Shape::new(dims);
654
655        for (d, sd) in dims.iter().zip(shape.iter()) {
656            assert_eq!(d, sd);
657        }
658    }
659
660    #[test]
661    fn test_shape_iter_mut() {
662        let mut shape = Shape::new([2, 3, 4, 5]);
663
664        for d in shape.iter_mut() {
665            *d += 1;
666        }
667
668        assert_eq!(&shape.dims, &[3, 4, 5, 6]);
669    }
670
671    #[test]
672    fn test_shape_as_slice() {
673        let dims = [2, 3, 4, 5];
674        let shape = Shape::new(dims);
675
676        assert_eq!(shape.as_slice(), dims.as_slice());
677
678        // Deref coercion
679        let shape_slice: &[usize] = &shape;
680        assert_eq!(shape_slice, *&[2, 3, 4, 5]);
681    }
682
683    #[test]
684    fn test_shape_as_mut_slice() {
685        let mut dims = [2, 3, 4, 5];
686        let mut shape = Shape::new(dims);
687
688        let shape_mut = shape.as_mut_slice();
689        assert_eq!(shape_mut, dims.as_mut_slice());
690        shape_mut[1] = 6;
691
692        assert_eq!(shape_mut, &[2, 6, 4, 5]);
693
694        let mut shape = Shape::new(dims);
695        let shape = &mut shape[..];
696        shape[1] = 6;
697
698        assert_eq!(shape, shape_mut)
699    }
700
701    #[test]
702    fn test_shape_flatten() {
703        let shape = Shape::new([2, 3, 4, 5]);
704        assert_eq!(shape.num_elements(), 120);
705
706        let shape = shape.flatten();
707        assert_eq!(shape.num_elements(), 120);
708        assert_eq!(&shape.dims, &[120]);
709    }
710
711    #[test]
712    fn test_ravel() {
713        let shape = Shape::new([2, 3, 4, 5]);
714
715        assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0);
716        assert_eq!(
717            shape.ravel_index(&[1, 2, 3, 4]),
718            1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
719        );
720    }
721
722    #[test]
723    fn test_shape_insert_remove_push() {
724        let dims = [2, 3, 4, 5];
725        let mut shape = Shape::new(dims);
726        let size = 6;
727        shape.insert(1, size);
728
729        assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
730
731        let removed = shape.remove(1);
732        assert_eq!(removed, size);
733        assert_eq!(shape, Shape::new(dims));
734
735        shape.push(6);
736        assert_eq!(shape, Shape::new([2, 3, 4, 5, 6]));
737    }
738
739    #[test]
740    fn test_shape_swap_permute() {
741        let dims = [2, 3, 4, 5];
742        let shape = Shape::new(dims);
743        let shape = shape.swap(1, 2).unwrap();
744
745        assert_eq!(&shape.dims, &[2, 4, 3, 5]);
746
747        let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
748        assert_eq!(shape, Shape::new(dims));
749    }
750
751    #[test]
752    #[should_panic]
753    fn test_shape_swap_out_of_bounds() {
754        let shape = Shape::new([2, 3, 4, 5]);
755
756        shape.swap(0, 4).unwrap();
757    }
758
759    #[test]
760    #[should_panic]
761    fn test_shape_permute_incomplete() {
762        let shape = Shape::new([2, 3, 4, 5]);
763
764        shape.permute(&[0, 2, 1]).unwrap();
765    }
766
767    #[test]
768    fn test_shape_repeat() {
769        let shape = Shape::new([2, 3, 4, 5]);
770
771        let out = shape.repeat(2, 3).unwrap();
772        assert_eq!(out, Shape::new([2, 3, 12, 5]));
773    }
774
775    #[test]
776    fn test_shape_repeat_invalid() {
777        let shape = Shape::new([2, 3, 4, 5]);
778
779        let out = shape.repeat(5, 3);
780        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
781    }
782
783    #[test]
784    fn test_shape_reduce() {
785        let shape = Shape::new([2, 3, 4, 5]);
786
787        let out = shape.reduce(2).unwrap();
788        assert_eq!(out, Shape::new([2, 3, 1, 5]));
789    }
790
791    #[test]
792    fn test_shape_reduce_invalid() {
793        let shape = Shape::new([2, 3, 4, 5]);
794
795        let out = shape.reduce(5);
796        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
797    }
798
799    #[test]
800    fn test_shape_broadcast_binary() {
801        let lhs = Shape::new([1, 1, 2, 4]);
802        let rhs = Shape::new([7, 6, 2, 1]);
803
804        let out = lhs.broadcast(&rhs).unwrap();
805        assert_eq!(out, Shape::new([7, 6, 2, 4]));
806    }
807
808    #[test]
809    fn test_shape_broadcast_rank_mismatch() {
810        let lhs = Shape::new([1, 2, 4]);
811        let rhs = Shape::new([7, 6, 2, 4]);
812
813        let out = lhs.broadcast(&rhs);
814        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
815    }
816
817    #[test]
818    fn test_shape_broadcast_incompatible_dims() {
819        let lhs = Shape::new([1, 2, 2, 4]);
820        let rhs = Shape::new([7, 6, 2, 1]);
821
822        let out = lhs.broadcast(&rhs);
823        assert_eq!(
824            out,
825            Err(ShapeError::IncompatibleDims {
826                left: 2,
827                right: 6,
828                dim: 1
829            })
830        );
831    }
832
833    #[test]
834    fn test_shape_broadcast_many() {
835        let s1 = Shape::new([1, 1, 2, 4]);
836        let s2 = Shape::new([7, 1, 2, 1]);
837        let s3 = Shape::new([7, 6, 1, 1]);
838
839        let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
840        assert_eq!(out, Shape::new([7, 6, 2, 4]));
841    }
842
843    #[test]
844    fn test_shape_broadcast_many_rank_mismatch() {
845        let s1 = Shape::new([1, 1, 2, 4]);
846        let s2 = Shape::new([7, 1, 2, 1]);
847        let s3 = Shape::new([1, 6, 1]);
848
849        let out = Shape::broadcast_many([&s1, &s2, &s3]);
850        assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 }));
851    }
852
853    #[test]
854    fn test_shape_broadcast_many_incompatible_dims() {
855        let s1 = Shape::new([1, 1, 2, 4]);
856        let s2 = Shape::new([7, 1, 2, 1]);
857        let s3 = Shape::new([4, 6, 1, 1]);
858
859        let out = Shape::broadcast_many([&s1, &s2, &s3]);
860        assert_eq!(
861            out,
862            Err(ShapeError::IncompatibleDims {
863                left: 7,
864                right: 4,
865                dim: 0
866            })
867        );
868    }
869
870    #[test]
871    fn test_shape_broadcast_many_empty() {
872        let out = Shape::broadcast_many(&[]);
873        assert_eq!(out, Err(ShapeError::Empty));
874    }
875
876    #[test]
877    fn test_shape_matmul_2d() {
878        let lhs = Shape::new([2, 4]);
879        let rhs = Shape::new([4, 2]);
880        let out = calculate_matmul_output(&lhs, &rhs).unwrap();
881        assert_eq!(out, Shape::new([2, 2]));
882    }
883
884    #[test]
885    fn test_shape_matmul_4d_broadcasted() {
886        let lhs = Shape::new([1, 3, 2, 4]);
887        let rhs = Shape::new([2, 1, 4, 2]);
888        let out = calculate_matmul_output(&lhs, &rhs).unwrap();
889        assert_eq!(out, Shape::new([2, 3, 2, 2]));
890    }
891
892    #[test]
893    fn test_shape_matmul_invalid_rank() {
894        let lhs = Shape::new([3, 2, 4]);
895        let rhs = Shape::new([2, 1, 4, 2]);
896        let out = calculate_matmul_output(&lhs, &rhs);
897        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
898    }
899
900    #[test]
901    fn test_shape_matmul_invalid_shape() {
902        let lhs = Shape::new([1, 3, 2, 4]);
903        let rhs = Shape::new([2, 1, 3, 2]);
904        let out = calculate_matmul_output(&lhs, &rhs);
905        assert_eq!(
906            out,
907            Err(ShapeError::IncompatibleShapes {
908                left: lhs,
909                right: rhs
910            })
911        );
912    }
913
914    #[test]
915    fn test_shape_matmul_invalid_broadcast() {
916        let lhs = Shape::new([1, 3, 2, 4]);
917        let rhs = Shape::new([2, 2, 4, 2]);
918        let out = calculate_matmul_output(&lhs, &rhs);
919        assert_eq!(
920            out,
921            Err(ShapeError::IncompatibleDims {
922                left: 3,
923                right: 2,
924                dim: 1
925            })
926        );
927    }
928
929    #[test]
930    fn test_shape_cat() {
931        let s1 = Shape::new([2, 3, 4, 5]);
932        let s2 = Shape::new([1, 3, 4, 5]);
933        let s3 = Shape::new([4, 3, 4, 5]);
934
935        let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
936        assert_eq!(out, Shape::new([7, 3, 4, 5]));
937
938        let s1 = Shape::new([2, 3, 4, 5]);
939        let s2 = Shape::new([2, 3, 2, 5]);
940        let s3 = Shape::new([2, 3, 1, 5]);
941
942        let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
943        assert_eq!(out, Shape::new([2, 3, 7, 5]));
944    }
945
946    #[test]
947    fn test_shape_cat_empty() {
948        let out = Shape::cat(&[], 0);
949        assert_eq!(out, Err(ShapeError::Empty));
950    }
951
952    #[test]
953    fn test_shape_cat_dim_out_of_bounds() {
954        let s1 = Shape::new([2, 3, 4, 5]);
955        let s2 = Shape::new([2, 3, 4, 5]);
956        let out = Shape::cat(&[s1, s2], 4);
957        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 }));
958    }
959
960    #[test]
961    fn test_shape_cat_rank_mismatch() {
962        let s1 = Shape::new([2, 3, 4, 5]);
963        let s2 = Shape::new([2, 3, 4, 5, 6]);
964        let out = Shape::cat(&[s1, s2], 0);
965        assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 }));
966    }
967
968    #[test]
969    fn test_shape_cat_incompatible_shapes() {
970        let s1 = Shape::new([2, 3, 4, 5]);
971        let s2 = Shape::new([1, 3, 4, 5]);
972        let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
973
974        assert_eq!(
975            out,
976            Err(ShapeError::IncompatibleShapes {
977                left: s1,
978                right: s2
979            })
980        );
981    }
982
983    #[test]
984    fn test_shape_slice_output_shape_basic() {
985        // Test basic slicing with step=1
986        let slices = [
987            Slice::new(0, Some(5), 1), // 5 elements
988            Slice::new(2, Some(8), 1), // 6 elements
989        ];
990        let original_shape = Shape::new([10, 10, 10]);
991        let result = original_shape.slice(&slices).unwrap();
992        assert_eq!(result, Shape::new([5, 6, 10]));
993    }
994
995    #[test]
996    fn test_shape_slice_output_shape_with_positive_steps() {
997        // Test slicing with various positive steps
998        let slices = [
999            Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements
1000            Slice::new(1, Some(9), 3),  // [1,4,7] -> 3 elements
1001            Slice::new(0, Some(7), 4),  // [0,4] -> 2 elements
1002        ];
1003        let original_shape = Shape::new([20, 20, 20, 30]);
1004        let result = original_shape.slice(&slices).unwrap();
1005        assert_eq!(result, Shape::new([5, 3, 2, 30]));
1006    }
1007
1008    #[test]
1009    fn test_shape_slice_output_shape_with_negative_steps() {
1010        // Test slicing with negative steps (backward iteration)
1011        let slices = [
1012            Slice::new(0, Some(10), -1), // 10 elements traversed backward
1013            Slice::new(2, Some(8), -2),  // [7,5,3] -> 3 elements
1014        ];
1015        let original_shape = Shape::new([20, 20, 20]);
1016        let result = original_shape.slice(&slices).unwrap();
1017        assert_eq!(result, Shape::new([10, 3, 20]));
1018    }
1019
1020    #[test]
1021    fn test_shape_slice_output_shape_mixed_steps() {
1022        // Test with a mix of positive, negative, and unit steps
1023        let slices = [
1024            Slice::from_range_stepped(1..6, 1),   // 5 elements
1025            Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements
1026            Slice::from_range_stepped(2..14, 4),  // [2,6,10] -> 3 elements
1027        ];
1028        let original_shape = Shape::new([20, 20, 20]);
1029        let result = original_shape.slice(&slices).unwrap();
1030        assert_eq!(result, Shape::new([5, 4, 3]));
1031    }
1032
1033    #[test]
1034    fn test_shape_slice_output_shape_partial_dims() {
1035        // Test when slices has fewer dimensions than original shape
1036        let slices = [
1037            Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements
1038        ];
1039        let original_shape = Shape::new([10, 20, 30, 40]);
1040        let result = original_shape.slice(&slices).unwrap();
1041        assert_eq!(result, Shape::new([3, 20, 30, 40]));
1042    }
1043
1044    #[test]
1045    fn test_shape_slice_output_shape_edge_cases() {
1046        // Test edge cases with small ranges and large steps
1047        let slices = [
1048            Slice::from_range_stepped(0..1, 1),    // Single element
1049            Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element
1050            Slice::from_range_stepped(5..5, 1),    // Empty range -> 0 elements
1051        ];
1052        let original_shape = Shape::new([10, 20, 30]);
1053        let result = original_shape.slice(&slices).unwrap();
1054        assert_eq!(result, Shape::new([1, 1, 0]));
1055    }
1056
1057    #[test]
1058    fn test_shape_slice_output_shape_empty() {
1059        // Test with no slice infos (should return original shape)
1060        let slices = [];
1061        let original_shape = Shape::new([10, 20, 30]);
1062        let result = original_shape.slice(&slices).unwrap();
1063        assert_eq!(result, Shape::new([10, 20, 30]));
1064    }
1065
1066    #[test]
1067    fn test_shape_slice_output_shape_uneven_division() {
1068        // Test cases where range size doesn't divide evenly by step
1069        let slices = [
1070            Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6]
1071            Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8]
1072            Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6]
1073        ];
1074        let original_shape = Shape::new([20, 20, 20]);
1075        let result = original_shape.slice(&slices).unwrap();
1076        assert_eq!(result, Shape::new([3, 3, 2]));
1077    }
1078
1079    #[test]
1080    fn test_shape_expand() {
1081        let shape = Shape::new([1, 3, 1]);
1082        let expanded = Shape::new([2, 3, 4]);
1083        let out = shape.expand(expanded.clone()).unwrap();
1084        assert_eq!(out, expanded);
1085    }
1086
1087    #[test]
1088    fn test_shape_expand_higher_rank() {
1089        let shape = Shape::new([1, 4]);
1090        let expanded = Shape::new([2, 3, 4]);
1091        let out = shape.expand(expanded.clone()).unwrap();
1092        assert_eq!(out, expanded);
1093    }
1094
1095    #[test]
1096    fn test_shape_expand_invalid_rank() {
1097        let shape = Shape::new([1, 3, 1]);
1098        let expanded = Shape::new([3, 4]);
1099        let out = shape.expand(expanded);
1100        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 2 }));
1101    }
1102
1103    #[test]
1104    fn test_shape_expand_incompatible_dims() {
1105        let shape = Shape::new([1, 3, 2]);
1106        let expanded = Shape::new([2, 3, 4]);
1107        let out = shape.expand(expanded);
1108        assert_eq!(
1109            out,
1110            Err(ShapeError::IncompatibleDims {
1111                left: 2,
1112                right: 4,
1113                dim: 2
1114            })
1115        );
1116    }
1117
1118    #[test]
1119    fn test_shape_reshape() {
1120        let shape = Shape::new([2, 3, 4, 5]);
1121        let reshaped = Shape::new([1, 2, 12, 5]);
1122        let out = shape.reshape(reshaped.clone()).unwrap();
1123        assert_eq!(out, reshaped);
1124    }
1125
1126    #[test]
1127    fn test_shape_reshape_invalid() {
1128        let shape = Shape::new([2, 3, 4, 5]);
1129        let reshaped = Shape::new([2, 2, 12, 5]);
1130        let out = shape.clone().reshape(reshaped.clone());
1131        assert_eq!(
1132            out,
1133            Err(ShapeError::IncompatibleShapes {
1134                left: shape,
1135                right: reshaped
1136            })
1137        );
1138    }
1139}