Skip to main content

cubecl_zspace/
strides.rs

1use core::ops::{Deref, DerefMut};
2
3use serde::{Deserialize, Serialize};
4use smallvec::SmallVec;
5
6use crate::{INLINE_DIMS, MetadataError, indexing::AsSize};
7
8#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
9pub struct Strides {
10    dims: SmallVec<[usize; INLINE_DIMS]>,
11}
12
13impl Strides {
14    pub fn new(dims: &[usize]) -> Self {
15        // For backward compat
16        Self {
17            dims: SmallVec::from_slice(dims),
18        }
19    }
20
21    pub fn new_raw(dims: SmallVec<[usize; INLINE_DIMS]>) -> Self {
22        Self { dims }
23    }
24
25    pub fn rank(&self) -> usize {
26        self.dims.len()
27    }
28
29    /// Insert a dimension of `stride` at position `index`.
30    pub fn insert(&mut self, index: usize, stride: usize) {
31        self.dims.insert(index, stride);
32    }
33
34    /// Remove and return the dimension at position `index` from the strides.
35    pub fn remove(&mut self, index: usize) -> usize {
36        self.dims.remove(index)
37    }
38
39    /// Appends a dimension of `stride` to the back of the strides.
40    pub fn push(&mut self, stride: usize) {
41        self.dims.push(stride)
42    }
43
44    /// Extend the strides with the content of another shape or iterator.
45    pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
46        self.dims.extend(iter)
47    }
48
49    /// Reorder the strides dimensions according to the permutation of `axes`.
50    pub fn permute(&mut self, axes: &[usize]) -> Result<(), MetadataError> {
51        if axes.len() != self.rank() {
52            return Err(MetadataError::RankMismatch {
53                left: self.rank(),
54                right: axes.len(),
55            });
56        }
57        debug_assert!(axes.iter().all(|i| i < &self.rank()));
58
59        self.dims = axes.iter().map(|&i| self.dims[i]).collect();
60        Ok(())
61    }
62
63    /// Reorder the strides dimensions according to the permutation of `axes`.
64    pub fn permuted(mut self, axes: &[usize]) -> Result<Self, MetadataError> {
65        self.permute(axes)?;
66        Ok(self)
67    }
68}
69
70impl Deref for Strides {
71    type Target = [usize];
72
73    fn deref(&self) -> &Self::Target {
74        &self.dims
75    }
76}
77
78impl DerefMut for Strides {
79    fn deref_mut(&mut self) -> &mut Self::Target {
80        &mut self.dims
81    }
82}
83
84#[macro_export]
85macro_rules! strides {
86    (@one $x:expr) => (1usize);
87    () => (
88        $crate::Strides::new_raw($crate::SmallVec::new())
89    );
90    ($elem:expr; $n:expr) => ({
91        $crate::Strides::new_raw($crate::smallvec!($elem; $n))
92    });
93    ($($x:expr),+$(,)?) => ({
94        $crate::Strides::new_raw($crate::smallvec!($($x),*))
95    });
96}
97
98impl<T, I> From<T> for Strides
99where
100    T: IntoIterator<Item = I>,
101    I: AsSize,
102{
103    fn from(dims: T) -> Self {
104        Strides {
105            dims: dims.into_iter().map(|d| d.as_size()).collect(),
106        }
107    }
108}
109
110impl From<&Strides> for Strides {
111    fn from(value: &Strides) -> Self {
112        value.clone()
113    }
114}
115
116impl<I: AsSize> FromIterator<I> for Strides {
117    fn from_iter<T: IntoIterator<Item = I>>(iter: T) -> Self {
118        Strides {
119            dims: iter.into_iter().map(|it| it.as_size()).collect(),
120        }
121    }
122}