burn_std/tensor/
shape.rs

1//! Tensor shape definition.
2
3use super::indexing::ravel_index;
4use super::{AsIndex, Slice, SliceArg};
5use alloc::string::ToString;
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt::{Debug, Display, Formatter};
9use core::str::FromStr;
10use core::{
11    ops::{Deref, DerefMut, Index, IndexMut, Range},
12    slice::{Iter, IterMut, SliceIndex},
13};
14use serde::{Deserialize, Serialize};
15
16/// Shape of a tensor.
17#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub struct Shape {
19    /// The dimensions of the tensor.
20    pub dims: Vec<usize>,
21}
22
23#[allow(missing_docs)]
24#[derive(Debug, Clone, PartialEq, Eq)]
25/// Error that can occur when attempting to modify shapes.
26pub enum ShapeError {
27    /// The operands have different ranks.
28    RankMismatch { left: usize, right: usize },
29    /// A pair of dimensions are incompatible for broadcasting.
30    IncompatibleDims {
31        left: usize,
32        right: usize,
33        dim: usize,
34    },
35    /// Invalid dimension specified for the rank.
36    OutOfBounds { dim: usize, rank: usize },
37    /// A pair of shapes are incompatible for the operation.
38    IncompatibleShapes { left: Shape, right: Shape },
39    /// Invalid empty shape.
40    Empty,
41}
42
43impl Shape {
44    /// Constructs a new `Shape`.
45    pub fn new<const D: usize>(dims: [usize; D]) -> Self {
46        // For backward compat
47        Self {
48            dims: dims.to_vec(),
49        }
50    }
51
52    /// Returns the total number of elements of a tensor having this shape
53    pub fn num_elements(&self) -> usize {
54        self.dims.iter().product()
55    }
56
57    /// Returns the number of dimensions.
58    ///
59    /// Alias for `Shape::rank()`.
60    pub fn num_dims(&self) -> usize {
61        self.dims.len()
62    }
63
64    /// Returns the rank (the number of dimensions).
65    ///
66    /// Alias for `Shape::num_dims()`.
67    pub fn rank(&self) -> usize {
68        self.num_dims()
69    }
70
71    // For compat with dims: [usize; D]
72    /// Returns the dimensions of the tensor as an array.
73    pub fn dims<const D: usize>(&self) -> [usize; D] {
74        let mut dims = [1; D];
75        dims[..D].copy_from_slice(&self.dims[..D]);
76        dims
77    }
78
79    /// Change the shape to one dimensional with the same number of elements.
80    pub fn flatten(mut self) -> Self {
81        self.dims = [self.num_elements()].into();
82        self
83    }
84
85    /// Flatten the shape along a given range of dimensions.
86    ///
87    /// This function collapses the specified range of dimensions into a single dimension,
88    /// effectively flattening the tensor in that range.
89    ///
90    /// # Arguments
91    ///
92    /// - `start_dim`: The starting dimension of the range to be flattened,
93    ///   supports negative indexing.
94    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
95    ///   supports negative indexing.
96    ///
97    /// # Returns
98    ///
99    /// A new `Shape` instance with the specified range of dimensions flattened.
100    ///
101    /// # Example
102    ///
103    /// ```rust
104    /// use burn_std::Shape;
105    ///
106    /// fn example() {
107    ///     let shape = Shape::new([2, 3, 4]);
108    ///
109    ///     let flattened = shape.flatten_dims(1, 2);
110    ///     println!("{flattened}");
111    ///     // [2, 12]
112    /// }
113    /// ```
114    pub fn flatten_dims(self, start_dim: impl AsIndex, end_dim: impl AsIndex) -> Self {
115        let rank = self.rank();
116        let start = start_dim.expect_dim_index(rank);
117        let end = end_dim.expect_dim_index(rank);
118
119        assert!(
120            start <= end,
121            "start_dim ({start}) must be <= than end_dim ({end})"
122        );
123
124        let existing = self.dims;
125
126        let flattened_size = existing[start..=end].iter().product();
127
128        let new_rank = rank - (end - start);
129        let mut dims = vec![0; new_rank];
130        dims[..start].copy_from_slice(&existing[..start]);
131        dims[start] = flattened_size;
132        dims[start + 1..].copy_from_slice(&existing[end + 1..]);
133
134        Self { dims }
135    }
136
137    /// Compute the ravel index for the given coordinates.
138    ///
139    /// This returns the row-major order raveling:
140    /// * `strides[-1] = 1`
141    /// * `strides[i] = strides[i+1] * dims[i+1]`
142    /// * `dim_strides = coords * strides`
143    /// * `ravel = sum(dim_strides)`
144    ///
145    /// # Arguments
146    /// - `indices`: the index for each dimension; must be the same length as `shape`.
147    ///
148    /// # Returns
149    /// - the ravel offset index.
150    pub fn ravel_index<I: AsIndex>(&self, indices: &[I]) -> usize {
151        ravel_index(indices, &self.dims)
152    }
153
154    /// Convert shape dimensions to full covering ranges (0..dim) for each dimension.
155    pub fn into_ranges(self) -> Vec<Range<usize>> {
156        self.into_iter().map(|d| 0..d).collect()
157    }
158
159    /// Converts slice arguments into an array of slice specifications for the shape.
160    ///
161    /// This method returns an array of `Slice` objects that can be used for slicing operations.
162    /// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but
163    /// allows custom slice specifications instead of full ranges.
164    /// For creating complex slice specifications, use the [`s!`] macro.
165    ///
166    /// # Arguments
167    ///
168    /// * `slices` - An array of slice specifications, where each element can be:
169    ///   - A range (e.g., `2..5`)
170    ///   - An index
171    ///   - A `Slice` object
172    ///   - The output of the [`s!`] macro for advanced slicing
173    ///
174    /// # Behavior
175    ///
176    /// - Supports partial and full slicing in any number of dimensions.
177    /// - Missing ranges are treated as full slices if D > D2.
178    /// - Handles negative indices by wrapping around from the end of the dimension.
179    /// - Clamps ranges to the shape's dimensions if they exceed the bounds.
180    ///
181    /// # Returns
182    ///
183    /// An array of `Slice` objects corresponding to the provided slice specifications,
184    /// clamped to the shape's actual dimensions.
185    ///
186    /// # Examples
187    ///
188    /// ```rust
189    /// use burn_std::{Shape, Slice, s};
190    ///
191    /// fn example() {
192    ///     // 1D slicing
193    ///     let slices = Shape::new([4]).into_slices(1..4);
194    ///     assert_eq!(slices[0].to_range(4), 1..3);
195    ///
196    ///     // 2D slicing
197    ///     let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
198    ///     assert_eq!(slices[0].to_range(3), 1..3);
199    ///     assert_eq!(slices[1].to_range(4), 0..2);
200    ///
201    ///     // Using negative indices
202    ///     let slices = Shape::new([3]).into_slices(..-2);
203    ///     assert_eq!(slices[0].to_range(3), 0..1);
204    ///
205    ///     // Using the slice macro to select different ranges
206    ///     let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
207    ///     assert_eq!(slices[0].to_range(2), 0..2);
208    ///     assert_eq!(slices[1].to_range(3), 1..2);
209    /// }
210    /// ```
211    ///
212    /// # See Also
213    ///
214    /// - [`s!`] - The recommended macro for creating slice specifications
215    /// - [`Shape::into_ranges`] - Convert to full covering ranges
216    ///
217    /// [`s!`]: crate::s!
218    pub fn into_slices<S>(self, slices: S) -> Vec<Slice>
219    where
220        S: SliceArg,
221    {
222        slices.into_slices(&self)
223    }
224
225    /// Construct a vector of the dims.
226    pub fn to_vec(&self) -> Vec<usize> {
227        self.dims.clone()
228    }
229
230    /// Returns an iterator over the shape dimensions.
231    pub fn iter(&self) -> Iter<'_, usize> {
232        self.dims.iter()
233    }
234
235    /// Mutable iterator over the dimensions.
236    pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
237        self.dims.iter_mut()
238    }
239
240    /// Borrow the underlying dimensions slice.
241    pub fn as_slice(&self) -> &[usize] {
242        &self.dims
243    }
244
245    /// Borrow the underlying dimensions slice mutably.
246    pub fn as_mut_slice(&mut self) -> &mut [usize] {
247        &mut self.dims
248    }
249
250    /// Insert a dimension of `size` at position `index`.
251    pub fn insert(&mut self, index: usize, size: usize) {
252        self.dims.insert(index, size);
253    }
254
255    /// Remove and return the dimension at position `index` from the shape.
256    pub fn remove(&mut self, index: usize) -> usize {
257        self.dims.remove(index)
258    }
259
260    /// Appends a dimension of `size` to the back of the shape.
261    pub fn push(&mut self, size: usize) {
262        self.dims.push(size)
263    }
264
265    /// Extend the shape with the content of another shape or iterator.
266    pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
267        self.dims.extend(iter)
268    }
269
270    /// Swap two dimensions in the shape.
271    pub fn swap(mut self, dim1: usize, dim2: usize) -> Result<Self, ShapeError> {
272        if dim1 > self.rank() {
273            return Err(ShapeError::OutOfBounds {
274                dim: dim1,
275                rank: self.rank(),
276            });
277        }
278        if dim2 > self.rank() {
279            return Err(ShapeError::OutOfBounds {
280                dim: dim2,
281                rank: self.rank(),
282            });
283        }
284        self.dims.swap(dim1, dim2);
285        Ok(self)
286    }
287
288    /// Reorder the shape dimensions according to the permutation of `axes`.
289    pub fn permute(mut self, axes: &[usize]) -> Result<Self, ShapeError> {
290        if axes.len() != self.rank() {
291            return Err(ShapeError::RankMismatch {
292                left: self.rank(),
293                right: axes.len(),
294            });
295        }
296        debug_assert!(axes.iter().all(|i| i < &self.rank()));
297
298        self.dims = axes.iter().map(|&i| self.dims[i]).collect();
299        Ok(self)
300    }
301
302    /// Repeated the specified `dim` a number of `times`.
303    pub fn repeat(mut self, dim: usize, times: usize) -> Result<Shape, ShapeError> {
304        if dim >= self.rank() {
305            return Err(ShapeError::OutOfBounds {
306                dim,
307                rank: self.rank(),
308            });
309        }
310
311        self.dims[dim] *= times;
312        Ok(self)
313    }
314
315    /// Returns a new shape where the specified `dim` is reduced to size 1.
316    pub fn reduce(mut self, dim: usize) -> Result<Shape, ShapeError> {
317        if dim >= self.rank() {
318            return Err(ShapeError::OutOfBounds {
319                dim,
320                rank: self.rank(),
321            });
322        }
323
324        self.dims[dim] = 1;
325        Ok(self)
326    }
327
328    /// Concatenates all shapes into a new one along the given dimension.
329    pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, ShapeError>
330    where
331        I: IntoIterator<Item = &'a Shape>,
332    {
333        let mut iter = shapes.into_iter();
334
335        let first = iter.next().ok_or(ShapeError::Empty)?;
336
337        if dim >= first.rank() {
338            return Err(ShapeError::OutOfBounds {
339                dim,
340                rank: first.rank(),
341            });
342        }
343
344        let mut shape = first.clone();
345
346        for s in iter {
347            if s.rank() != shape.rank() {
348                return Err(ShapeError::RankMismatch {
349                    left: shape.rank(),
350                    right: s.rank(),
351                });
352            }
353
354            if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
355                return Err(ShapeError::IncompatibleShapes {
356                    left: shape.clone(),
357                    right: s.clone(),
358                });
359            }
360
361            shape[dim] += s[dim];
362        }
363
364        Ok(shape)
365    }
366
367    /// Compute the output shape from the given slices.
368    pub fn slice(mut self, slices: &[Slice]) -> Result<Self, ShapeError> {
369        if slices.len() > self.rank() {
370            return Err(ShapeError::RankMismatch {
371                left: self.rank(),
372                right: slices.len(),
373            });
374        }
375
376        slices
377            .iter()
378            .zip(self.iter_mut())
379            .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));
380
381        Ok(self)
382    }
383
384    /// Compute the output shape for binary operations with broadcasting support.
385    ///
386    /// - Shapes must be of the same rank (missing dimensions are not handled automatically).
387    /// - Two dimensions are compatible if they are equal, or one of them is 1.
388    ///
389    /// For example, a shape `[1, 1, 2, 4]` can be broadcast into `[7, 6, 2, 4]`
390    /// because its axes are either equal or 1. On the other hand, a shape `[2, 2]`
391    /// can *not* be broadcast into `[2, 4]`.
392    pub fn broadcast(&self, other: &Self) -> Result<Self, ShapeError> {
393        Self::broadcast_many([self, other])
394    }
395
396    /// Compute the broadcasted output shape across multiple input shapes.
397    ///
398    /// See also [broadcast](Self::broadcast).
399    pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, ShapeError>
400    where
401        I: IntoIterator<Item = &'a Shape>,
402    {
403        let mut iter = shapes.into_iter();
404        let mut broadcasted = iter.next().ok_or(ShapeError::Empty)?.clone();
405        let rank = broadcasted.rank();
406
407        for shape in iter {
408            if shape.rank() != rank {
409                return Err(ShapeError::RankMismatch {
410                    left: rank,
411                    right: shape.rank(),
412                });
413            }
414
415            for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
416                match (*d_lhs, d_rhs) {
417                    (a, b) if a == b => {} // same
418                    (1, b) => *d_lhs = b,  // broadcast to rhs
419                    (_a, 1) => {}          // keep existing dimension
420                    _ => {
421                        return Err(ShapeError::IncompatibleDims {
422                            left: *d_lhs,
423                            right: d_rhs,
424                            dim,
425                        });
426                    }
427                }
428            }
429        }
430
431        Ok(broadcasted)
432    }
433
434    /// Expand this shape to match the target shape, following broadcasting rules.
435    pub fn expand(&self, target: Shape) -> Result<Shape, ShapeError> {
436        let target_rank = target.rank();
437        if self.rank() > target_rank {
438            return Err(ShapeError::RankMismatch {
439                left: self.rank(),
440                right: target_rank,
441            });
442        }
443
444        for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() {
445            if dim_self != dim_target && *dim_self != 1 {
446                return Err(ShapeError::IncompatibleDims {
447                    left: *dim_self,
448                    right: *dim_target,
449                    dim: target_rank - i - 1,
450                });
451            }
452        }
453
454        Ok(target)
455    }
456
457    /// Reshape this shape to the target shape.
458    pub fn reshape(&self, target: Shape) -> Result<Shape, ShapeError> {
459        if self.num_elements() != target.num_elements() {
460            return Err(ShapeError::IncompatibleShapes {
461                left: self.clone(),
462                right: target,
463            });
464        }
465        Ok(target)
466    }
467}
468
469/// Compute the output shape for matrix multiplication with broadcasting support.
470///
471/// The last two dimensions are treated as matrices, while preceding dimensions
472/// follow broadcast semantics similar to elementwise operations.
473pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result<Shape, ShapeError> {
474    let rank = lhs.rank();
475    if rank != rhs.rank() {
476        return Err(ShapeError::RankMismatch {
477            left: rank,
478            right: rhs.rank(),
479        });
480    }
481
482    if lhs[rank - 1] != rhs[rank - 2] {
483        return Err(ShapeError::IncompatibleShapes {
484            left: lhs.clone(),
485            right: rhs.clone(),
486        });
487    }
488
489    let mut shape = if rank > 2 {
490        // Broadcast leading dims
491        Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
492    } else {
493        Shape::new([])
494    };
495    shape.extend([lhs[rank - 2], rhs[rank - 1]]);
496
497    Ok(shape)
498}
499
500impl Display for Shape {
501    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
502        self.dims.fmt(f)
503    }
504}
505
506impl FromStr for Shape {
507    type Err = crate::ExpressionError;
508
509    fn from_str(source: &str) -> Result<Self, Self::Err> {
510        let mut s = source.trim();
511
512        const DELIMS: [(&str, &str); 2] = [("[", "]"), ("(", ")")];
513
514        for (open, close) in DELIMS {
515            if let Some(p) = s.strip_prefix(open) {
516                if let Some(p) = p.strip_suffix(close) {
517                    s = p.trim();
518                    break;
519                } else {
520                    return Err(crate::ExpressionError::ParseError {
521                        message: "Unbalanced delimiters".to_string(),
522                        source: source.to_string(),
523                    });
524                }
525            }
526        }
527
528        if s.is_empty() {
529            return Ok(Shape::new([]));
530        }
531
532        let dims =
533            s.split(',')
534                .map(|dim_str| {
535                    dim_str.trim().parse::<usize>().map_err(|_| {
536                        crate::ExpressionError::ParseError {
537                            message: "Unable to parse shape".to_string(),
538                            source: source.to_string(),
539                        }
540                    })
541                })
542                .collect::<Result<Vec<usize>, crate::ExpressionError>>()?;
543
544        if dims.is_empty() {
545            unreachable!("Split should have returned at least one element");
546        }
547
548        Ok(Shape { dims })
549    }
550}
551
552impl IntoIterator for Shape {
553    type Item = usize;
554    type IntoIter = alloc::vec::IntoIter<Self::Item>;
555
556    fn into_iter(self) -> Self::IntoIter {
557        self.dims.into_iter()
558    }
559}
560
561impl<Idx> Index<Idx> for Shape
562where
563    Idx: SliceIndex<[usize]>,
564{
565    type Output = Idx::Output;
566
567    fn index(&self, index: Idx) -> &Self::Output {
568        &self.dims[index]
569    }
570}
571
572impl<Idx> IndexMut<Idx> for Shape
573where
574    Idx: SliceIndex<[usize]>,
575{
576    fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
577        &mut self.dims[index]
578    }
579}
580
581// Allow `&shape` to behave like a slice `&[usize]` directly
582impl Deref for Shape {
583    type Target = [usize];
584
585    fn deref(&self) -> &Self::Target {
586        &self.dims
587    }
588}
589
590// Allow `&shape` to behave like a mut slice `&mut [usize]` directly
591impl DerefMut for Shape {
592    fn deref_mut(&mut self) -> &mut Self::Target {
593        &mut self.dims
594    }
595}
596
597// Conversion sugar
598impl<const D: usize> From<[usize; D]> for Shape {
599    fn from(dims: [usize; D]) -> Self {
600        Shape::new(dims)
601    }
602}
603
604impl<const D: usize> From<[i64; D]> for Shape {
605    fn from(dims: [i64; D]) -> Self {
606        Shape {
607            dims: dims.into_iter().map(|d| d as usize).collect(),
608        }
609    }
610}
611
612impl<const D: usize> From<[i32; D]> for Shape {
613    fn from(dims: [i32; D]) -> Self {
614        Shape {
615            dims: dims.into_iter().map(|d| d as usize).collect(),
616        }
617    }
618}
619
620impl From<&[usize]> for Shape {
621    fn from(dims: &[usize]) -> Self {
622        Shape { dims: dims.into() }
623    }
624}
625
626impl From<Vec<i64>> for Shape {
627    fn from(shape: Vec<i64>) -> Self {
628        Self {
629            dims: shape.into_iter().map(|d| d as usize).collect(),
630        }
631    }
632}
633
634impl From<Vec<u64>> for Shape {
635    fn from(shape: Vec<u64>) -> Self {
636        Self {
637            dims: shape.into_iter().map(|d| d as usize).collect(),
638        }
639    }
640}
641
642impl From<Vec<usize>> for Shape {
643    fn from(shape: Vec<usize>) -> Self {
644        Self { dims: shape }
645    }
646}
647
648impl From<&Vec<usize>> for Shape {
649    fn from(shape: &Vec<usize>) -> Self {
650        Self {
651            dims: shape.clone(),
652        }
653    }
654}
655
656impl From<Shape> for Vec<usize> {
657    fn from(shape: Shape) -> Self {
658        shape.dims
659    }
660}
661
662#[cfg(test)]
663#[allow(clippy::identity_op, reason = "useful for clarity")]
664mod tests {
665    use super::*;
666    use crate::s;
667    use alloc::string::ToString;
668    use alloc::vec;
669
670    #[test]
671    fn test_shape_to_str() {
672        let shape = Shape::new([2, 3, 4, 5]);
673        assert_eq!(shape.to_string(), "[2, 3, 4, 5]");
674    }
675
676    #[test]
677    fn test_shape_from_str() {
678        assert_eq!(
679            "[2, 3, 4, 5]".parse::<Shape>().unwrap(),
680            Shape::new([2, 3, 4, 5])
681        );
682        assert_eq!(
683            "(2, 3, 4, 5)".parse::<Shape>().unwrap(),
684            Shape::new([2, 3, 4, 5])
685        );
686        assert_eq!(
687            "2, 3, 4, 5".parse::<Shape>().unwrap(),
688            Shape::new([2, 3, 4, 5])
689        );
690
691        assert_eq!("[2]".parse::<Shape>().unwrap(), Shape::new([2]));
692        assert_eq!("(2)".parse::<Shape>().unwrap(), Shape::new([2]));
693        assert_eq!("2".parse::<Shape>().unwrap(), Shape::new([2]));
694
695        assert_eq!("[]".parse::<Shape>().unwrap(), Shape::new([]));
696        assert_eq!("".parse::<Shape>().unwrap(), Shape::new([]));
697
698        assert_eq!(
699            "[".parse::<Shape>(),
700            Err(crate::ExpressionError::ParseError {
701                message: "Unbalanced delimiters".to_string(),
702                source: "[".to_string()
703            })
704        );
705
706        assert_eq!(
707            "[[1]".parse::<Shape>(),
708            Err(crate::ExpressionError::ParseError {
709                message: "Unable to parse shape".to_string(),
710                source: "[[1]".to_string()
711            })
712        );
713        assert_eq!(
714            "[[1]]".parse::<Shape>(),
715            Err(crate::ExpressionError::ParseError {
716                message: "Unable to parse shape".to_string(),
717                source: "[[1]]".to_string()
718            })
719        );
720        assert_eq!(
721            "[1)".parse::<Shape>(),
722            Err(crate::ExpressionError::ParseError {
723                message: "Unbalanced delimiters".to_string(),
724                source: "[1)".to_string()
725            })
726        );
727
728        assert_eq!(
729            "]".parse::<Shape>(),
730            Err(crate::ExpressionError::ParseError {
731                message: "Unable to parse shape".to_string(),
732                source: "]".to_string()
733            })
734        );
735
736        assert_eq!(
737            "[a]".parse::<Shape>(),
738            Err(crate::ExpressionError::ParseError {
739                message: "Unable to parse shape".to_string(),
740                source: "[a]".to_string()
741            })
742        );
743    }
744
745    #[test]
746    fn num_dims_and_rank() {
747        let dims = [2, 3, 4, 5];
748        let shape = Shape::new(dims);
749        assert_eq!(4, shape.num_dims());
750        assert_eq!(4, shape.rank());
751    }
752
753    #[test]
754    fn num_elements() {
755        let dims = [2, 3, 4, 5];
756        let shape = Shape::new(dims);
757        assert_eq!(120, shape.num_elements());
758    }
759
760    #[test]
761    fn test_shape_into_iter() {
762        let dims = [2, 3, 4, 5];
763        let shape = Shape::new(dims);
764
765        assert_eq!(shape.into_iter().sum::<usize>(), 14);
766    }
767
768    #[test]
769    fn test_into_ranges() {
770        let dims = [2, 3, 4, 5];
771        let shape = Shape::new(dims);
772        assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
773    }
774
775    #[test]
776    fn test_to_vec() {
777        let dims = [2, 3, 4, 5];
778        let shape = Shape::new(dims);
779        assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
780    }
781
782    #[allow(clippy::single_range_in_vec_init)]
783    #[test]
784    fn test_into_slices() {
785        let slices = Shape::new([3]).into_slices(1..4);
786        assert_eq!(slices[0].to_range(3), 1..3);
787
788        let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);
789        assert_eq!(slices[0].to_range(3), 1..3);
790        assert_eq!(slices[1].to_range(4), 0..2);
791
792        let slices = Shape::new([3]).into_slices(..-2);
793        assert_eq!(slices[0].to_range(3), 0..1);
794
795        let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);
796        assert_eq!(slices[0].to_range(2), 0..2);
797        assert_eq!(slices[1].to_range(3), 1..2);
798
799        let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);
800        assert_eq!(slices[0].to_range(2), 0..2);
801        assert_eq!(slices[1].to_range(3), 2..3);
802    }
803
804    #[test]
805    fn test_shape_index() {
806        let shape = Shape::new([2, 3, 4, 5]);
807
808        assert_eq!(shape[0], 2);
809        assert_eq!(shape[1], 3);
810        assert_eq!(shape[2], 4);
811        assert_eq!(shape[3], 5);
812
813        // Works with ranges
814        assert_eq!(shape[1..3], *&[3, 4]);
815        assert_eq!(shape[1..=2], *&[3, 4]);
816        assert_eq!(shape[..], *&[2, 3, 4, 5]);
817    }
818
819    #[test]
820    fn test_shape_slice_methods() {
821        let shape = Shape::new([2, 3, 4, 5]);
822
823        let dim = shape.first();
824        assert_eq!(dim, Some(&2));
825        let dim = shape.last();
826        assert_eq!(dim, Some(&5));
827
828        assert!(!shape.is_empty());
829        let shape = Shape::new([]);
830        assert!(shape.is_empty());
831    }
832
833    #[test]
834    fn test_shape_iter() {
835        let dims = [2, 3, 4, 5];
836        let shape = Shape::new(dims);
837
838        for (d, sd) in dims.iter().zip(shape.iter()) {
839            assert_eq!(d, sd);
840        }
841    }
842
843    #[test]
844    fn test_shape_iter_mut() {
845        let mut shape = Shape::new([2, 3, 4, 5]);
846
847        for d in shape.iter_mut() {
848            *d += 1;
849        }
850
851        assert_eq!(&shape.dims, &[3, 4, 5, 6]);
852    }
853
854    #[test]
855    fn test_shape_as_slice() {
856        let dims = [2, 3, 4, 5];
857        let shape = Shape::new(dims);
858
859        assert_eq!(shape.as_slice(), dims.as_slice());
860
861        // Deref coercion
862        let shape_slice: &[usize] = &shape;
863        assert_eq!(shape_slice, *&[2, 3, 4, 5]);
864    }
865
866    #[test]
867    fn test_shape_as_mut_slice() {
868        let mut dims = [2, 3, 4, 5];
869        let mut shape = Shape::new(dims);
870
871        let shape_mut = shape.as_mut_slice();
872        assert_eq!(shape_mut, dims.as_mut_slice());
873        shape_mut[1] = 6;
874
875        assert_eq!(shape_mut, &[2, 6, 4, 5]);
876
877        let mut shape = Shape::new(dims);
878        let shape = &mut shape[..];
879        shape[1] = 6;
880
881        assert_eq!(shape, shape_mut)
882    }
883
884    #[test]
885    fn test_shape_flatten() {
886        let shape = Shape::new([2, 3, 4, 5]);
887        assert_eq!(shape.num_elements(), 120);
888
889        let shape = shape.flatten();
890        assert_eq!(shape.num_elements(), 120);
891        assert_eq!(&shape.dims, &[120]);
892    }
893
894    #[test]
895    fn test_ravel() {
896        let shape = Shape::new([2, 3, 4, 5]);
897
898        assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0);
899        assert_eq!(
900            shape.ravel_index(&[1, 2, 3, 4]),
901            1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
902        );
903    }
904
905    #[test]
906    fn test_shape_insert_remove_push() {
907        let dims = [2, 3, 4, 5];
908        let mut shape = Shape::new(dims);
909        let size = 6;
910        shape.insert(1, size);
911
912        assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
913
914        let removed = shape.remove(1);
915        assert_eq!(removed, size);
916        assert_eq!(shape, Shape::new(dims));
917
918        shape.push(6);
919        assert_eq!(shape, Shape::new([2, 3, 4, 5, 6]));
920    }
921
922    #[test]
923    fn test_shape_swap_permute() {
924        let dims = [2, 3, 4, 5];
925        let shape = Shape::new(dims);
926        let shape = shape.swap(1, 2).unwrap();
927
928        assert_eq!(&shape.dims, &[2, 4, 3, 5]);
929
930        let shape = shape.permute(&[0, 2, 1, 3]).unwrap();
931        assert_eq!(shape, Shape::new(dims));
932    }
933
934    #[test]
935    #[should_panic]
936    fn test_shape_swap_out_of_bounds() {
937        let shape = Shape::new([2, 3, 4, 5]);
938
939        shape.swap(0, 4).unwrap();
940    }
941
942    #[test]
943    #[should_panic]
944    fn test_shape_permute_incomplete() {
945        let shape = Shape::new([2, 3, 4, 5]);
946
947        shape.permute(&[0, 2, 1]).unwrap();
948    }
949
950    #[test]
951    fn test_shape_repeat() {
952        let shape = Shape::new([2, 3, 4, 5]);
953
954        let out = shape.repeat(2, 3).unwrap();
955        assert_eq!(out, Shape::new([2, 3, 12, 5]));
956    }
957
958    #[test]
959    fn test_shape_repeat_invalid() {
960        let shape = Shape::new([2, 3, 4, 5]);
961
962        let out = shape.repeat(5, 3);
963        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
964    }
965
966    #[test]
967    fn test_shape_reduce() {
968        let shape = Shape::new([2, 3, 4, 5]);
969
970        let out = shape.reduce(2).unwrap();
971        assert_eq!(out, Shape::new([2, 3, 1, 5]));
972    }
973
974    #[test]
975    fn test_shape_reduce_invalid() {
976        let shape = Shape::new([2, 3, 4, 5]);
977
978        let out = shape.reduce(5);
979        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 }));
980    }
981
982    #[test]
983    fn test_shape_broadcast_binary() {
984        let lhs = Shape::new([1, 1, 2, 4]);
985        let rhs = Shape::new([7, 6, 2, 1]);
986
987        let out = lhs.broadcast(&rhs).unwrap();
988        assert_eq!(out, Shape::new([7, 6, 2, 4]));
989    }
990
991    #[test]
992    fn test_shape_broadcast_rank_mismatch() {
993        let lhs = Shape::new([1, 2, 4]);
994        let rhs = Shape::new([7, 6, 2, 4]);
995
996        let out = lhs.broadcast(&rhs);
997        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
998    }
999
1000    #[test]
1001    fn test_shape_broadcast_incompatible_dims() {
1002        let lhs = Shape::new([1, 2, 2, 4]);
1003        let rhs = Shape::new([7, 6, 2, 1]);
1004
1005        let out = lhs.broadcast(&rhs);
1006        assert_eq!(
1007            out,
1008            Err(ShapeError::IncompatibleDims {
1009                left: 2,
1010                right: 6,
1011                dim: 1
1012            })
1013        );
1014    }
1015
1016    #[test]
1017    fn test_shape_broadcast_many() {
1018        let s1 = Shape::new([1, 1, 2, 4]);
1019        let s2 = Shape::new([7, 1, 2, 1]);
1020        let s3 = Shape::new([7, 6, 1, 1]);
1021
1022        let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
1023        assert_eq!(out, Shape::new([7, 6, 2, 4]));
1024    }
1025
1026    #[test]
1027    fn test_shape_broadcast_many_rank_mismatch() {
1028        let s1 = Shape::new([1, 1, 2, 4]);
1029        let s2 = Shape::new([7, 1, 2, 1]);
1030        let s3 = Shape::new([1, 6, 1]);
1031
1032        let out = Shape::broadcast_many([&s1, &s2, &s3]);
1033        assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 }));
1034    }
1035
1036    #[test]
1037    fn test_shape_broadcast_many_incompatible_dims() {
1038        let s1 = Shape::new([1, 1, 2, 4]);
1039        let s2 = Shape::new([7, 1, 2, 1]);
1040        let s3 = Shape::new([4, 6, 1, 1]);
1041
1042        let out = Shape::broadcast_many([&s1, &s2, &s3]);
1043        assert_eq!(
1044            out,
1045            Err(ShapeError::IncompatibleDims {
1046                left: 7,
1047                right: 4,
1048                dim: 0
1049            })
1050        );
1051    }
1052
1053    #[test]
1054    fn test_shape_broadcast_many_empty() {
1055        let out = Shape::broadcast_many(&[]);
1056        assert_eq!(out, Err(ShapeError::Empty));
1057    }
1058
1059    #[test]
1060    fn test_shape_matmul_2d() {
1061        let lhs = Shape::new([2, 4]);
1062        let rhs = Shape::new([4, 2]);
1063        let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1064        assert_eq!(out, Shape::new([2, 2]));
1065    }
1066
1067    #[test]
1068    fn test_shape_matmul_4d_broadcasted() {
1069        let lhs = Shape::new([1, 3, 2, 4]);
1070        let rhs = Shape::new([2, 1, 4, 2]);
1071        let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1072        assert_eq!(out, Shape::new([2, 3, 2, 2]));
1073    }
1074
1075    #[test]
1076    fn test_shape_matmul_invalid_rank() {
1077        let lhs = Shape::new([3, 2, 4]);
1078        let rhs = Shape::new([2, 1, 4, 2]);
1079        let out = calculate_matmul_output(&lhs, &rhs);
1080        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 }));
1081    }
1082
1083    #[test]
1084    fn test_shape_matmul_invalid_shape() {
1085        let lhs = Shape::new([1, 3, 2, 4]);
1086        let rhs = Shape::new([2, 1, 3, 2]);
1087        let out = calculate_matmul_output(&lhs, &rhs);
1088        assert_eq!(
1089            out,
1090            Err(ShapeError::IncompatibleShapes {
1091                left: lhs,
1092                right: rhs
1093            })
1094        );
1095    }
1096
1097    #[test]
1098    fn test_shape_matmul_invalid_broadcast() {
1099        let lhs = Shape::new([1, 3, 2, 4]);
1100        let rhs = Shape::new([2, 2, 4, 2]);
1101        let out = calculate_matmul_output(&lhs, &rhs);
1102        assert_eq!(
1103            out,
1104            Err(ShapeError::IncompatibleDims {
1105                left: 3,
1106                right: 2,
1107                dim: 1
1108            })
1109        );
1110    }
1111
1112    #[test]
1113    fn test_shape_cat() {
1114        let s1 = Shape::new([2, 3, 4, 5]);
1115        let s2 = Shape::new([1, 3, 4, 5]);
1116        let s3 = Shape::new([4, 3, 4, 5]);
1117
1118        let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
1119        assert_eq!(out, Shape::new([7, 3, 4, 5]));
1120
1121        let s1 = Shape::new([2, 3, 4, 5]);
1122        let s2 = Shape::new([2, 3, 2, 5]);
1123        let s3 = Shape::new([2, 3, 1, 5]);
1124
1125        let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
1126        assert_eq!(out, Shape::new([2, 3, 7, 5]));
1127    }
1128
1129    #[test]
1130    fn test_shape_cat_empty() {
1131        let out = Shape::cat(&[], 0);
1132        assert_eq!(out, Err(ShapeError::Empty));
1133    }
1134
1135    #[test]
1136    fn test_shape_cat_dim_out_of_bounds() {
1137        let s1 = Shape::new([2, 3, 4, 5]);
1138        let s2 = Shape::new([2, 3, 4, 5]);
1139        let out = Shape::cat(&[s1, s2], 4);
1140        assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 }));
1141    }
1142
1143    #[test]
1144    fn test_shape_cat_rank_mismatch() {
1145        let s1 = Shape::new([2, 3, 4, 5]);
1146        let s2 = Shape::new([2, 3, 4, 5, 6]);
1147        let out = Shape::cat(&[s1, s2], 0);
1148        assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 }));
1149    }
1150
1151    #[test]
1152    fn test_shape_cat_incompatible_shapes() {
1153        let s1 = Shape::new([2, 3, 4, 5]);
1154        let s2 = Shape::new([1, 3, 4, 5]);
1155        let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
1156
1157        assert_eq!(
1158            out,
1159            Err(ShapeError::IncompatibleShapes {
1160                left: s1,
1161                right: s2
1162            })
1163        );
1164    }
1165
1166    #[test]
1167    fn test_shape_slice_output_shape_basic() {
1168        // Test basic slicing with step=1
1169        let slices = [
1170            Slice::new(0, Some(5), 1), // 5 elements
1171            Slice::new(2, Some(8), 1), // 6 elements
1172        ];
1173        let original_shape = Shape::new([10, 10, 10]);
1174        let result = original_shape.slice(&slices).unwrap();
1175        assert_eq!(result, Shape::new([5, 6, 10]));
1176    }
1177
1178    #[test]
1179    fn test_shape_slice_output_shape_with_positive_steps() {
1180        // Test slicing with various positive steps
1181        let slices = [
1182            Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements
1183            Slice::new(1, Some(9), 3),  // [1,4,7] -> 3 elements
1184            Slice::new(0, Some(7), 4),  // [0,4] -> 2 elements
1185        ];
1186        let original_shape = Shape::new([20, 20, 20, 30]);
1187        let result = original_shape.slice(&slices).unwrap();
1188        assert_eq!(result, Shape::new([5, 3, 2, 30]));
1189    }
1190
1191    #[test]
1192    fn test_shape_slice_output_shape_with_negative_steps() {
1193        // Test slicing with negative steps (backward iteration)
1194        let slices = [
1195            Slice::new(0, Some(10), -1), // 10 elements traversed backward
1196            Slice::new(2, Some(8), -2),  // [7,5,3] -> 3 elements
1197        ];
1198        let original_shape = Shape::new([20, 20, 20]);
1199        let result = original_shape.slice(&slices).unwrap();
1200        assert_eq!(result, Shape::new([10, 3, 20]));
1201    }
1202
1203    #[test]
1204    fn test_shape_slice_output_shape_mixed_steps() {
1205        // Test with a mix of positive, negative, and unit steps
1206        let slices = [
1207            Slice::from_range_stepped(1..6, 1),   // 5 elements
1208            Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements
1209            Slice::from_range_stepped(2..14, 4),  // [2,6,10] -> 3 elements
1210        ];
1211        let original_shape = Shape::new([20, 20, 20]);
1212        let result = original_shape.slice(&slices).unwrap();
1213        assert_eq!(result, Shape::new([5, 4, 3]));
1214    }
1215
1216    #[test]
1217    fn test_shape_slice_output_shape_partial_dims() {
1218        // Test when slices has fewer dimensions than original shape
1219        let slices = [
1220            Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements
1221        ];
1222        let original_shape = Shape::new([10, 20, 30, 40]);
1223        let result = original_shape.slice(&slices).unwrap();
1224        assert_eq!(result, Shape::new([3, 20, 30, 40]));
1225    }
1226
1227    #[test]
1228    fn test_shape_slice_output_shape_edge_cases() {
1229        // Test edge cases with small ranges and large steps
1230        let slices = [
1231            Slice::from_range_stepped(0..1, 1),    // Single element
1232            Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element
1233            Slice::from_range_stepped(5..5, 1),    // Empty range -> 0 elements
1234        ];
1235        let original_shape = Shape::new([10, 20, 30]);
1236        let result = original_shape.slice(&slices).unwrap();
1237        assert_eq!(result, Shape::new([1, 1, 0]));
1238    }
1239
1240    #[test]
1241    fn test_shape_slice_output_shape_empty() {
1242        // Test with no slice infos (should return original shape)
1243        let slices = [];
1244        let original_shape = Shape::new([10, 20, 30]);
1245        let result = original_shape.slice(&slices).unwrap();
1246        assert_eq!(result, Shape::new([10, 20, 30]));
1247    }
1248
1249    #[test]
1250    fn test_shape_slice_output_shape_uneven_division() {
1251        // Test cases where range size doesn't divide evenly by step
1252        let slices = [
1253            Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6]
1254            Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8]
1255            Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6]
1256        ];
1257        let original_shape = Shape::new([20, 20, 20]);
1258        let result = original_shape.slice(&slices).unwrap();
1259        assert_eq!(result, Shape::new([3, 3, 2]));
1260    }
1261
1262    #[test]
1263    fn test_shape_expand() {
1264        let shape = Shape::new([1, 3, 1]);
1265        let expanded = Shape::new([2, 3, 4]);
1266        let out = shape.expand(expanded.clone()).unwrap();
1267        assert_eq!(out, expanded);
1268    }
1269
1270    #[test]
1271    fn test_shape_expand_higher_rank() {
1272        let shape = Shape::new([1, 4]);
1273        let expanded = Shape::new([2, 3, 4]);
1274        let out = shape.expand(expanded.clone()).unwrap();
1275        assert_eq!(out, expanded);
1276    }
1277
1278    #[test]
1279    fn test_shape_expand_invalid_rank() {
1280        let shape = Shape::new([1, 3, 1]);
1281        let expanded = Shape::new([3, 4]);
1282        let out = shape.expand(expanded);
1283        assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 2 }));
1284    }
1285
1286    #[test]
1287    fn test_shape_expand_incompatible_dims() {
1288        let shape = Shape::new([1, 3, 2]);
1289        let expanded = Shape::new([2, 3, 4]);
1290        let out = shape.expand(expanded);
1291        assert_eq!(
1292            out,
1293            Err(ShapeError::IncompatibleDims {
1294                left: 2,
1295                right: 4,
1296                dim: 2
1297            })
1298        );
1299    }
1300
1301    #[test]
1302    fn test_shape_reshape() {
1303        let shape = Shape::new([2, 3, 4, 5]);
1304        let reshaped = Shape::new([1, 2, 12, 5]);
1305        let out = shape.reshape(reshaped.clone()).unwrap();
1306        assert_eq!(out, reshaped);
1307    }
1308
1309    #[test]
1310    fn test_shape_reshape_invalid() {
1311        let shape = Shape::new([2, 3, 4, 5]);
1312        let reshaped = Shape::new([2, 2, 12, 5]);
1313        let out = shape.clone().reshape(reshaped.clone());
1314        assert_eq!(
1315            out,
1316            Err(ShapeError::IncompatibleShapes {
1317                left: shape,
1318                right: reshaped
1319            })
1320        );
1321    }
1322
1323    #[test]
1324    fn test_flatten_dims() {
1325        let shape = Shape::new([2, 3, 4, 5]);
1326        let flattened = shape.flatten_dims(-2, 3);
1327        assert_eq!(flattened, Shape::new([2, 3, 20]));
1328    }
1329}