1use core::ops::{Deref, DerefMut};
2
3use serde::{Deserialize, Serialize};
4use smallvec::SmallVec;
5
6use crate::{INLINE_DIMS, MetadataError, indexing::AsSize};
7
8#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)]
9pub struct Strides {
10 dims: SmallVec<[usize; INLINE_DIMS]>,
11}
12
13impl Strides {
14 pub fn new(dims: &[usize]) -> Self {
15 Self {
17 dims: SmallVec::from_slice(dims),
18 }
19 }
20
21 pub fn new_raw(dims: SmallVec<[usize; INLINE_DIMS]>) -> Self {
22 Self { dims }
23 }
24
25 pub fn rank(&self) -> usize {
26 self.dims.len()
27 }
28
29 pub fn insert(&mut self, index: usize, stride: usize) {
31 self.dims.insert(index, stride);
32 }
33
34 pub fn remove(&mut self, index: usize) -> usize {
36 self.dims.remove(index)
37 }
38
39 pub fn push(&mut self, stride: usize) {
41 self.dims.push(stride)
42 }
43
44 pub fn extend(&mut self, iter: impl IntoIterator<Item = usize>) {
46 self.dims.extend(iter)
47 }
48
49 pub fn permute(&mut self, axes: &[usize]) -> Result<(), MetadataError> {
51 if axes.len() != self.rank() {
52 return Err(MetadataError::RankMismatch {
53 left: self.rank(),
54 right: axes.len(),
55 });
56 }
57 debug_assert!(axes.iter().all(|i| i < &self.rank()));
58
59 self.dims = axes.iter().map(|&i| self.dims[i]).collect();
60 Ok(())
61 }
62
63 pub fn permuted(mut self, axes: &[usize]) -> Result<Self, MetadataError> {
65 self.permute(axes)?;
66 Ok(self)
67 }
68}
69
70impl Deref for Strides {
71 type Target = [usize];
72
73 fn deref(&self) -> &Self::Target {
74 &self.dims
75 }
76}
77
78impl DerefMut for Strides {
79 fn deref_mut(&mut self) -> &mut Self::Target {
80 &mut self.dims
81 }
82}
83
84#[macro_export]
85macro_rules! strides {
86 (@one $x:expr) => (1usize);
87 () => (
88 $crate::Strides::new_raw($crate::SmallVec::new())
89 );
90 ($elem:expr; $n:expr) => ({
91 $crate::Strides::new_raw($crate::smallvec!($elem; $n))
92 });
93 ($($x:expr),+$(,)?) => ({
94 $crate::Strides::new_raw($crate::smallvec!($($x),*))
95 });
96}
97
98impl<T, I> From<T> for Strides
99where
100 T: IntoIterator<Item = I>,
101 I: AsSize,
102{
103 fn from(dims: T) -> Self {
104 Strides {
105 dims: dims.into_iter().map(|d| d.as_size()).collect(),
106 }
107 }
108}
109
110impl From<&Strides> for Strides {
111 fn from(value: &Strides) -> Self {
112 value.clone()
113 }
114}
115
116impl<I: AsSize> FromIterator<I> for Strides {
117 fn from_iter<T: IntoIterator<Item = I>>(iter: T) -> Self {
118 Strides {
119 dims: iter.into_iter().map(|it| it.as_size()).collect(),
120 }
121 }
122}