cubecl_zspace/
metadata.rs1use serde::{Deserialize, Serialize};
2
3use crate::{MetadataError, shape::Shape, strides::Strides};
4
5#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
6pub struct Metadata {
7 pub shape: Shape,
8 pub strides: Strides,
9}
10
11impl Metadata {
12 pub fn new(shape: impl Into<Shape>, strides: impl Into<Strides>) -> Self {
13 let shape = shape.into();
14 let strides = strides.into();
15 debug_assert_eq!(
16 shape.rank(),
17 strides.rank(),
18 "Rank of shape and strides must be the same"
19 );
20
21 Self { shape, strides }
22 }
23
24 pub fn shape(&self) -> &Shape {
25 &self.shape
26 }
27
28 pub fn shape_mut(&mut self) -> &mut Shape {
29 &mut self.shape
30 }
31
32 pub fn strides(&self) -> &Strides {
33 &self.strides
34 }
35
36 pub fn strides_mut(&mut self) -> &mut Strides {
37 &mut self.strides
38 }
39
40 pub fn rank(&self) -> usize {
41 self.num_dims()
42 }
43
44 pub fn num_dims(&self) -> usize {
45 self.shape.num_dims()
46 }
47
48 pub fn num_elements(&self) -> usize {
50 self.shape.num_elements()
51 }
52
53 pub fn swapped(mut self, dim0: usize, dim1: usize) -> Self {
54 self.swap(dim0, dim1);
55 self
56 }
57
58 pub fn swap(&mut self, dim0: usize, dim1: usize) {
59 debug_assert!(dim0 < self.rank(), "dim0 is out of bounds");
60 debug_assert!(dim1 < self.rank(), "dim1 is out of bounds");
61 self.shape.swap(dim0, dim1);
62 self.strides.swap(dim0, dim1);
63 }
64
65 pub fn permute(&mut self, axes: &[usize]) -> Result<(), MetadataError> {
67 self.shape.permute(axes)?;
68 self.strides.permute(axes)?;
69
70 Ok(())
71 }
72
73 pub fn permuted(mut self, axes: &[usize]) -> Result<Self, MetadataError> {
74 self.permute(axes)?;
75 Ok(self)
76 }
77
78 pub fn insert(&mut self, index: usize, shape: usize, stride: usize) {
80 self.shape.insert(index, shape);
81 self.strides.insert(index, stride);
82 }
83
84 pub fn remove(&mut self, index: usize) -> (usize, usize) {
86 let shape = self.shape.remove(index);
87 let stride = self.strides.remove(index);
88 (shape, stride)
89 }
90
91 pub fn push(&mut self, shape: usize, stride: usize) {
93 self.shape.push(shape);
94 self.strides.push(stride);
95 }
96}