Skip to main content

cubecl_zspace/
metadata.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{MetadataError, shape::Shape, strides::Strides};
4
5#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
6pub struct Metadata {
7    pub shape: Shape,
8    pub strides: Strides,
9}
10
11impl Metadata {
12    pub fn new(shape: impl Into<Shape>, strides: impl Into<Strides>) -> Self {
13        let shape = shape.into();
14        let strides = strides.into();
15        debug_assert_eq!(
16            shape.rank(),
17            strides.rank(),
18            "Rank of shape and strides must be the same"
19        );
20
21        Self { shape, strides }
22    }
23
24    pub fn shape(&self) -> &Shape {
25        &self.shape
26    }
27
28    pub fn shape_mut(&mut self) -> &mut Shape {
29        &mut self.shape
30    }
31
32    pub fn strides(&self) -> &Strides {
33        &self.strides
34    }
35
36    pub fn strides_mut(&mut self) -> &mut Strides {
37        &mut self.strides
38    }
39
40    pub fn rank(&self) -> usize {
41        self.num_dims()
42    }
43
44    pub fn num_dims(&self) -> usize {
45        self.shape.num_dims()
46    }
47
48    /// Returns the total number of elements of a tensor having this shape
49    pub fn num_elements(&self) -> usize {
50        self.shape.num_elements()
51    }
52
53    pub fn swapped(mut self, dim0: usize, dim1: usize) -> Self {
54        self.swap(dim0, dim1);
55        self
56    }
57
58    pub fn swap(&mut self, dim0: usize, dim1: usize) {
59        debug_assert!(dim0 < self.rank(), "dim0 is out of bounds");
60        debug_assert!(dim1 < self.rank(), "dim1 is out of bounds");
61        self.shape.swap(dim0, dim1);
62        self.strides.swap(dim0, dim1);
63    }
64
65    /// Reorder the shape dimensions according to the permutation of `axes`.
66    pub fn permute(&mut self, axes: &[usize]) -> Result<(), MetadataError> {
67        self.shape.permute(axes)?;
68        self.strides.permute(axes)?;
69
70        Ok(())
71    }
72
73    pub fn permuted(mut self, axes: &[usize]) -> Result<Self, MetadataError> {
74        self.permute(axes)?;
75        Ok(self)
76    }
77
78    /// Insert a dimension of `shape` with `stride` at position `index`.
79    pub fn insert(&mut self, index: usize, shape: usize, stride: usize) {
80        self.shape.insert(index, shape);
81        self.strides.insert(index, stride);
82    }
83
84    /// Remove and return the dimension at position `index` from the metadata.
85    pub fn remove(&mut self, index: usize) -> (usize, usize) {
86        let shape = self.shape.remove(index);
87        let stride = self.strides.remove(index);
88        (shape, stride)
89    }
90
91    /// Appends a dimension of `shape` with `stride` to the back of the metadata.
92    pub fn push(&mut self, shape: usize, stride: usize) {
93        self.shape.push(shape);
94        self.strides.push(stride);
95    }
96}