acme_tensor/shape/
stride.rs

1/*
2   Appellation: stride <mod>
3   Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::{dim, Axis, Rank};
6use core::borrow::{Borrow, BorrowMut};
7use core::ops::{Deref, DerefMut, Index, IndexMut};
8use core::slice::{Iter as SliceIter, IterMut as SliceIterMut};
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub trait IntoStride {
13    fn into_stride(self) -> Stride;
14}
15
16impl<S> IntoStride for S
17where
18    S: Into<Stride>,
19{
20    fn into_stride(self) -> Stride {
21        self.into()
22    }
23}
24
25pub enum Strides {
26    Contiguous,
27    Fortran,
28    Stride(Stride),
29}
30
31#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
32#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
33pub struct Stride(pub(crate) Vec<usize>);
34
35impl Stride {
36    pub fn new(stride: Vec<usize>) -> Self {
37        Self(stride)
38    }
39
40    pub fn with_capacity(capacity: usize) -> Self {
41        Self(Vec::with_capacity(capacity))
42    }
43
44    pub fn zeros(rank: Rank) -> Self {
45        Self(vec![0; *rank])
46    }
47    /// Returns a reference to the stride.
48    pub fn as_slice(&self) -> &[usize] {
49        &self.0
50    }
51    /// Returns a mutable reference to the stride.
52    pub fn as_slice_mut(&mut self) -> &mut [usize] {
53        &mut self.0
54    }
55    /// Returns the capacity of the stride.
56    pub fn capacity(&self) -> usize {
57        self.0.capacity()
58    }
59    /// Clears the stride, removing all elements.
60    pub fn clear(&mut self) {
61        self.0.clear()
62    }
63    /// Gets the element at the specified axis, returning None if the axis is out of bounds.
64    pub fn get(&self, axis: Axis) -> Option<&usize> {
65        self.0.get(*axis)
66    }
67    /// Returns an iterator over references to the elements of the stride.
68    pub fn iter(&self) -> SliceIter<usize> {
69        self.0.iter()
70    }
71    /// Returns an iterator over mutable references to the elements of the stride.
72    pub fn iter_mut(&mut self) -> SliceIterMut<usize> {
73        self.0.iter_mut()
74    }
75    /// Returns the rank of the stride; i.e., the number of dimensions.
76    pub fn rank(&self) -> Rank {
77        self.0.len().into()
78    }
79    /// Removes and returns the stride of the axis.
80    pub fn remove(&mut self, axis: Axis) -> usize {
81        self.0.remove(*axis)
82    }
83    /// Returns a new stride with the axis removed.
84    pub fn remove_axis(&self, axis: Axis) -> Self {
85        let mut stride = self.clone();
86        stride.remove(axis);
87        stride
88    }
89    /// Reverses the stride.
90    pub fn reverse(&mut self) {
91        self.0.reverse()
92    }
93    /// Returns a new stride with the elements reversed.
94    pub fn reversed(&self) -> Self {
95        let mut stride = self.clone();
96        stride.reverse();
97        stride
98    }
99    /// Sets the element at the specified axis, returning None if the axis is out of bounds.
100    pub fn set(&mut self, axis: Axis, value: usize) -> Option<usize> {
101        self.0.get_mut(*axis).map(|v| core::mem::replace(v, value))
102    }
103    ///
104    pub fn stride_offset<Idx>(&self, index: &Idx) -> isize
105    where
106        Idx: AsRef<[usize]>,
107    {
108        index
109            .as_ref()
110            .iter()
111            .copied()
112            .zip(self.iter().copied())
113            .fold(0, |acc, (i, s)| acc + dim::stride_offset(i, s))
114    }
115    /// Swaps two elements in the stride, inplace.
116    pub fn swap(&mut self, a: usize, b: usize) {
117        self.0.swap(a, b)
118    }
119    /// Returns a new shape with the two axes swapped.
120    pub fn swap_axes(&self, a: Axis, b: Axis) -> Self {
121        let mut stride = self.clone();
122        stride.swap(a.axis(), b.axis());
123        stride
124    }
125}
126
127// Internal methods
128impl Stride {
129    pub(crate) fn _fastest_varying_stride_order(&self) -> Self {
130        let mut indices = self.clone();
131        for (i, elt) in indices.as_slice_mut().iter_mut().enumerate() {
132            *elt = i;
133        }
134        let strides = self.as_slice();
135        indices
136            .as_slice_mut()
137            .sort_by_key(|&i| (strides[i] as isize).abs());
138        indices
139    }
140}
141
142impl AsRef<[usize]> for Stride {
143    fn as_ref(&self) -> &[usize] {
144        &self.0
145    }
146}
147
148impl AsMut<[usize]> for Stride {
149    fn as_mut(&mut self) -> &mut [usize] {
150        &mut self.0
151    }
152}
153
154impl Borrow<[usize]> for Stride {
155    fn borrow(&self) -> &[usize] {
156        &self.0
157    }
158}
159
160impl BorrowMut<[usize]> for Stride {
161    fn borrow_mut(&mut self) -> &mut [usize] {
162        &mut self.0
163    }
164}
165
166impl Deref for Stride {
167    type Target = [usize];
168
169    fn deref(&self) -> &Self::Target {
170        &self.0
171    }
172}
173
174impl DerefMut for Stride {
175    fn deref_mut(&mut self) -> &mut Self::Target {
176        &mut self.0
177    }
178}
179
180impl Extend<usize> for Stride {
181    fn extend<I: IntoIterator<Item = usize>>(&mut self, iter: I) {
182        self.0.extend(iter)
183    }
184}
185
186impl FromIterator<usize> for Stride {
187    fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> Self {
188        Stride(Vec::from_iter(iter))
189    }
190}
191
192impl Index<usize> for Stride {
193    type Output = usize;
194
195    fn index(&self, index: usize) -> &Self::Output {
196        &self.0[index]
197    }
198}
199
200impl IndexMut<usize> for Stride {
201    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
202        &mut self.0[index]
203    }
204}
205
206impl Index<Axis> for Stride {
207    type Output = usize;
208
209    fn index(&self, index: Axis) -> &Self::Output {
210        &self.0[*index]
211    }
212}
213
214impl IndexMut<Axis> for Stride {
215    fn index_mut(&mut self, index: Axis) -> &mut Self::Output {
216        &mut self.0[*index]
217    }
218}
219
220impl IntoIterator for Stride {
221    type Item = usize;
222    type IntoIter = std::vec::IntoIter<Self::Item>;
223
224    fn into_iter(self) -> Self::IntoIter {
225        self.0.into_iter()
226    }
227}
228
229impl From<Vec<usize>> for Stride {
230    fn from(v: Vec<usize>) -> Self {
231        Stride(v)
232    }
233}
234
235impl From<&[usize]> for Stride {
236    fn from(v: &[usize]) -> Self {
237        Stride(v.to_vec())
238    }
239}