Skip to main content

cubecl_zspace/
shape.rs

1//! Tensor shape definition.
2
3use super::indexing::ravel_index;
4use alloc::format;
5use alloc::string::{String, ToString};
6use alloc::vec::Vec;
7use core::fmt::{Debug, Display, Formatter};
8use core::str::FromStr;
9use core::{
10    ops::{Deref, DerefMut, Index, IndexMut, Range},
11    slice::{Iter, IterMut, SliceIndex},
12};
13use serde::{Deserialize, Serialize};
14use smallvec::{SmallVec, smallvec};
15
16pub use crate::errors::ExpressionError;
17use crate::{
18    INLINE_DIMS,
19    indexing::{AsIndex, AsSize},
20};
21
22/// Shape of a tensor.
23#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
24pub struct Shape {
25    /// The dimensions of the tensor.
26    dims: SmallVec<[usize; INLINE_DIMS]>,
27}
28
29#[allow(missing_docs)]
30#[derive(Debug, Clone, PartialEq, Eq)]
31/// Error that can occur when attempting to modify shapes.
32pub enum MetadataError {
33    /// The operands have different ranks.
34    RankMismatch { left: usize, right: usize },
35    /// A pair of dimensions are incompatible for broadcasting.
36    IncompatibleDims {
37        left: usize,
38        right: usize,
39        dim: usize,
40    },
41    /// Invalid dimension specified for the rank.
42    OutOfBounds { dim: usize, rank: usize },
43    /// A pair of shapes are incompatible for the operation.
44    IncompatibleShapes { left: Shape, right: Shape },
45    /// Invalid shape.
46    Invalid { reason: String },
47}
48
49impl MetadataError {
50    fn empty() -> Self {
51        Self::Invalid {
52            reason: "Shape is empty.".into(),
53        }
54    }
55}
56
57impl Shape {
58    /// Constructs a new `Shape`.
59    pub fn new<const D: usize>(dims: [usize; D]) -> Self {
60        // For backward compat
61        Self {
62            dims: SmallVec::from_slice(&dims),
63        }
64    }
65
66    /// Constructs a new `Shape` from raw backing storage. Mainly intended for macro use.
67    pub fn new_raw(dims: SmallVec<[usize; INLINE_DIMS]>) -> Self {
68        Self { dims }
69    }
70
71    /// Returns the total number of elements of a tensor having this shape
72    pub fn num_elements(&self) -> usize {
73        self.dims.iter().product()
74    }
75
76    /// Returns the number of dimensions.
77    ///
78    /// Alias for `Shape::rank()`.
79    pub fn num_dims(&self) -> usize {
80        self.dims.len()
81    }
82
83    /// Returns the rank (the number of dimensions).
84    ///
85    /// Alias for `Shape::num_dims()`.
86    pub fn rank(&self) -> usize {
87        self.num_dims()
88    }
89
90    // For compat with dims: [usize; D]
91    /// Returns the dimensions of the tensor as an array.
92    pub fn dims<const D: usize>(&self) -> [usize; D] {
93        let mut dims = [1; D];
94        dims[..D].copy_from_slice(&self.dims[..D]);
95        dims
96    }
97
98    /// Change the shape to one dimensional with the same number of elements.
99    pub fn flatten(mut self) -> Self {
100        self.dims = SmallVec::from_slice(&[self.num_elements()]);
101        self
102    }
103
104    /// Flatten the shape along a given range of dimensions.
105    ///
106    /// This function collapses the specified range of dimensions into a single dimension,
107    /// effectively flattening the tensor in that range.
108    ///
109    /// # Arguments
110    ///
111    /// - `start_dim`: The starting dimension of the range to be flattened,
112    ///   supports negative indexing.
113    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),
114    ///   supports negative indexing.
115    ///
116    /// # Returns
117    ///
118    /// A new `Shape` instance with the specified range of dimensions flattened.
119    ///
120    /// # Example
121    ///
122    /// ```rust
123    /// use cubecl_zspace::Shape;
124    ///
125    /// fn example() {
126    ///     let shape = Shape::new([2, 3, 4]);
127    ///
128    ///     let flattened = shape.flatten_dims(1, 2);
129    ///     println!("{flattened}");
130    ///     // [2, 12]
131    /// }
132    /// ```
133    pub fn flatten_dims(self, start_dim: impl AsIndex, end_dim: impl AsIndex) -> Self {
134        let rank = self.rank();
135        let start = start_dim.expect_dim_index(rank);
136        let end = end_dim.expect_dim_index(rank);
137
138        assert!(
139            start <= end,
140            "start_dim ({start}) must be <= than end_dim ({end})"
141        );
142
143        let existing = self.dims;
144
145        let flattened_size = existing[start..=end].iter().product();
146
147        let new_rank = rank - (end - start);
148        let mut dims = smallvec![0; new_rank];
149        dims[..start].copy_from_slice(&existing[..start]);
150        dims[start] = flattened_size;
151        dims[start + 1..].copy_from_slice(&existing[end + 1..]);
152
153        Self { dims }
154    }
155
156    /// Compute the ravel index for the given coordinates.
157    ///
158    /// This returns the row-major order raveling:
159    /// * `strides[-1] = 1`
160    /// * `strides[i] = strides[i+1] * dims[i+1]`
161    /// * `dim_strides = coords * strides`
162    /// * `ravel = sum(dim_strides)`
163    ///
164    /// # Arguments
165    /// - `indices`: the index for each dimension; must be the same length as `shape`.
166    ///
167    /// # Returns
168    /// - the ravel offset index.
169    pub fn ravel_index<I: AsIndex>(&self, indices: &[I]) -> usize {
170        ravel_index(indices, &self.dims)
171    }
172
173    /// Convert shape dimensions to full covering ranges (0..dim) for each dimension.
174    pub fn into_ranges(self) -> Vec<Range<usize>> {
175        self.iter().map(|&d| 0..d).collect()
176    }
177
178    /// Construct a vector of the dims.
179    pub fn to_vec(&self) -> Vec<usize> {
180        self.dims.to_vec()
181    }
182
183    /// Returns an iterator over the shape dimensions.
184    pub fn iter(&self) -> Iter<'_, usize> {
185        self.dims.iter()
186    }
187
188    /// Mutable iterator over the dimensions.
189    pub fn iter_mut(&mut self) -> IterMut<'_, usize> {
190        self.dims.iter_mut()
191    }
192
193    /// Borrow the underlying dimensions slice.
194    pub fn as_slice(&self) -> &[usize] {
195        &self.dims
196    }
197
198    /// Borrow the underlying dimensions slice mutably.
199    pub fn as_mut_slice(&mut self) -> &mut [usize] {
200        &mut self.dims
201    }
202
203    /// Insert a dimension of `size` at position `index`.
204    pub fn insert(&mut self, index: usize, size: usize) {
205        self.dims.insert(index, size);
206    }
207
208    /// Remove and return the dimension at position `index` from the shape.
209    pub fn remove(&mut self, index: usize) -> usize {
210        self.dims.remove(index)
211    }
212
213    /// Appends a dimension of `size` to the back of the shape.
214    pub fn push(&mut self, size: usize) {
215        self.dims.push(size)
216    }
217
218    /// Extend the shape with the content of another shape or iterator.
219    pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
220        self.dims.extend(iter)
221    }
222
223    /// Swap two dimensions in the shape.
224    pub fn swapped(mut self, dim1: usize, dim2: usize) -> Result<Self, MetadataError> {
225        if dim1 >= self.rank() {
226            return Err(MetadataError::OutOfBounds {
227                dim: dim1,
228                rank: self.rank(),
229            });
230        }
231        if dim2 >= self.rank() {
232            return Err(MetadataError::OutOfBounds {
233                dim: dim2,
234                rank: self.rank(),
235            });
236        }
237        self.dims.swap(dim1, dim2);
238        Ok(self)
239    }
240
241    /// Reorder the shape dimensions according to the permutation of `axes`.
242    pub fn permute(&mut self, axes: &[usize]) -> Result<(), MetadataError> {
243        if axes.len() != self.rank() {
244            return Err(MetadataError::RankMismatch {
245                left: self.rank(),
246                right: axes.len(),
247            });
248        }
249        debug_assert!(axes.iter().all(|i| i < &self.rank()));
250
251        self.dims = axes.iter().map(|&i| self.dims[i]).collect();
252        Ok(())
253    }
254
255    /// Reorder the shape dimensions according to the permutation of `axes`.
256    pub fn permuted(mut self, axes: &[usize]) -> Result<Self, MetadataError> {
257        self.permute(axes)?;
258        Ok(self)
259    }
260
261    /// Repeated the specified `dim` a number of `times`.
262    pub fn repeat(mut self, dim: usize, times: usize) -> Result<Shape, MetadataError> {
263        if dim >= self.rank() {
264            return Err(MetadataError::OutOfBounds {
265                dim,
266                rank: self.rank(),
267            });
268        }
269
270        self.dims[dim] *= times;
271        Ok(self)
272    }
273
274    /// Returns a new shape where the specified `dim` is reduced to size 1.
275    pub fn reduce(mut self, dim: usize) -> Result<Shape, MetadataError> {
276        if dim >= self.rank() {
277            return Err(MetadataError::OutOfBounds {
278                dim,
279                rank: self.rank(),
280            });
281        }
282
283        self.dims[dim] = 1;
284        Ok(self)
285    }
286
287    /// Concatenates all shapes into a new one along the given dimension.
288    pub fn cat<'a, I>(shapes: I, dim: usize) -> Result<Self, MetadataError>
289    where
290        I: IntoIterator<Item = &'a Shape>,
291    {
292        let mut iter = shapes.into_iter();
293
294        let first = iter.next().ok_or(MetadataError::empty())?;
295
296        if dim >= first.rank() {
297            return Err(MetadataError::OutOfBounds {
298                dim,
299                rank: first.rank(),
300            });
301        }
302
303        let mut shape = first.clone();
304
305        for s in iter {
306            if s.rank() != shape.rank() {
307                return Err(MetadataError::RankMismatch {
308                    left: shape.rank(),
309                    right: s.rank(),
310                });
311            }
312
313            if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] {
314                return Err(MetadataError::IncompatibleShapes {
315                    left: shape.clone(),
316                    right: s.clone(),
317                });
318            }
319
320            shape[dim] += s[dim];
321        }
322
323        Ok(shape)
324    }
325
326    /// Compute the output shape for binary operations with broadcasting support.
327    ///
328    /// - Shapes must be of the same rank (missing dimensions are not handled automatically).
329    /// - Two dimensions are compatible if they are equal, or one of them is 1.
330    ///
331    /// For example, a shape `[1, 1, 2, 4]` can be broadcast into `[7, 6, 2, 4]`
332    /// because its axes are either equal or 1. On the other hand, a shape `[2, 2]`
333    /// can *not* be broadcast into `[2, 4]`.
334    pub fn broadcast(&self, other: &Self) -> Result<Self, MetadataError> {
335        Self::broadcast_many([self, other])
336    }
337
338    /// Compute the broadcasted output shape across multiple input shapes.
339    ///
340    /// See also [broadcast](Self::broadcast).
341    pub fn broadcast_many<'a, I>(shapes: I) -> Result<Self, MetadataError>
342    where
343        I: IntoIterator<Item = &'a Shape>,
344    {
345        let mut iter = shapes.into_iter();
346        let mut broadcasted = iter.next().ok_or(MetadataError::empty())?.clone();
347        let rank = broadcasted.rank();
348
349        for shape in iter {
350            if shape.rank() != rank {
351                return Err(MetadataError::RankMismatch {
352                    left: rank,
353                    right: shape.rank(),
354                });
355            }
356
357            for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() {
358                match (*d_lhs, d_rhs) {
359                    (a, b) if a == b => {} // same
360                    (1, b) => *d_lhs = b,  // broadcast to rhs
361                    (_a, 1) => {}          // keep existing dimension
362                    _ => {
363                        return Err(MetadataError::IncompatibleDims {
364                            left: *d_lhs,
365                            right: d_rhs,
366                            dim,
367                        });
368                    }
369                }
370            }
371        }
372
373        Ok(broadcasted)
374    }
375
376    /// Expand this shape to match the target shape, following broadcasting rules.
377    pub fn expand(&self, target: Shape) -> Result<Shape, MetadataError> {
378        let target_rank = target.rank();
379        if self.rank() > target_rank {
380            return Err(MetadataError::RankMismatch {
381                left: self.rank(),
382                right: target_rank,
383            });
384        }
385
386        for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() {
387            if dim_self != dim_target && *dim_self != 1 {
388                return Err(MetadataError::IncompatibleDims {
389                    left: *dim_self,
390                    right: *dim_target,
391                    dim: target_rank - i - 1,
392                });
393            }
394        }
395
396        Ok(target)
397    }
398
399    /// Reshape this shape to the target shape.
400    pub fn reshape<A, T>(&self, args: A) -> Result<Shape, MetadataError>
401    where
402        A: AsRef<[T]> + Debug,
403        T: AsIndex,
404    {
405        let args = args.as_ref();
406        let mut infer_index = None;
407        let mut dims = Vec::new();
408
409        let mut new_size = 1;
410
411        for (idx, &s) in args.iter().enumerate() {
412            let s = s.as_index();
413            if s > 0 {
414                let s = s as usize;
415                new_size *= s;
416                dims.push(s);
417            } else if s == 0 {
418                // We need to find the index of the 0 dimensions and
419                // replace them with the actual dimension value.
420                let s = self.dims[idx];
421                new_size *= s;
422                dims.push(s);
423            } else if s == -1 {
424                match infer_index {
425                    None => {
426                        infer_index = Some(idx);
427                        // Used by / Replaced by handling later.
428                        dims.push(1);
429                    }
430                    Some(_) => {
431                        return Err(MetadataError::Invalid {
432                            reason: "Repeated -1 in reshape".to_string(),
433                        });
434                    }
435                }
436            } else {
437                return Err(MetadataError::Invalid {
438                    reason: "The given shape cannot contain negative dimensions (other than -1)."
439                        .to_string(),
440                });
441            }
442        }
443
444        let source_size = self.num_elements();
445        match infer_index {
446            None => {
447                if source_size != new_size {
448                    return Err(MetadataError::Invalid {
449                        reason: format!(
450                            "The given shape doesn't have the same number of elements as the current shape. Current shape: {self}, target shape: {dims:?}.",
451                        ),
452                    });
453                }
454            }
455            Some(idx) => {
456                if !source_size.is_multiple_of(new_size) {
457                    return Err(MetadataError::Invalid {
458                        reason: format!(
459                            "Cannot infer a valid target shape. Current shape: {self}, target dimensions: {args:?}."
460                        ),
461                    });
462                }
463                dims[idx] = source_size / new_size;
464            }
465        }
466
467        Ok(dims.into())
468    }
469}
470
471#[macro_export]
472macro_rules! shape {
473    (@one $x:expr) => (1usize);
474    () => (
475        $crate::Shape::new_raw($crate::SmallVec::new())
476    );
477    ($elem:expr; $n:expr) => ({
478        $crate::Shape::new_raw($crate::smallvec!($elem; $n))
479    });
480    ($($x:expr),+$(,)?) => ({
481        $crate::Shape::new_raw($crate::smallvec!($($x),*))
482    });
483}
484
485/// Compute the output shape for matrix multiplication with broadcasting support.
486///
487/// The last two dimensions are treated as matrices, while preceding dimensions
488/// follow broadcast semantics similar to elementwise operations.
489pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result<Shape, MetadataError> {
490    let rank = lhs.rank();
491    if rank != rhs.rank() {
492        return Err(MetadataError::RankMismatch {
493            left: rank,
494            right: rhs.rank(),
495        });
496    }
497
498    if lhs[rank - 1] != rhs[rank - 2] {
499        return Err(MetadataError::IncompatibleShapes {
500            left: lhs.clone(),
501            right: rhs.clone(),
502        });
503    }
504
505    let mut shape = if rank > 2 {
506        // Broadcast leading dims
507        Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))?
508    } else {
509        Shape::new([])
510    };
511    shape.extend([lhs[rank - 2], rhs[rank - 1]]);
512
513    Ok(shape)
514}
515
516impl Display for Shape {
517    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
518        self.dims.fmt(f)
519    }
520}
521
522impl FromStr for Shape {
523    type Err = ExpressionError;
524
525    fn from_str(source: &str) -> Result<Self, Self::Err> {
526        let mut s = source.trim();
527
528        const DELIMS: [(&str, &str); 2] = [("[", "]"), ("(", ")")];
529
530        for (open, close) in DELIMS {
531            if let Some(p) = s.strip_prefix(open) {
532                if let Some(p) = p.strip_suffix(close) {
533                    s = p.trim();
534                    break;
535                } else {
536                    return Err(ExpressionError::ParseError {
537                        message: "Unbalanced delimiters".to_string(),
538                        source: source.to_string(),
539                    });
540                }
541            }
542        }
543
544        if s.is_empty() {
545            return Ok(Shape::new([]));
546        }
547
548        let dims = s
549            .split(',')
550            .map(|dim_str| {
551                dim_str
552                    .trim()
553                    .parse::<usize>()
554                    .map_err(|_| ExpressionError::ParseError {
555                        message: "Unable to parse shape".to_string(),
556                        source: source.to_string(),
557                    })
558            })
559            .collect::<Result<SmallVec<_>, ExpressionError>>()?;
560
561        if dims.is_empty() {
562            unreachable!("Split should have returned at least one element");
563        }
564
565        Ok(Shape { dims })
566    }
567}
568
569impl<Idx> Index<Idx> for Shape
570where
571    Idx: SliceIndex<[usize]>,
572{
573    type Output = Idx::Output;
574
575    fn index(&self, index: Idx) -> &Self::Output {
576        &self.dims[index]
577    }
578}
579
580impl<Idx> IndexMut<Idx> for Shape
581where
582    Idx: SliceIndex<[usize]>,
583{
584    fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
585        &mut self.dims[index]
586    }
587}
588
589// Allow `&shape` to behave like a slice `&[usize]` directly
590impl Deref for Shape {
591    type Target = [usize];
592
593    fn deref(&self) -> &Self::Target {
594        &self.dims
595    }
596}
597
598// Allow `&shape` to behave like a mut slice `&mut [usize]` directly
599impl DerefMut for Shape {
600    fn deref_mut(&mut self) -> &mut Self::Target {
601        &mut self.dims
602    }
603}
604// Allow `shape.reshape(other_shape)`.
605//
606// By implementing `AsRef<[usize]>`, `Shape` behaves like a slice of dimensions,
607// similar to how `Vec<T>` can be passed to functions expecting a slice.
608impl AsRef<[usize]> for Shape {
609    fn as_ref(&self) -> &[usize] {
610        &self.dims
611    }
612}
613
614impl From<Shape> for Vec<usize> {
615    fn from(shape: Shape) -> Self {
616        shape.dims.to_vec()
617    }
618}
619
620impl<T, I> From<T> for Shape
621where
622    T: IntoIterator<Item = I>,
623    I: AsSize,
624{
625    fn from(dims: T) -> Self {
626        Shape {
627            dims: dims.into_iter().map(|d| d.as_size()).collect(),
628        }
629    }
630}
631
632impl From<&Shape> for Shape {
633    fn from(value: &Shape) -> Self {
634        value.clone()
635    }
636}
637
638impl<I: AsSize> FromIterator<I> for Shape {
639    fn from_iter<T: IntoIterator<Item = I>>(iter: T) -> Self {
640        Shape {
641            dims: iter.into_iter().map(|it| it.as_size()).collect(),
642        }
643    }
644}
645
646#[cfg(test)]
647#[allow(clippy::identity_op, reason = "useful for clarity")]
648mod tests {
649    use super::*;
650    use alloc::string::ToString;
651    use alloc::vec;
652
653    #[test]
654    fn test_shape_to_str() {
655        let shape = Shape::new([2, 3, 4, 5]);
656        assert_eq!(shape.to_string(), "[2, 3, 4, 5]");
657    }
658
659    #[test]
660    fn test_shape_from_str() {
661        assert_eq!(
662            "[2, 3, 4, 5]".parse::<Shape>().unwrap(),
663            Shape::new([2, 3, 4, 5])
664        );
665        assert_eq!(
666            "(2, 3, 4, 5)".parse::<Shape>().unwrap(),
667            Shape::new([2, 3, 4, 5])
668        );
669        assert_eq!(
670            "2, 3, 4, 5".parse::<Shape>().unwrap(),
671            Shape::new([2, 3, 4, 5])
672        );
673
674        assert_eq!("[2]".parse::<Shape>().unwrap(), Shape::new([2]));
675        assert_eq!("(2)".parse::<Shape>().unwrap(), Shape::new([2]));
676        assert_eq!("2".parse::<Shape>().unwrap(), Shape::new([2]));
677
678        assert_eq!("[]".parse::<Shape>().unwrap(), Shape::new([]));
679        assert_eq!("".parse::<Shape>().unwrap(), Shape::new([]));
680
681        assert_eq!(
682            "[".parse::<Shape>(),
683            Err(ExpressionError::ParseError {
684                message: "Unbalanced delimiters".to_string(),
685                source: "[".to_string()
686            })
687        );
688
689        assert_eq!(
690            "[[1]".parse::<Shape>(),
691            Err(ExpressionError::ParseError {
692                message: "Unable to parse shape".to_string(),
693                source: "[[1]".to_string()
694            })
695        );
696        assert_eq!(
697            "[[1]]".parse::<Shape>(),
698            Err(ExpressionError::ParseError {
699                message: "Unable to parse shape".to_string(),
700                source: "[[1]]".to_string()
701            })
702        );
703        assert_eq!(
704            "[1)".parse::<Shape>(),
705            Err(ExpressionError::ParseError {
706                message: "Unbalanced delimiters".to_string(),
707                source: "[1)".to_string()
708            })
709        );
710
711        assert_eq!(
712            "]".parse::<Shape>(),
713            Err(ExpressionError::ParseError {
714                message: "Unable to parse shape".to_string(),
715                source: "]".to_string()
716            })
717        );
718
719        assert_eq!(
720            "[a]".parse::<Shape>(),
721            Err(ExpressionError::ParseError {
722                message: "Unable to parse shape".to_string(),
723                source: "[a]".to_string()
724            })
725        );
726    }
727
728    #[test]
729    fn num_dims_and_rank() {
730        let dims = [2, 3, 4, 5];
731        let shape = Shape::new(dims);
732        assert_eq!(4, shape.num_dims());
733        assert_eq!(4, shape.rank());
734    }
735
736    #[test]
737    fn num_elements() {
738        let dims = [2, 3, 4, 5];
739        let shape = Shape::new(dims);
740        assert_eq!(120, shape.num_elements());
741    }
742
743    #[test]
744    #[allow(clippy::into_iter_on_ref)]
745    fn test_shape_into_iter() {
746        let dims = [2, 3, 4, 5];
747        let shape = Shape::new(dims);
748
749        assert_eq!(shape.into_iter().sum::<usize>(), 14);
750    }
751
752    #[test]
753    fn test_into_ranges() {
754        let dims = [2, 3, 4, 5];
755        let shape = Shape::new(dims);
756        assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);
757    }
758
759    #[test]
760    fn test_to_vec() {
761        let dims = [2, 3, 4, 5];
762        let shape = Shape::new(dims);
763        assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]);
764    }
765
766    #[test]
767    fn test_shape_index() {
768        let shape = Shape::new([2, 3, 4, 5]);
769
770        assert_eq!(shape[0], 2);
771        assert_eq!(shape[1], 3);
772        assert_eq!(shape[2], 4);
773        assert_eq!(shape[3], 5);
774
775        // Works with ranges
776        assert_eq!(shape[1..3], *&[3, 4]);
777        assert_eq!(shape[1..=2], *&[3, 4]);
778        assert_eq!(shape[..], *&[2, 3, 4, 5]);
779    }
780
781    #[test]
782    fn test_shape_slice_methods() {
783        let shape = Shape::new([2, 3, 4, 5]);
784
785        let dim = shape.first();
786        assert_eq!(dim, Some(&2));
787        let dim = shape.last();
788        assert_eq!(dim, Some(&5));
789
790        assert!(!shape.is_empty());
791        let shape = Shape::new([]);
792        assert!(shape.is_empty());
793    }
794
795    #[test]
796    fn test_shape_iter() {
797        let dims = [2, 3, 4, 5];
798        let shape = Shape::new(dims);
799
800        for (d, sd) in dims.iter().zip(shape.iter()) {
801            assert_eq!(d, sd);
802        }
803    }
804
805    #[test]
806    fn test_shape_iter_mut() {
807        let mut shape = Shape::new([2, 3, 4, 5]);
808
809        for d in shape.iter_mut() {
810            *d += 1;
811        }
812
813        assert_eq!(shape.as_slice(), &[3, 4, 5, 6]);
814    }
815
816    #[test]
817    fn test_shape_as_slice() {
818        let dims = [2, 3, 4, 5];
819        let shape = Shape::new(dims);
820
821        assert_eq!(shape.as_slice(), dims.as_slice());
822
823        // Deref coercion
824        let shape_slice: &[usize] = &shape;
825        assert_eq!(shape_slice, *&[2, 3, 4, 5]);
826    }
827
828    #[test]
829    fn test_shape_as_mut_slice() {
830        let mut dims = [2, 3, 4, 5];
831        let mut shape = Shape::new(dims);
832
833        let shape_mut = shape.as_mut_slice();
834        assert_eq!(shape_mut, dims.as_mut_slice());
835        shape_mut[1] = 6;
836
837        assert_eq!(shape_mut, &[2, 6, 4, 5]);
838
839        let mut shape = Shape::new(dims);
840        let shape = &mut shape[..];
841        shape[1] = 6;
842
843        assert_eq!(shape, shape_mut)
844    }
845
846    #[test]
847    fn test_shape_flatten() {
848        let shape = Shape::new([2, 3, 4, 5]);
849        assert_eq!(shape.num_elements(), 120);
850
851        let shape = shape.flatten();
852        assert_eq!(shape.num_elements(), 120);
853        assert_eq!(shape.as_slice(), &[120]);
854    }
855
856    #[test]
857    fn test_ravel() {
858        let shape = Shape::new([2, 3, 4, 5]);
859
860        assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0);
861        assert_eq!(
862            shape.ravel_index(&[1, 2, 3, 4]),
863            1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
864        );
865    }
866
867    #[test]
868    fn test_shape_insert_remove_push() {
869        let dims = [2, 3, 4, 5];
870        let mut shape = Shape::new(dims);
871        let size = 6;
872        shape.insert(1, size);
873
874        assert_eq!(shape, Shape::new([2, 6, 3, 4, 5]));
875
876        let removed = shape.remove(1);
877        assert_eq!(removed, size);
878        assert_eq!(shape, Shape::new(dims));
879
880        shape.push(6);
881        assert_eq!(shape, Shape::new([2, 3, 4, 5, 6]));
882    }
883
884    #[test]
885    fn test_shape_swap_permute() {
886        let dims = [2, 3, 4, 5];
887        let shape = Shape::new(dims);
888        let shape = shape.swapped(1, 2).unwrap();
889
890        assert_eq!(shape.as_slice(), &[2, 4, 3, 5]);
891
892        let shape = shape.permuted(&[0, 2, 1, 3]).unwrap();
893        assert_eq!(shape, Shape::new(dims));
894    }
895
896    #[test]
897    #[should_panic]
898    fn test_shape_swap_out_of_bounds() {
899        let shape = Shape::new([2, 3, 4, 5]);
900
901        shape.swapped(0, 4).unwrap();
902    }
903
904    #[test]
905    #[should_panic]
906    fn test_shape_permute_incomplete() {
907        let shape = Shape::new([2, 3, 4, 5]);
908
909        shape.permuted(&[0, 2, 1]).unwrap();
910    }
911
912    #[test]
913    fn test_shape_repeat() {
914        let shape = Shape::new([2, 3, 4, 5]);
915
916        let out = shape.repeat(2, 3).unwrap();
917        assert_eq!(out, Shape::new([2, 3, 12, 5]));
918    }
919
920    #[test]
921    fn test_shape_repeat_invalid() {
922        let shape = Shape::new([2, 3, 4, 5]);
923
924        let out = shape.repeat(5, 3);
925        assert_eq!(out, Err(MetadataError::OutOfBounds { dim: 5, rank: 4 }));
926    }
927
928    #[test]
929    fn test_shape_reduce() {
930        let shape = Shape::new([2, 3, 4, 5]);
931
932        let out = shape.reduce(2).unwrap();
933        assert_eq!(out, Shape::new([2, 3, 1, 5]));
934    }
935
936    #[test]
937    fn test_shape_reduce_invalid() {
938        let shape = Shape::new([2, 3, 4, 5]);
939
940        let out = shape.reduce(5);
941        assert_eq!(out, Err(MetadataError::OutOfBounds { dim: 5, rank: 4 }));
942    }
943
944    #[test]
945    fn test_shape_broadcast_binary() {
946        let lhs = Shape::new([1, 1, 2, 4]);
947        let rhs = Shape::new([7, 6, 2, 1]);
948
949        let out = lhs.broadcast(&rhs).unwrap();
950        assert_eq!(out, Shape::new([7, 6, 2, 4]));
951    }
952
953    #[test]
954    fn test_shape_broadcast_rank_mismatch() {
955        let lhs = Shape::new([1, 2, 4]);
956        let rhs = Shape::new([7, 6, 2, 4]);
957
958        let out = lhs.broadcast(&rhs);
959        assert_eq!(out, Err(MetadataError::RankMismatch { left: 3, right: 4 }));
960    }
961
962    #[test]
963    fn test_shape_broadcast_incompatible_dims() {
964        let lhs = Shape::new([1, 2, 2, 4]);
965        let rhs = Shape::new([7, 6, 2, 1]);
966
967        let out = lhs.broadcast(&rhs);
968        assert_eq!(
969            out,
970            Err(MetadataError::IncompatibleDims {
971                left: 2,
972                right: 6,
973                dim: 1
974            })
975        );
976    }
977
978    #[test]
979    fn test_shape_broadcast_many() {
980        let s1 = Shape::new([1, 1, 2, 4]);
981        let s2 = Shape::new([7, 1, 2, 1]);
982        let s3 = Shape::new([7, 6, 1, 1]);
983
984        let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap();
985        assert_eq!(out, Shape::new([7, 6, 2, 4]));
986    }
987
988    #[test]
989    fn test_shape_broadcast_many_rank_mismatch() {
990        let s1 = Shape::new([1, 1, 2, 4]);
991        let s2 = Shape::new([7, 1, 2, 1]);
992        let s3 = Shape::new([1, 6, 1]);
993
994        let out = Shape::broadcast_many([&s1, &s2, &s3]);
995        assert_eq!(out, Err(MetadataError::RankMismatch { left: 4, right: 3 }));
996    }
997
998    #[test]
999    fn test_shape_broadcast_many_incompatible_dims() {
1000        let s1 = Shape::new([1, 1, 2, 4]);
1001        let s2 = Shape::new([7, 1, 2, 1]);
1002        let s3 = Shape::new([4, 6, 1, 1]);
1003
1004        let out = Shape::broadcast_many([&s1, &s2, &s3]);
1005        assert_eq!(
1006            out,
1007            Err(MetadataError::IncompatibleDims {
1008                left: 7,
1009                right: 4,
1010                dim: 0
1011            })
1012        );
1013    }
1014
1015    #[test]
1016    fn test_shape_broadcast_many_empty() {
1017        let out = Shape::broadcast_many(&[]);
1018        assert_eq!(out, Err(MetadataError::empty()));
1019    }
1020
1021    #[test]
1022    fn test_shape_matmul_2d() {
1023        let lhs = Shape::new([2, 4]);
1024        let rhs = Shape::new([4, 2]);
1025        let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1026        assert_eq!(out, Shape::new([2, 2]));
1027    }
1028
1029    #[test]
1030    fn test_shape_matmul_4d_broadcasted() {
1031        let lhs = Shape::new([1, 3, 2, 4]);
1032        let rhs = Shape::new([2, 1, 4, 2]);
1033        let out = calculate_matmul_output(&lhs, &rhs).unwrap();
1034        assert_eq!(out, Shape::new([2, 3, 2, 2]));
1035    }
1036
1037    #[test]
1038    fn test_shape_matmul_invalid_rank() {
1039        let lhs = Shape::new([3, 2, 4]);
1040        let rhs = Shape::new([2, 1, 4, 2]);
1041        let out = calculate_matmul_output(&lhs, &rhs);
1042        assert_eq!(out, Err(MetadataError::RankMismatch { left: 3, right: 4 }));
1043    }
1044
1045    #[test]
1046    fn test_shape_matmul_invalid_shape() {
1047        let lhs = Shape::new([1, 3, 2, 4]);
1048        let rhs = Shape::new([2, 1, 3, 2]);
1049        let out = calculate_matmul_output(&lhs, &rhs);
1050        assert_eq!(
1051            out,
1052            Err(MetadataError::IncompatibleShapes {
1053                left: lhs,
1054                right: rhs
1055            })
1056        );
1057    }
1058
1059    #[test]
1060    fn test_shape_matmul_invalid_broadcast() {
1061        let lhs = Shape::new([1, 3, 2, 4]);
1062        let rhs = Shape::new([2, 2, 4, 2]);
1063        let out = calculate_matmul_output(&lhs, &rhs);
1064        assert_eq!(
1065            out,
1066            Err(MetadataError::IncompatibleDims {
1067                left: 3,
1068                right: 2,
1069                dim: 1
1070            })
1071        );
1072    }
1073
1074    #[test]
1075    fn test_shape_cat() {
1076        let s1 = Shape::new([2, 3, 4, 5]);
1077        let s2 = Shape::new([1, 3, 4, 5]);
1078        let s3 = Shape::new([4, 3, 4, 5]);
1079
1080        let out = Shape::cat(&[s1, s2, s3], 0).unwrap();
1081        assert_eq!(out, Shape::new([7, 3, 4, 5]));
1082
1083        let s1 = Shape::new([2, 3, 4, 5]);
1084        let s2 = Shape::new([2, 3, 2, 5]);
1085        let s3 = Shape::new([2, 3, 1, 5]);
1086
1087        let out = Shape::cat(&[s1, s2, s3], 2).unwrap();
1088        assert_eq!(out, Shape::new([2, 3, 7, 5]));
1089    }
1090
1091    #[test]
1092    fn test_shape_cat_empty() {
1093        let out = Shape::cat(&[], 0);
1094        assert_eq!(out, Err(MetadataError::empty()));
1095    }
1096
1097    #[test]
1098    fn test_shape_cat_dim_out_of_bounds() {
1099        let s1 = Shape::new([2, 3, 4, 5]);
1100        let s2 = Shape::new([2, 3, 4, 5]);
1101        let out = Shape::cat(&[s1, s2], 4);
1102        assert_eq!(out, Err(MetadataError::OutOfBounds { dim: 4, rank: 4 }));
1103    }
1104
1105    #[test]
1106    fn test_shape_cat_rank_mismatch() {
1107        let s1 = Shape::new([2, 3, 4, 5]);
1108        let s2 = Shape::new([2, 3, 4, 5, 6]);
1109        let out = Shape::cat(&[s1, s2], 0);
1110        assert_eq!(out, Err(MetadataError::RankMismatch { left: 4, right: 5 }));
1111    }
1112
1113    #[test]
1114    fn test_shape_cat_incompatible_shapes() {
1115        let s1 = Shape::new([2, 3, 4, 5]);
1116        let s2 = Shape::new([1, 3, 4, 5]);
1117        let out = Shape::cat(&[s1.clone(), s2.clone()], 1);
1118
1119        assert_eq!(
1120            out,
1121            Err(MetadataError::IncompatibleShapes {
1122                left: s1,
1123                right: s2
1124            })
1125        );
1126    }
1127
1128    #[test]
1129    fn test_shape_expand() {
1130        let shape = Shape::new([1, 3, 1]);
1131        let expanded = Shape::new([2, 3, 4]);
1132        let out = shape.expand(expanded.clone()).unwrap();
1133        assert_eq!(out, expanded);
1134    }
1135
1136    #[test]
1137    fn test_shape_expand_higher_rank() {
1138        let shape = Shape::new([1, 4]);
1139        let expanded = Shape::new([2, 3, 4]);
1140        let out = shape.expand(expanded.clone()).unwrap();
1141        assert_eq!(out, expanded);
1142    }
1143
1144    #[test]
1145    fn test_shape_expand_invalid_rank() {
1146        let shape = Shape::new([1, 3, 1]);
1147        let expanded = Shape::new([3, 4]);
1148        let out = shape.expand(expanded);
1149        assert_eq!(out, Err(MetadataError::RankMismatch { left: 3, right: 2 }));
1150    }
1151
1152    #[test]
1153    fn test_shape_expand_incompatible_dims() {
1154        let shape = Shape::new([1, 3, 2]);
1155        let expanded = Shape::new([2, 3, 4]);
1156        let out = shape.expand(expanded);
1157        assert_eq!(
1158            out,
1159            Err(MetadataError::IncompatibleDims {
1160                left: 2,
1161                right: 4,
1162                dim: 2
1163            })
1164        );
1165    }
1166
1167    #[test]
1168    fn test_shape_reshape() {
1169        let shape = Shape::new([2, 3, 4, 5]);
1170        let reshaped = Shape::new([1, 2, 12, 5]);
1171        let out = shape.reshape(reshaped.clone()).unwrap();
1172        assert_eq!(out, reshaped);
1173    }
1174
1175    #[test]
1176    fn test_shape_reshape_invalid() {
1177        let shape = Shape::new([2, 3, 4, 5]);
1178        let reshaped = Shape::new([2, 2, 12, 5]);
1179        let out = shape.reshape(reshaped.clone());
1180        assert_eq!(
1181            out,
1182            Err(MetadataError::Invalid {
1183                reason: "The given shape doesn't have the same number of elements as the current shape. Current shape: [2, 3, 4, 5], target shape: [2, 2, 12, 5].".into(),
1184            })
1185        );
1186    }
1187
1188    #[test]
1189    fn test_shape_reshape_invalid_inferred() {
1190        let shape = Shape::new([2, 4]);
1191        let out = shape.reshape([-1, 3]);
1192        assert_eq!(
1193            out,
1194            Err(MetadataError::Invalid {
1195                reason: "Cannot infer a valid target shape. Current shape: [2, 4], target dimensions: [-1, 3].".into(),
1196            })
1197        );
1198    }
1199
1200    #[test]
1201    fn test_flatten_dims() {
1202        let shape = Shape::new([2, 3, 4, 5]);
1203        let flattened = shape.flatten_dims(-2, 3);
1204        assert_eq!(flattened, Shape::new([2, 3, 20]));
1205    }
1206}