acme_tensor/shape/
stride.rs1use super::{dim, Axis, Rank};
6use core::borrow::{Borrow, BorrowMut};
7use core::ops::{Deref, DerefMut, Index, IndexMut};
8use core::slice::{Iter as SliceIter, IterMut as SliceIterMut};
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub trait IntoStride {
13 fn into_stride(self) -> Stride;
14}
15
16impl<S> IntoStride for S
17where
18 S: Into<Stride>,
19{
20 fn into_stride(self) -> Stride {
21 self.into()
22 }
23}
24
25pub enum Strides {
26 Contiguous,
27 Fortran,
28 Stride(Stride),
29}
30
31#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33pub struct Stride(pub(crate) Vec<usize>);
34
35impl Stride {
36 pub fn new(stride: Vec<usize>) -> Self {
37 Self(stride)
38 }
39
40 pub fn with_capacity(capacity: usize) -> Self {
41 Self(Vec::with_capacity(capacity))
42 }
43
44 pub fn zeros(rank: Rank) -> Self {
45 Self(vec![0; *rank])
46 }
47 pub fn as_slice(&self) -> &[usize] {
49 &self.0
50 }
51 pub fn as_slice_mut(&mut self) -> &mut [usize] {
53 &mut self.0
54 }
55 pub fn capacity(&self) -> usize {
57 self.0.capacity()
58 }
59 pub fn clear(&mut self) {
61 self.0.clear()
62 }
63 pub fn get(&self, axis: Axis) -> Option<&usize> {
65 self.0.get(*axis)
66 }
67 pub fn iter(&self) -> SliceIter<usize> {
69 self.0.iter()
70 }
71 pub fn iter_mut(&mut self) -> SliceIterMut<usize> {
73 self.0.iter_mut()
74 }
75 pub fn rank(&self) -> Rank {
77 self.0.len().into()
78 }
79 pub fn remove(&mut self, axis: Axis) -> usize {
81 self.0.remove(*axis)
82 }
83 pub fn remove_axis(&self, axis: Axis) -> Self {
85 let mut stride = self.clone();
86 stride.remove(axis);
87 stride
88 }
89 pub fn reverse(&mut self) {
91 self.0.reverse()
92 }
93 pub fn reversed(&self) -> Self {
95 let mut stride = self.clone();
96 stride.reverse();
97 stride
98 }
99 pub fn set(&mut self, axis: Axis, value: usize) -> Option<usize> {
101 self.0.get_mut(*axis).map(|v| core::mem::replace(v, value))
102 }
103 pub fn stride_offset<Idx>(&self, index: &Idx) -> isize
105 where
106 Idx: AsRef<[usize]>,
107 {
108 index
109 .as_ref()
110 .iter()
111 .copied()
112 .zip(self.iter().copied())
113 .fold(0, |acc, (i, s)| acc + dim::stride_offset(i, s))
114 }
115 pub fn swap(&mut self, a: usize, b: usize) {
117 self.0.swap(a, b)
118 }
119 pub fn swap_axes(&self, a: Axis, b: Axis) -> Self {
121 let mut stride = self.clone();
122 stride.swap(a.axis(), b.axis());
123 stride
124 }
125}
126
127impl Stride {
129 pub(crate) fn _fastest_varying_stride_order(&self) -> Self {
130 let mut indices = self.clone();
131 for (i, elt) in indices.as_slice_mut().iter_mut().enumerate() {
132 *elt = i;
133 }
134 let strides = self.as_slice();
135 indices
136 .as_slice_mut()
137 .sort_by_key(|&i| (strides[i] as isize).abs());
138 indices
139 }
140}
141
142impl AsRef<[usize]> for Stride {
143 fn as_ref(&self) -> &[usize] {
144 &self.0
145 }
146}
147
148impl AsMut<[usize]> for Stride {
149 fn as_mut(&mut self) -> &mut [usize] {
150 &mut self.0
151 }
152}
153
154impl Borrow<[usize]> for Stride {
155 fn borrow(&self) -> &[usize] {
156 &self.0
157 }
158}
159
160impl BorrowMut<[usize]> for Stride {
161 fn borrow_mut(&mut self) -> &mut [usize] {
162 &mut self.0
163 }
164}
165
166impl Deref for Stride {
167 type Target = [usize];
168
169 fn deref(&self) -> &Self::Target {
170 &self.0
171 }
172}
173
174impl DerefMut for Stride {
175 fn deref_mut(&mut self) -> &mut Self::Target {
176 &mut self.0
177 }
178}
179
180impl Extend<usize> for Stride {
181 fn extend<I: IntoIterator<Item = usize>>(&mut self, iter: I) {
182 self.0.extend(iter)
183 }
184}
185
186impl FromIterator<usize> for Stride {
187 fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> Self {
188 Stride(Vec::from_iter(iter))
189 }
190}
191
192impl Index<usize> for Stride {
193 type Output = usize;
194
195 fn index(&self, index: usize) -> &Self::Output {
196 &self.0[index]
197 }
198}
199
200impl IndexMut<usize> for Stride {
201 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
202 &mut self.0[index]
203 }
204}
205
206impl Index<Axis> for Stride {
207 type Output = usize;
208
209 fn index(&self, index: Axis) -> &Self::Output {
210 &self.0[*index]
211 }
212}
213
214impl IndexMut<Axis> for Stride {
215 fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
216 &mut self.0[*index]
217 }
218}
219
220impl IntoIterator for Stride {
221 type Item = usize;
222 type IntoIter = std::vec::IntoIter<Self::Item>;
223
224 fn into_iter(self) -> Self::IntoIter {
225 self.0.into_iter()
226 }
227}
228
229impl From<Vec<usize>> for Stride {
230 fn from(v: Vec<usize>) -> Self {
231 Stride(v)
232 }
233}
234
235impl From<&[usize]> for Stride {
236 fn from(v: &[usize]) -> Self {
237 Stride(v.to_vec())
238 }
239}