burn_std/tensor/
shape.rs

1//! Tensor shape definition.
2
3use super::indexing::ravel_index;
4use super::{AsIndex, Slice, SliceArg};
5use alloc::format;
6use alloc::string::{String, ToString};
7use alloc::vec;
8use alloc::vec::Vec;
9use core::fmt::{Debug, Display, Formatter};
10use core::str::FromStr;
11use core::{
12    ops::{Deref, DerefMut, Index, IndexMut, Range},
13    slice::{Iter, IterMut, SliceIndex},
14};
15use serde::{Deserialize, Serialize};
16
17pub use crate::errors::ExpressionError;
18pub use crate::tensor::index_conversion::AsSize;
19
20/// Shape of a tensor.
21#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
22pub struct Shape {
23    /// The dimensions of the tensor.
24    pub dims: Vec<usize>,
25}
26
27#[allow(missing_docs)]
28#[derive(Debug, Clone, PartialEq, Eq)]
29/// Error that can occur when attempting to modify shapes.
30pub enum ShapeError {
31    /// The operands have different ranks.
32    RankMismatch { left: usize, right: usize },
33    /// A pair of dimensions are incompatible for broadcasting.
34    IncompatibleDims {
35        left: usize,
36        right: usize,
37        dim: usize,
38    },
39    /// Invalid dimension specified for the rank.
40    OutOfBounds { dim: usize, rank: usize },
41    /// A pair of shapes are incompatible for the operation.
42    IncompatibleShapes { left: Shape, right: Shape },
43    /// Invalid shape.
44    Invalid { reason: String },
45}
46
47impl ShapeError {
48    fn empty() -> Self {
49        Self::Invalid {
50            reason: "Shape is empty.".into(),
51        }
52    }
53}
54
55impl Shape {
56    /// Constructs a new `Shape`.
57    pub fn new<const D: usize>(dims: [usize; D]) -> Self {
58        // For backward compat
59        Self {
60            dims: dims.to_vec(),
61        }
62    }
63
64    /// Returns the total number of elements of a tensor having this shape
65    pub fn num_elements(&self) -> usize {
66        self.dims.iter().product()
67    }
68
69    /// Returns the number of dimensions.
70    ///
71    /// Alias for `Shape::rank()`.
72    pub fn num_dims(&self) -> usize {
73        self.dims.len()
74    }
75
76    /// Returns the rank (the number of dimensions).
77    ///
78    /// Alias for `Shape::num_dims()`.
79    pub fn rank(&self) -> usize {
80        self.num_dims()
81    }
82
83    // For compat with dims: [usize; D]
84    /// Returns the dimensions of the tensor as an array.
85    pub fn dims<const D: usize>(&self) -> [usize; D] {
86        let mut dims = [1; D];
87        dims[..D].copy_from_slice(&self.dims[..D]);
88        dims
89    }
90
91    /// Change the shape to one dimensional with the same number of elements.
92    pub fn flatten(mut self) -> Self {
93        self.dims = [self.num_elements()].into();
94        self
95    }
96
97    /// Flatten the shape along a given range of dimensions.
98    ///
99    /// This function collapses the specified range of dimensions into a single dimension,
100    /// effectively flattening the tensor in that range.
101    ///
102    /// # Arguments
103    ///
104    /// - `start_dim`: The starting dimension of the range to be flattened,
105    ///   supports negative indexing.
106    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
107    ///   supports negative indexing.
108    ///
109    /// # Returns
110    ///
111    /// A new `Shape` instance with the specified range of dimensions flattened.
112    ///
113    /// # Example
114    ///
115    /// ```rust
116    /// use burn_std::Shape;
117    ///
118    /// fn example() {
119    ///     let shape = Shape::new([2, 3, 4]);
120    ///
121    ///     let flattened = shape.flatten_dims(1, 2);
122    ///     println!("{flattened}");
123    ///     // [2, 12]
124    /// }
125    /// ```
126    pub fn flatten_dims(self, start_dim: impl AsIndex, end_dim: impl AsIndex) -> Self {
127        let rank = self.rank();
128        let start = start_dim.expect_dim_index(rank);
129        let end = end_dim.expect_dim_index(rank);
130
131        assert!(
132            start <= end,
133            "start_dim ({start}) must be <= than end_dim ({end})"
134        );
135
136        let existing = self.dims;
137
138        let flattened_size = existing[start..=end].iter().product();
139
140        let new_rank = rank - (end - start);
141        let mut dims = vec![0; new_rank];
142        dims[..start].copy_from_slice(&existing[..start]);
143        dims[start] = flattened_size;
144        dims[start + 1..].copy_from_slice(&existing[end + 1..]);
145
146        Self { dims }
147    }
148
149    /// Compute the ravel index for the given coordinates.
150    ///
151    /// This returns the row-major order raveling:
152    /// * `strides[-1] = 1`
153    /// * `strides[i] = strides[i+1] * dims[i+1]`
154    /// * `dim_strides = coords * strides`
155    /// * `ravel = sum(dim_strides)`
156    ///
157    /// # Arguments
158    /// - `indices`: the index for each dimension; must be the same length as `shape`.
159    ///
160    /// # Returns
161    /// - the ravel offset index.
162    pub fn ravel_index<I: AsIndex>(&self, indices: &[I]) -> usize {
163        ravel_index(indices, &self.dims)
164    }
165
166    /// Convert shape dimensions to full covering ranges (0..dim) for each dimension.
167    pub fn into_ranges(self) -> Vec<Range<usize>> {
168        self.iter().map(|&d| 0..d).collect()
169    }
170
171    /// Converts slice arguments into an array of slice specifications for the shape.
172    ///
173    /// This method returns an array of `Slice` objects that can be used for slicing operations.
174    /// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but
175    /// allows custom slice specifications instead of full ranges.
176    /// For creating complex slice specifications, use the [`s!`] macro.
177    ///
178    /// # Arguments
179    ///
180    /// * `slices` - An array of slice specifications, where each element can be:
181    ///   - A range (e.g., `2..5`)
182    ///   - An index
183    ///   - A `Slice` object
184    ///   - The output of the [`s!`] macro for advanced slicing
185    ///
186    /// # Behavior
187    ///
188    /// - Supports partial and full slicing in any number of dimensions.
189    /// - Missing ranges are treated as full slices if D > D2.
190    /// - Handles negative indices by wrapping around from the end of the dimension.
191    /// - Clamps ranges to the shape's dimensions if they exceed the bounds.
192    ///
193    /// # Returns
194    ///
195    /// An array of `Slice` objects corresponding to the provided slice specifications,
196    /// clamped to the shape's actual dimensions.
197    ///
198    /// # Examples
199    ///
200    /// ```rust
201    /// use burn_std::{Shape, Slice, s};
202    ///
203    /// fn example() {
204    ///     // 1D slicing
205    ///     let slices = Shape::new([4]).into_slices(1..4);
206    ///     assert_eq!(slices[0].to_range(4), 1..3);
207    ///
208    ///     // 2D slicing
209    ///     let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
210    ///     assert_eq!(slices[0].to_range(3), 1..3);
211    ///     assert_eq!(slices[1].to_range(4), 0..2);
212    ///
213    ///     // Using negative indices
214    ///     let slices = Shape::new([3]).into_slices(..-2);
215    ///     assert_eq!(slices[0].to_range(3), 0..1);
216    ///
217    ///     // Using the slice macro to select different ranges
218    ///     let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
219    ///     assert_eq!(slices[0].to_range(2), 0..2);
220    ///     assert_eq!(slices[1].to_range(3), 1..2);
221    /// }
222    /// ```
223    ///
224    /// # See Also
225    ///
226    /// - [`s!`] - The recommended macro for creating slice specifications
227    /// - [`Shape::into_ranges`] - Convert to full covering ranges
228    ///
229    /// [`s!`]: crate::s!
230    pub fn into_slices<S>(self, slices: S) -> Vec<Slice>
231    where
232        S: SliceArg,
233    {
234        slices.into_slices(&self)
235    }
236
237    /// Construct a vector of the dims.
238    pub fn to_vec(&self) -> Vec<usize> {
239        self.dims.clone()
240    }
241
242    /// Returns an iterator over the shape dimensions.
243    pub fn iter(&self) -> Iter<'_, usize> {
244        self.dims.iter()
245    }
246
247    /// Mutable iterator over the dimensions.
248    pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
249        self.dims.iter_mut()
250    }
251
252    /// Borrow the underlying dimensions slice.
253    pub fn as_slice(&self) -> &[usize] {
254        &self.dims
255    }
256
257    /// Borrow the underlying dimensions slice mutably.
258    pub fn as_mut_slice(&mut self) -> &mut [usize] {
259        &mut self.dims
260    }
261
262    /// Insert a dimension of `size` at position `index`.
263    pub fn insert(&mut self, index: usize, size: usize) {
264        self.dims.insert(index, size);
265    }
266
267    /// Remove and return the dimension at position `index` from the shape.
268    pub fn remove(&mut self, index: usize) -> usize {
269        self.dims.remove(index)
270    }
271
272    /// Appends a dimension of `size` to the back of the shape.
273    pub fn push(&mut self, size: usize) {
274        self.dims.push(size)
275    }
276
277    /// Extend the shape with the content of another shape or iterator.
278    pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
279        self.dims.extend(iter)
280    }
281
282    /// Swap two dimensions in the shape.
283    pub fn swap(mut self, dim1: usize, dim2: usize) -> Result<Self, ShapeError> {
284        if dim1 > self.rank() {
285            return Err(ShapeError::OutOfBounds {
286                dim: dim1,
287                rank: self.rank(),
288            });
289        }
290        if dim2 > self.rank() {
291            return Err(ShapeError::OutOfBounds {
292                dim: dim2,
293                rank: self.rank(),
294            });
295        }
296        self.dims.swap(dim1, dim2);
297        Ok(self)
298    }
299
300    /// Reorder the shape dimensions according to the permutation of `axes`.
301    pub fn permute(mut self, axes: &[usize]) -> Result<Self, ShapeError> {
302        if axes.len() != self.rank() {
303            return Err(ShapeError::RankMismatch {
304                left: self.rank(),
305                right: axes.len(),
306            });
307        }
308        debug_assert!(axes.iter().all(|i| i < &self.rank()));
309
310        self.dims = axes.iter().map(|&i| self.dims[i]).collect();
311        Ok(self)
312    }
313
314    /// Repeated the specified `dim` a number of `times`.
315    pub fn repeat(mut self, dim: usize, times: usize) -> Result<Shape, ShapeError> {
316        if dim >= self.rank() {
317            return Err(ShapeError::OutOfBounds {
318                dim,
319                rank: self.rank(),
320            });
321        }
322
323        self.dims[dim] *= times;
324        Ok(self)
325    }
326
327    /// Returns a new shape where the specified `dim` is reduced to size 1.
328    pub fn reduce(mut self, dim: usize) -> Result<Shape, ShapeError> {
329        if dim >= self.rank() {
330            return Err(ShapeError::OutOfBounds {
331                dim,
332                rank: self.rank(),
333            });
334        }
335
336        self.dims[dim] = 1;
337        Ok(self)
338    }
339
340    /// Concatenates all shapes into a new one along the given dimension.
341    pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, ShapeError>
342    where
343        I: IntoIterator<Item = &'a Shape>,
344    {
345        let mut iter = shapes.into_iter();
346
347        let first = iter.next().ok_or(ShapeError::empty())?;
348
349        if dim >= first.rank() {
350            return Err(ShapeError::OutOfBounds {
351                dim,
352                rank: first.rank(),
353            });
354        }
355
356        let mut shape = first.clone();
357
358        for s in iter {
359            if s.rank() != shape.rank() {
360                return Err(ShapeError::RankMismatch {
361                    left: shape.rank(),
362                    right: s.rank(),
363                });
364            }
365
366            if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
367                return Err(ShapeError::IncompatibleShapes {
368                    left: shape.clone(),
369                    right: s.clone(),
370                });
371            }
372
373            shape[dim] += s[dim];
374        }
375
376        Ok(shape)
377    }
378
379    /// Compute the output shape from the given slices.
380    pub fn slice(mut self, slices: &[Slice]) -> Result<Self, ShapeError> {
381        if slices.len() > self.rank() {
382            return Err(ShapeError::RankMismatch {
383                left: self.rank(),
384                right: slices.len(),
385            });
386        }
387
388        slices
389            .iter()
390            .zip(self.iter_mut())
391            .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
392
393        Ok(self)
394    }
395
396    /// Compute the output shape for binary operations with broadcasting support.
397    ///
398    /// - Shapes must be of the same rank (missing dimensions are not handled automatically).
399    /// - Two dimensions are compatible if they are equal, or one of them is 1.
400    ///
401    /// For example, a shape `[1, 1, 2, 4]` can be broadcast into `[7, 6, 2, 4]`
402    /// because its axes are either equal or 1. On the other hand, a shape `[2, 2]`
403    /// can *not* be broadcast into `[2, 4]`.
404    pub fn broadcast(&self, other: &Self) -> Result<Self, ShapeError> {
405        Self::broadcast_many([self, other])
406    }
407
408    /// Compute the broadcasted output shape across multiple input shapes.
409    ///
410    /// See also [broadcast](Self::broadcast).
411    pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, ShapeError>
412    where
413        I: IntoIterator<Item = &'a Shape>,
414    {
415        let mut iter = shapes.into_iter();
416        let mut broadcasted = iter.next().ok_or(ShapeError::empty())?.clone();
417        let rank = broadcasted.rank();
418
419        for shape in iter {
420            if shape.rank() != rank {
421                return Err(ShapeError::RankMismatch {
422                    left: rank,
423                    right: shape.rank(),
424                });
425            }
426
427            for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
428                match (*d_lhs, d_rhs) {
429                    (a, b) if a == b => {} // same
430                    (1, b) => *d_lhs = b,  // broadcast to rhs
431                    (_a, 1) => {}          // keep existing dimension
432                    _ => {
433                        return Err(ShapeError::IncompatibleDims {
434                            left: *d_lhs,
435                            right: d_rhs,
436                            dim,
437                        });
438                    }
439                }
440            }
441        }
442
443        Ok(broadcasted)
444    }
445
446    /// Expand this shape to match the target shape, following broadcasting rules.
447    pub fn expand(&self, target: Shape) -> Result<Shape, ShapeError> {
448        let target_rank = target.rank();
449        if self.rank() > target_rank {
450            return Err(ShapeError::RankMismatch {
451                left: self.rank(),
452                right: target_rank,
453            });
454        }
455
456        for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() {
457            if dim_self != dim_target && *dim_self != 1 {
458                return Err(ShapeError::IncompatibleDims {
459                    left: *dim_self,
460                    right: *dim_target,
461                    dim: target_rank - i - 1,
462                });
463            }
464        }
465
466        Ok(target)
467    }
468
469    /// Reshape this shape to the target shape.
470    pub fn reshape<A, T>(&self, args: A) -> Result<Shape, ShapeError>
471    where
472        A: AsRef<[T]> + Debug,
473        T: AsIndex,
474    {
475        let args = args.as_ref();
476        let mut infer_index = None;
477        let mut dims = Vec::new();
478
479        let mut new_size = 1;
480
481        for (idx, &s) in args.iter().enumerate() {
482            let s = s.as_index();
483            if s > 0 {
484                let s = s as usize;
485                new_size *= s;
486                dims.push(s);
487            } else if s == 0 {
488                // We need to find the index of the 0 dimensions and
489                // replace them with the actual dimension value.
490                let s = self.dims[idx];
491                new_size *= s;
492                dims.push(s);
493            } else if s == -1 {
494                match infer_index {
495                    None => {
496                        infer_index = Some(idx);
497                        // Used by / Replaced by handling later.
498                        dims.push(1);
499                    }
500                    Some(_) => {
501                        return Err(ShapeError::Invalid {
502                            reason: "Repeated -1 in reshape".to_string(),
503                        });
504                    }
505                }
506            } else {
507                return Err(ShapeError::Invalid {
508                    reason: "The given shape cannot contain negative dimensions (other than -1)."
509                        .to_string(),
510                });
511            }
512        }
513
514        let source_size = self.num_elements();
515        match infer_index {
516            None => {
517                if source_size != new_size {
518                    return Err(ShapeError::Invalid {
519                        reason: format!(
520                            "The given shape doesn't have the same number of elements as the current shape. Current shape: {self}, target shape: {dims:?}.",
521                        ),
522                    });
523                }
524            }
525            Some(idx) => {
526                if !source_size.is_multiple_of(new_size) {
527                    return Err(ShapeError::Invalid {
528                        reason: format!(
529                            "Cannot infer a valid target shape. Current shape: {self}, target dimensions: {args:?}."
530                        ),
531                    });
532                }
533                dims[idx] = source_size / new_size;
534            }
535        }
536
537        Ok(dims.into())
538    }
539}
540
541/// Compute the output shape for matrix multiplication with broadcasting support.
542///
543/// The last two dimensions are treated as matrices, while preceding dimensions
544/// follow broadcast semantics similar to elementwise operations.
545pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result<Shape, ShapeError> {
546    let rank = lhs.rank();
547    if rank != rhs.rank() {
548        return Err(ShapeError::RankMismatch {
549            left: rank,
550            right: rhs.rank(),
551        });
552    }
553
554    if lhs[rank - 1] != rhs[rank - 2] {
555        return Err(ShapeError::IncompatibleShapes {
556            left: lhs.clone(),
557            right: rhs.clone(),
558        });
559    }
560
561    let mut shape = if rank > 2 {
562        // Broadcast leading dims
563        Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
564    } else {
565        Shape::new([])
566    };
567    shape.extend([lhs[rank - 2], rhs[rank - 1]]);
568
569    Ok(shape)
570}
571
572impl Display for Shape {
573    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
574        self.dims.fmt(f)
575    }
576}
577
578impl FromStr for Shape {
579    type Err = crate::ExpressionError;
580
581    fn from_str(source: &str) -> Result<Self, Self::Err> {
582        let mut s = source.trim();
583
584        const DELIMS: [(&str, &str); 2] = [("[", "]"), ("(", ")")];
585
586        for (open, close) in DELIMS {
587            if let Some(p) = s.strip_prefix(open) {
588                if let Some(p) = p.strip_suffix(close) {
589                    s = p.trim();
590                    break;
591                } else {
592                    return Err(crate::ExpressionError::ParseError {
593                        message: "Unbalanced delimiters".to_string(),
594                        source: source.to_string(),
595                    });
596                }
597            }
598        }
599
600        if s.is_empty() {
601            return Ok(Shape::new([]));
602        }
603
604        let dims =
605            s.split(',')
606                .map(|dim_str| {
607                    dim_str.trim().parse::<usize>().map_err(|_| {
608                        crate::ExpressionError::ParseError {
609                            message: "Unable to parse shape".to_string(),
610                            source: source.to_string(),
611                        }
612                    })
613                })
614                .collect::<Result<Vec<usize>, crate::ExpressionError>>()?;
615
616        if dims.is_empty() {
617            unreachable!("Split should have returned at least one element");
618        }
619
620        Ok(Shape { dims })
621    }
622}
623
624impl<Idx> Index<Idx> for Shape
625where
626    Idx: SliceIndex<[usize]>,
627{
628    type Output = Idx::Output;
629
630    fn index(&self, index: Idx) -> &Self::Output {
631        &self.dims[index]
632    }
633}
634
635impl<Idx> IndexMut<Idx> for Shape
636where
637    Idx: SliceIndex<[usize]>,
638{
639    fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
640        &mut self.dims[index]
641    }
642}
643
644// Allow `&shape` to behave like a slice `&[usize]` directly
645impl Deref for Shape {
646    type Target = [usize];
647
648    fn deref(&self) -> &Self::Target {
649        &self.dims
650    }
651}
652
653// Allow `&shape` to behave like a mut slice `&mut [usize]` directly
654impl DerefMut for Shape {
655    fn deref_mut(&mut self) -> &mut Self::Target {
656        &mut self.dims
657    }
658}
659// Allow `shape.reshape(other_shape)`.
660//
661// By implementing `AsRef<[usize]>`, `Shape` behaves like a slice of dimensions,
662// similar to how `Vec<T>` can be passed to functions expecting a slice.
663impl AsRef<[usize]> for Shape {
664    fn as_ref(&self) -> &[usize] {
665        &self.dims
666    }
667}
668
669impl From<Shape> for Vec<usize> {
670    fn from(shape: Shape) -> Self {
671        shape.dims
672    }
673}
674
675impl<T, I> From<T> for Shape
676where
677    T: IntoIterator<Item = I>,
678    I: AsSize,
679{
680    fn from(dims: T) -> Self {
681        Shape {
682            dims: dims.into_iter().map(|d| d.as_size()).collect(),
683        }
684    }
685}
686
687#[cfg(test)]
688#[allow(clippy::identity_op, reason = "useful for clarity")]
689mod tests {
690    use super::*;
691    use crate::s;
692    use alloc::string::ToString;
693    use alloc::vec;
694
695    #[test]
696    fn test_shape_to_str() {
697        let shape = Shape::new([2, 3, 4, 5]);
698        assert_eq!(shape.to_string(), "[2, 3, 4, 5]");
699    }
700
701    #[test]
702    fn test_shape_from_str() {
703        assert_eq!(
704            "[2, 3, 4, 5]".parse::<Shape>().unwrap(),
705            Shape::new([2, 3, 4, 5])
706        );
707        assert_eq!(
708            "(2, 3, 4, 5)".parse::<Shape>().unwrap(),
709            Shape::new([2, 3, 4, 5])
710        );
711        assert_eq!(
712            "2, 3, 4, 5".parse::<Shape>().unwrap(),
713            Shape::new([2, 3, 4, 5])
714        );
715
716        assert_eq!("[2]".parse::<Shape>().unwrap(), Shape::new([2]));
717        assert_eq!("(2)".parse::<Shape>().unwrap(), Shape::new([2]));
718        assert_eq!("2".parse::<Shape>().unwrap(), Shape::new([2]));
719
720        assert_eq!("[]".parse::<Shape>().unwrap(), Shape::new([]));
721        assert_eq!("".parse::<Shape>().unwrap(), Shape::new([]));
722
723        assert_eq!(
724            "[".parse::<Shape>(),
725            Err(crate::ExpressionError::ParseError {
726                message: "Unbalanced delimiters".to_string(),
727                source: "[".to_string()
728            })
729        );
730
731        assert_eq!(
732            "[[1]".parse::<Shape>(),
733            Err(crate::ExpressionError::ParseError {
734                message: "Unable to parse shape".to_string(),
735                source: "[[1]".to_string()
736            })
737        );
738        assert_eq!(
739            "[[1]]".parse::<Shape>(),
740            Err(crate::ExpressionError::ParseError {
741                message: "Unable to parse shape".to_string(),
742                source: "[[1]]".to_string()
743            })
744        );
745        assert_eq!(
746            "[1)".parse::<Shape>(),
747            Err(crate::ExpressionError::ParseError {
748                message: "Unbalanced delimiters".to_string(),
749                source: "[1)".to_string()
750            })
751        );
752
753        assert_eq!(
754            "]".parse::<Shape>(),
755            Err(crate::ExpressionError::ParseError {
756                message: "Unable to parse shape".to_string(),
757                source: "]".to_string()
758            })
759        );
760
761        assert_eq!(
762            "[a]".parse::<Shape>(),
763            Err(crate::ExpressionError::ParseError {
764                message: "Unable to parse shape".to_string(),
765                source: "[a]".to_string()
766            })
767        );
768    }
769
770    #[test]
771    fn num_dims_and_rank() {
772        let dims = [2, 3, 4, 5];
773        let shape = Shape::new(dims);
774        assert_eq!(4, shape.num_dims());
775        assert_eq!(4, shape.rank());
776    }
777
778    #[test]
779    fn num_elements() {
780        let dims = [2, 3, 4, 5];
781        let shape = Shape::new(dims);
782        assert_eq!(120, shape.num_elements());
783    }
784
785    #[test]
786    #[allow(clippy::into_iter_on_ref)]
787    fn test_shape_into_iter() {
788        let dims = [2, 3, 4, 5];
789        let shape = Shape::new(dims);
790
791        assert_eq!(shape.into_iter().sum::<usize>(), 14);
792    }
793
794    #[test]
795    fn test_into_ranges() {
796        let dims = [2, 3, 4, 5];
797        let shape = Shape::new(dims);
798        assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
799    }
800
801    #[test]
802    fn test_to_vec() {
803        let dims = [2, 3, 4, 5];
804        let shape = Shape::new(dims);
805        assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
806    }
807
808    #[allow(clippy::single_range_in_vec_init)]
809    #[test]
810    fn test_into_slices() {
811        let slices = Shape::new([3]).into_slices(1..4);
812        assert_eq!(slices[0].to_range(3), 1..3);
813
814        let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
815        assert_eq!(slices[0].to_range(3), 1..3);
816        assert_eq!(slices[1].to_range(4), 0..2);
817
818        let slices = Shape::new([3]).into_slices(..-2);
819        assert_eq!(slices[0].to_range(3), 0..1);
820
821        let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
822        assert_eq!(slices[0].to_range(2), 0..2);
823        assert_eq!(slices[1].to_range(3), 1..2);
824
825        let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
826        assert_eq!(slices[0].to_range(2), 0..2);
827        assert_eq!(slices[1].to_range(3), 2..3);
828    }
829
830    #[test]
831    fn test_shape_index() {
832        let shape = Shape::new([2, 3, 4, 5]);
833
834        assert_eq!(shape[0], 2);
835        assert_eq!(shape[1], 3);
836        assert_eq!(shape[2], 4);
837        assert_eq!(shape[3], 5);
838
839        // Works with ranges
840        assert_eq!(shape[1..3], *&[3, 4]);
841        assert_eq!(shape[1..=2], *&[3, 4]);
842        assert_eq!(shape[..], *&[2, 3, 4, 5]);
843    }
844
845    #[test]
846    fn test_shape_slice_methods() {
847        let shape = Shape::new([2, 3, 4, 5]);
848
849        let dim = shape.first();
850        assert_eq!(dim, Some(&2));
851        let dim = shape.last();
852        assert_eq!(dim, Some(&5));
853
854        assert!(!shape.is_empty());
855        let shape = Shape::new([]);
856        assert!(shape.is_empty());
857    }
858
859    #[test]
860    fn test_shape_iter() {
861        let dims = [2, 3, 4, 5];
862        let shape = Shape::new(dims);
863
864        for (d, sd) in dims.iter().zip(shape.iter()) {
865            assert_eq!(d, sd);
866        }
867    }
868
869    #[test]
870    fn test_shape_iter_mut() {
871        let mut shape = Shape::new([2, 3, 4, 5]);
872
873        for d in shape.iter_mut() {
874            *d += 1;
875        }
876
877        assert_eq!(&shape.dims, &[3, 4, 5, 6]);
878    }
879
880    #[test]
881    fn test_shape_as_slice() {
882        let dims = [2, 3, 4, 5];
883        let shape = Shape::new(dims);
884
885        assert_eq!(shape.as_slice(), dims.as_slice());
886
887        // Deref coercion
888        let shape_slice: &[usize] = &shape;
889        assert_eq!(shape_slice, *&[2, 3, 4, 5]);
890    }
891
892    #[test]
893    fn test_shape_as_mut_slice() {
894        let mut dims = [2, 3, 4, 5];
895        let mut shape = Shape::new(dims);
896
897        let shape_mut = shape.as_mut_slice();
898        assert_eq!(shape_mut, dims.as_mut_slice());
899        shape_mut[1] = 6;
900
901        assert_eq!(shape_mut, &[2, 6, 4, 5]);
902
903        let mut shape = Shape::new(dims);
904        let shape = &mut shape[..];
905        shape[1] = 6;
906
907        assert_eq!(shape, shape_mut)
908    }
909
910    #[test]
911    fn test_shape_flatten() {
912        let shape = Shape::new([2, 3, 4, 5]);
913        assert_eq!(shape.num_elements(), 120);
914
915        let shape = shape.flatten();
916        assert_eq!(shape.num_elements(), 120);
917        assert_eq!(&shape.dims, &[120]);
918    }
919
920    #[test]
921    fn test_ravel() {
922        let shape = Shape::new([2, 3, 4, 5]);
923
924        assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0);
925        assert_eq!(
926            shape.ravel_index(&[1, 2, 3, 4]),
927            1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
928        );
929    }
930
931    #[test]
932    fn test_shape_insert_remove_push() {
933        let dims = [2, 3, 4, 5];
934        let mut shape = Shape::new(dims);
935        let size = 6;
936        shape.insert(1, size);
937
938        assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
939
940        let removed = shape.remove(1);
941        assert_eq!(removed, size);
942        assert_eq!(shape, Shape::new(dims));
943
944        shape.push(6);
945        assert_eq!(shape, Shape::new([2, 3, 4, 5, 6]));
946    }
947
948    #[test]
949    fn test_shape_swap_permute() {
950        let dims = [2, 3, 4, 5];
951        let shape = Shape::new(dims);
952        let shape = shape.swap(1, 2).unwrap();
953
954        assert_eq!(&shape.dims, &[2, 4, 3, 5]);
955
956        let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
957        assert_eq!(shape, Shape::new(dims));
958    }
959
960    #[test]
961    #[should_panic]
962    fn test_shape_swap_out_of_bounds() {
963        let shape = Shape::new([2, 3, 4, 5]);
964
965        shape.swap(0, 4).unwrap();
966    }
967
968    #[test]
969    #[should_panic]
970    fn test_shape_permute_incomplete() {
971        let shape = Shape::new([2, 3, 4, 5]);
972
973        shape.permute(&[0, 2, 1]).unwrap();
974    }
975
976    #[test]
977    fn test_shape_repeat() {
978        let shape = Shape::new([2, 3, 4, 5]);
979
980        let out = shape.repeat(2, 3).unwrap();
981        assert_eq!(out, Shape::new([2, 3, 12, 5]));
982    }
983
984    #[test]
985    fn test_shape_repeat_invalid() {
986        let shape = Shape::new([2, 3, 4, 5]);
987
988        let out = shape.repeat(5, 3);
989        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
990    }
991
992    #[test]
993    fn test_shape_reduce() {
994        let shape = Shape::new([2, 3, 4, 5]);
995
996        let out = shape.reduce(2).unwrap();
997        assert_eq!(out, Shape::new([2, 3, 1, 5]));
998    }
999
1000    #[test]
1001    fn test_shape_reduce_invalid() {
1002        let shape = Shape::new([2, 3, 4, 5]);
1003
1004        let out = shape.reduce(5);
1005        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
1006    }
1007
1008    #[test]
1009    fn test_shape_broadcast_binary() {
1010        let lhs = Shape::new([1, 1, 2, 4]);
1011        let rhs = Shape::new([7, 6, 2, 1]);
1012
1013        let out = lhs.broadcast(&rhs).unwrap();
1014        assert_eq!(out, Shape::new([7, 6, 2, 4]));
1015    }
1016
1017    #[test]
1018    fn test_shape_broadcast_rank_mismatch() {
1019        let lhs = Shape::new([1, 2, 4]);
1020        let rhs = Shape::new([7, 6, 2, 4]);
1021
1022        let out = lhs.broadcast(&rhs);
1023        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
1024    }
1025
1026    #[test]
1027    fn test_shape_broadcast_incompatible_dims() {
1028        let lhs = Shape::new([1, 2, 2, 4]);
1029        let rhs = Shape::new([7, 6, 2, 1]);
1030
1031        let out = lhs.broadcast(&rhs);
1032        assert_eq!(
1033            out,
1034            Err(ShapeError::IncompatibleDims {
1035                left: 2,
1036                right: 6,
1037                dim: 1
1038            })
1039        );
1040    }
1041
1042    #[test]
1043    fn test_shape_broadcast_many() {
1044        let s1 = Shape::new([1, 1, 2, 4]);
1045        let s2 = Shape::new([7, 1, 2, 1]);
1046        let s3 = Shape::new([7, 6, 1, 1]);
1047
1048        let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
1049        assert_eq!(out, Shape::new([7, 6, 2, 4]));
1050    }
1051
1052    #[test]
1053    fn test_shape_broadcast_many_rank_mismatch() {
1054        let s1 = Shape::new([1, 1, 2, 4]);
1055        let s2 = Shape::new([7, 1, 2, 1]);
1056        let s3 = Shape::new([1, 6, 1]);
1057
1058        let out = Shape::broadcast_many([&s1, &s2, &s3]);
1059        assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 }));
1060    }
1061
1062    #[test]
1063    fn test_shape_broadcast_many_incompatible_dims() {
1064        let s1 = Shape::new([1, 1, 2, 4]);
1065        let s2 = Shape::new([7, 1, 2, 1]);
1066        let s3 = Shape::new([4, 6, 1, 1]);
1067
1068        let out = Shape::broadcast_many([&s1, &s2, &s3]);
1069        assert_eq!(
1070            out,
1071            Err(ShapeError::IncompatibleDims {
1072                left: 7,
1073                right: 4,
1074                dim: 0
1075            })
1076        );
1077    }
1078
1079    #[test]
1080    fn test_shape_broadcast_many_empty() {
1081        let out = Shape::broadcast_many(&[]);
1082        assert_eq!(out, Err(ShapeError::empty()));
1083    }
1084
1085    #[test]
1086    fn test_shape_matmul_2d() {
1087        let lhs = Shape::new([2, 4]);
1088        let rhs = Shape::new([4, 2]);
1089        let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1090        assert_eq!(out, Shape::new([2, 2]));
1091    }
1092
1093    #[test]
1094    fn test_shape_matmul_4d_broadcasted() {
1095        let lhs = Shape::new([1, 3, 2, 4]);
1096        let rhs = Shape::new([2, 1, 4, 2]);
1097        let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1098        assert_eq!(out, Shape::new([2, 3, 2, 2]));
1099    }
1100
1101    #[test]
1102    fn test_shape_matmul_invalid_rank() {
1103        let lhs = Shape::new([3, 2, 4]);
1104        let rhs = Shape::new([2, 1, 4, 2]);
1105        let out = calculate_matmul_output(&lhs, &rhs);
1106        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
1107    }
1108
1109    #[test]
1110    fn test_shape_matmul_invalid_shape() {
1111        let lhs = Shape::new([1, 3, 2, 4]);
1112        let rhs = Shape::new([2, 1, 3, 2]);
1113        let out = calculate_matmul_output(&lhs, &rhs);
1114        assert_eq!(
1115            out,
1116            Err(ShapeError::IncompatibleShapes {
1117                left: lhs,
1118                right: rhs
1119            })
1120        );
1121    }
1122
1123    #[test]
1124    fn test_shape_matmul_invalid_broadcast() {
1125        let lhs = Shape::new([1, 3, 2, 4]);
1126        let rhs = Shape::new([2, 2, 4, 2]);
1127        let out = calculate_matmul_output(&lhs, &rhs);
1128        assert_eq!(
1129            out,
1130            Err(ShapeError::IncompatibleDims {
1131                left: 3,
1132                right: 2,
1133                dim: 1
1134            })
1135        );
1136    }
1137
1138    #[test]
1139    fn test_shape_cat() {
1140        let s1 = Shape::new([2, 3, 4, 5]);
1141        let s2 = Shape::new([1, 3, 4, 5]);
1142        let s3 = Shape::new([4, 3, 4, 5]);
1143
1144        let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
1145        assert_eq!(out, Shape::new([7, 3, 4, 5]));
1146
1147        let s1 = Shape::new([2, 3, 4, 5]);
1148        let s2 = Shape::new([2, 3, 2, 5]);
1149        let s3 = Shape::new([2, 3, 1, 5]);
1150
1151        let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
1152        assert_eq!(out, Shape::new([2, 3, 7, 5]));
1153    }
1154
1155    #[test]
1156    fn test_shape_cat_empty() {
1157        let out = Shape::cat(&[], 0);
1158        assert_eq!(out, Err(ShapeError::empty()));
1159    }
1160
1161    #[test]
1162    fn test_shape_cat_dim_out_of_bounds() {
1163        let s1 = Shape::new([2, 3, 4, 5]);
1164        let s2 = Shape::new([2, 3, 4, 5]);
1165        let out = Shape::cat(&[s1, s2], 4);
1166        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 }));
1167    }
1168
1169    #[test]
1170    fn test_shape_cat_rank_mismatch() {
1171        let s1 = Shape::new([2, 3, 4, 5]);
1172        let s2 = Shape::new([2, 3, 4, 5, 6]);
1173        let out = Shape::cat(&[s1, s2], 0);
1174        assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 }));
1175    }
1176
1177    #[test]
1178    fn test_shape_cat_incompatible_shapes() {
1179        let s1 = Shape::new([2, 3, 4, 5]);
1180        let s2 = Shape::new([1, 3, 4, 5]);
1181        let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
1182
1183        assert_eq!(
1184            out,
1185            Err(ShapeError::IncompatibleShapes {
1186                left: s1,
1187                right: s2
1188            })
1189        );
1190    }
1191
1192    #[test]
1193    fn test_shape_slice_output_shape_basic() {
1194        // Test basic slicing with step=1
1195        let slices = [
1196            Slice::new(0, Some(5), 1), // 5 elements
1197            Slice::new(2, Some(8), 1), // 6 elements
1198        ];
1199        let original_shape = Shape::new([10, 10, 10]);
1200        let result = original_shape.slice(&slices).unwrap();
1201        assert_eq!(result, Shape::new([5, 6, 10]));
1202    }
1203
1204    #[test]
1205    fn test_shape_slice_output_shape_with_positive_steps() {
1206        // Test slicing with various positive steps
1207        let slices = [
1208            Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements
1209            Slice::new(1, Some(9), 3),  // [1,4,7] -> 3 elements
1210            Slice::new(0, Some(7), 4),  // [0,4] -> 2 elements
1211        ];
1212        let original_shape = Shape::new([20, 20, 20, 30]);
1213        let result = original_shape.slice(&slices).unwrap();
1214        assert_eq!(result, Shape::new([5, 3, 2, 30]));
1215    }
1216
1217    #[test]
1218    fn test_shape_slice_output_shape_with_negative_steps() {
1219        // Test slicing with negative steps (backward iteration)
1220        let slices = [
1221            Slice::new(0, Some(10), -1), // 10 elements traversed backward
1222            Slice::new(2, Some(8), -2),  // [7,5,3] -> 3 elements
1223        ];
1224        let original_shape = Shape::new([20, 20, 20]);
1225        let result = original_shape.slice(&slices).unwrap();
1226        assert_eq!(result, Shape::new([10, 3, 20]));
1227    }
1228
1229    #[test]
1230    fn test_shape_slice_output_shape_mixed_steps() {
1231        // Test with a mix of positive, negative, and unit steps
1232        let slices = [
1233            Slice::from_range_stepped(1..6, 1),   // 5 elements
1234            Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements
1235            Slice::from_range_stepped(2..14, 4),  // [2,6,10] -> 3 elements
1236        ];
1237        let original_shape = Shape::new([20, 20, 20]);
1238        let result = original_shape.slice(&slices).unwrap();
1239        assert_eq!(result, Shape::new([5, 4, 3]));
1240    }
1241
1242    #[test]
1243    fn test_shape_slice_output_shape_partial_dims() {
1244        // Test when slices has fewer dimensions than original shape
1245        let slices = [
1246            Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements
1247        ];
1248        let original_shape = Shape::new([10, 20, 30, 40]);
1249        let result = original_shape.slice(&slices).unwrap();
1250        assert_eq!(result, Shape::new([3, 20, 30, 40]));
1251    }
1252
1253    #[test]
1254    fn test_shape_slice_output_shape_edge_cases() {
1255        // Test edge cases with small ranges and large steps
1256        let slices = [
1257            Slice::from_range_stepped(0..1, 1),    // Single element
1258            Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element
1259            Slice::from_range_stepped(5..5, 1),    // Empty range -> 0 elements
1260        ];
1261        let original_shape = Shape::new([10, 20, 30]);
1262        let result = original_shape.slice(&slices).unwrap();
1263        assert_eq!(result, Shape::new([1, 1, 0]));
1264    }
1265
1266    #[test]
1267    fn test_shape_slice_output_shape_empty() {
1268        // Test with no slice infos (should return original shape)
1269        let slices = [];
1270        let original_shape = Shape::new([10, 20, 30]);
1271        let result = original_shape.slice(&slices).unwrap();
1272        assert_eq!(result, Shape::new([10, 20, 30]));
1273    }
1274
1275    #[test]
1276    fn test_shape_slice_output_shape_uneven_division() {
1277        // Test cases where range size doesn't divide evenly by step
1278        let slices = [
1279            Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6]
1280            Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8]
1281            Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6]
1282        ];
1283        let original_shape = Shape::new([20, 20, 20]);
1284        let result = original_shape.slice(&slices).unwrap();
1285        assert_eq!(result, Shape::new([3, 3, 2]));
1286    }
1287
1288    #[test]
1289    fn test_shape_expand() {
1290        let shape = Shape::new([1, 3, 1]);
1291        let expanded = Shape::new([2, 3, 4]);
1292        let out = shape.expand(expanded.clone()).unwrap();
1293        assert_eq!(out, expanded);
1294    }
1295
1296    #[test]
1297    fn test_shape_expand_higher_rank() {
1298        let shape = Shape::new([1, 4]);
1299        let expanded = Shape::new([2, 3, 4]);
1300        let out = shape.expand(expanded.clone()).unwrap();
1301        assert_eq!(out, expanded);
1302    }
1303
1304    #[test]
1305    fn test_shape_expand_invalid_rank() {
1306        let shape = Shape::new([1, 3, 1]);
1307        let expanded = Shape::new([3, 4]);
1308        let out = shape.expand(expanded);
1309        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 2 }));
1310    }
1311
1312    #[test]
1313    fn test_shape_expand_incompatible_dims() {
1314        let shape = Shape::new([1, 3, 2]);
1315        let expanded = Shape::new([2, 3, 4]);
1316        let out = shape.expand(expanded);
1317        assert_eq!(
1318            out,
1319            Err(ShapeError::IncompatibleDims {
1320                left: 2,
1321                right: 4,
1322                dim: 2
1323            })
1324        );
1325    }
1326
1327    #[test]
1328    fn test_shape_reshape() {
1329        let shape = Shape::new([2, 3, 4, 5]);
1330        let reshaped = Shape::new([1, 2, 12, 5]);
1331        let out = shape.reshape(reshaped.clone()).unwrap();
1332        assert_eq!(out, reshaped);
1333    }
1334
1335    #[test]
1336    fn test_shape_reshape_invalid() {
1337        let shape = Shape::new([2, 3, 4, 5]);
1338        let reshaped = Shape::new([2, 2, 12, 5]);
1339        let out = shape.reshape(reshaped.clone());
1340        assert_eq!(
1341            out,
1342            Err(ShapeError::Invalid {
1343                reason: "The given shape doesn't have the same number of elements as the current shape. Current shape: [2, 3, 4, 5], target shape: [2, 2, 12, 5].".into(),
1344            })
1345        );
1346    }
1347
1348    #[test]
1349    fn test_shape_reshape_invalid_inferred() {
1350        let shape = Shape::new([2, 4]);
1351        let out = shape.reshape([-1, 3]);
1352        assert_eq!(
1353            out,
1354            Err(ShapeError::Invalid {
1355                reason: "Cannot infer a valid target shape. Current shape: [2, 4], target dimensions: [-1, 3].".into(),
1356            })
1357        );
1358    }
1359
1360    #[test]
1361    fn test_flatten_dims() {
1362        let shape = Shape::new([2, 3, 4, 5]);
1363        let flattened = shape.flatten_dims(-2, 3);
1364        assert_eq!(flattened, Shape::new([2, 3, 20]));
1365    }
1366}