rstsr_common/layout/
shape.rs

1use crate::prelude_dev::*;
2
3pub trait DimShapeAPI: DimBaseAPI {
4    /// Total number of elements in tensor.
5    ///
6    /// # Note
7    ///
8    /// For 0-dimension tensor, it contains one element.
9    /// For multi-dimension tensor with a dimension that have zero length, it
10    /// contains zero elements.
11    ///
12    /// # Examples
13    ///
14    /// ```
15    /// use rstsr_core::prelude_dev::*;
16    ///
17    /// let shape = [2, 3];
18    /// assert_eq!(shape.shape_size(), 6);
19    ///
20    /// let shape = vec![];
21    /// assert_eq!(shape.shape_size(), 1);
22    /// ```
23    fn shape_size(&self) -> usize;
24
25    /// Stride for a f-contiguous tensor using this shape.
26    ///
27    /// # Examples
28    ///
29    /// ```
30    /// use rstsr_core::prelude_dev::*;
31    ///
32    /// let stride = [2, 3, 5].stride_f_contig();
33    /// assert_eq!(stride, [1, 2, 6]);
34    /// ```
35    fn stride_f_contig(&self) -> Self::Stride;
36
37    /// Stride for a c-contiguous tensor using this shape.
38    ///
39    /// # Examples
40    ///
41    /// ```
42    /// use rstsr_core::prelude_dev::*;
43    ///
44    /// let stride = [2, 3, 5].stride_c_contig();
45    /// assert_eq!(stride, [15, 5, 1]);
46    /// ```
47    fn stride_c_contig(&self) -> Self::Stride;
48
49    /// Stride for contiguous tensor using this shape.
50    ///
51    /// # Cargo feature dependent
52    ///
53    /// Whether c-contiguous or f-contiguous will depends on cargo feature
54    /// `f_prefer`.
55    fn stride_contig(&self) -> Self::Stride;
56
57    /// Index (col-major) of tensor by list of indexes.
58    ///
59    /// # Safety
60    ///
61    /// This function does not check whether index is out of bounds.
62    unsafe fn unravel_index_f(&self, index: usize) -> Self;
63
64    /// Index (row-major) of tensor by list of indexes.
65    ///
66    /// # Safety
67    ///
68    /// This function does not check whether index is out of bounds.
69    unsafe fn unravel_index_c(&self, index: usize) -> Self;
70}
71
72impl<const N: usize> DimShapeAPI for Ix<N> {
73    fn shape_size(&self) -> usize {
74        self.iter().product()
75    }
76
77    fn stride_f_contig(&self) -> [isize; N] {
78        let mut stride = [1; N];
79        for i in 1..N {
80            stride[i] = stride[i - 1] * self[i - 1].max(1) as isize;
81        }
82        stride
83    }
84
85    fn stride_c_contig(&self) -> [isize; N] {
86        let mut stride = [1; N];
87        if N == 0 {
88            return stride;
89        }
90        for i in (0..N - 1).rev() {
91            stride[i] = stride[i + 1] * self[i + 1].max(1) as isize;
92        }
93        stride
94    }
95
96    fn stride_contig(&self) -> [isize; N] {
97        match FlagOrder::default() {
98            RowMajor => Self::stride_c_contig(self),
99            ColMajor => Self::stride_f_contig(self),
100        }
101    }
102
103    #[inline]
104    unsafe fn unravel_index_f(&self, index: usize) -> Self {
105        let mut index = index;
106        let mut result = self.new_shape();
107        match self.ndim() {
108            0 => (),
109            1 => {
110                result[0] = index;
111            },
112            2 => {
113                result[1] = index / self[0];
114                result[0] = index % self[0];
115            },
116            3 => {
117                result[2] = index / (self[0] * self[1]);
118                index %= self[0] * self[1];
119                result[1] = index / self[0];
120                result[0] = index % self[0];
121            },
122            4 => {
123                result[3] = index / (self[0] * self[1] * self[2]);
124                index %= self[0] * self[1] * self[2];
125                result[2] = index / (self[0] * self[1]);
126                index %= self[0] * self[1];
127                result[1] = index / self[0];
128                result[0] = index % self[0];
129            },
130            _ => {
131                for i in 0..(self.ndim() - 1) {
132                    let dim = self[i];
133                    result[i] = index % dim;
134                    index /= dim;
135                }
136                result[self.ndim() - 1] = index;
137            },
138        }
139        return result;
140    }
141
142    #[inline]
143    unsafe fn unravel_index_c(&self, index: usize) -> Self {
144        let mut index = index;
145        let mut result = self.new_shape();
146        match self.ndim() {
147            0 => (),
148            1 => {
149                result[0] = index;
150            },
151            2 => {
152                result[0] = index / self[1];
153                result[1] = index % self[1];
154            },
155            3 => {
156                result[0] = index / (self[1] * self[2]);
157                index %= self[1] * self[2];
158                result[1] = index / self[2];
159                result[2] = index % self[2];
160            },
161            4 => {
162                result[0] = index / (self[1] * self[2] * self[3]);
163                index %= self[1] * self[2] * self[3];
164                result[1] = index / (self[2] * self[3]);
165                index %= self[2] * self[3];
166                result[2] = index / self[3];
167                result[3] = index % self[3];
168            },
169            _ => {
170                for i in (1..self.ndim()).rev() {
171                    let dim = self[i];
172                    result[i] = index % dim;
173                    index /= dim;
174                }
175                result[0] = index;
176            },
177        }
178        return result;
179    }
180}
181
182impl DimShapeAPI for IxD {
183    fn shape_size(&self) -> usize {
184        self.iter().product()
185    }
186
187    fn stride_f_contig(&self) -> Vec<isize> {
188        let mut stride = vec![1; self.len()];
189        for i in 1..self.len() {
190            stride[i] = stride[i - 1] * self[i - 1] as isize;
191        }
192        stride
193    }
194
195    fn stride_c_contig(&self) -> Vec<isize> {
196        let mut stride = vec![1; self.len()];
197        if self.is_empty() {
198            return stride;
199        }
200        for i in (0..self.len() - 1).rev() {
201            stride[i] = stride[i + 1] * self[i + 1] as isize;
202        }
203        stride
204    }
205
206    fn stride_contig(&self) -> Vec<isize> {
207        match FlagOrder::default() {
208            RowMajor => Self::stride_c_contig(self),
209            ColMajor => Self::stride_f_contig(self),
210        }
211    }
212
213    #[inline]
214    unsafe fn unravel_index_f(&self, index: usize) -> Self {
215        let mut index = index;
216        let mut result = self.new_shape();
217        if self.ndim() >= 1 {
218            for i in 0..(self.ndim() - 1) {
219                let dim = self[i];
220                result[i] = index % dim;
221                index /= dim;
222            }
223            result[self.ndim() - 1] = index;
224        }
225        return result;
226    }
227
228    #[inline]
229    unsafe fn unravel_index_c(&self, index: usize) -> Self {
230        let mut index = index;
231        let mut result = self.new_shape();
232        if self.ndim() >= 1 {
233            for i in (1..self.ndim()).rev() {
234                let dim = self[i];
235                result[i] = index % dim;
236                index /= dim;
237            }
238            result[0] = index;
239        }
240        return result;
241    }
242}
243
244#[cfg(test)]
245mod test {
246    use super::*;
247
248    #[test]
249    fn test_ndim() {
250        // general test
251        let shape = [2, 3];
252        assert_eq!(shape.ndim(), 2);
253        let shape = vec![2, 3];
254        assert_eq!(shape.ndim(), 2);
255        // empty dimension test
256        let shape = [];
257        assert_eq!(shape.ndim(), 0);
258        let shape = vec![];
259        assert_eq!(shape.ndim(), 0);
260    }
261
262    #[test]
263    fn test_size() {
264        // general test
265        let shape = [2, 3];
266        assert_eq!(shape.shape_size(), 6);
267        let shape = vec![];
268        assert_eq!(shape.shape_size(), 1);
269        // empty dimension test
270        let shape = [];
271        assert_eq!(shape.shape_size(), 1);
272        let shape = vec![];
273        assert_eq!(shape.shape_size(), 1);
274        // zero element test
275        let shape = [1, 2, 0, 4];
276        assert_eq!(shape.shape_size(), 0);
277    }
278
279    #[test]
280    fn test_stride_f_contig() {
281        // general test
282        let stride = [2, 3, 5].stride_f_contig();
283        assert_eq!(stride, [1, 2, 6]);
284        // empty dimension test
285        let stride = [].stride_f_contig();
286        assert_eq!(stride, []);
287        let stride = vec![].stride_f_contig();
288        assert_eq!(stride, vec![]);
289        // zero element test
290        let stride = [1, 2, 0, 4].stride_f_contig();
291        println!("{stride:?}");
292    }
293
294    #[test]
295    fn test_stride_c_contig() {
296        // general test
297        let stride = [2, 3, 5].stride_c_contig();
298        assert_eq!(stride, [15, 5, 1]);
299        // empty dimension test
300        let stride = [].stride_c_contig();
301        assert_eq!(stride, []);
302        let stride = vec![].stride_c_contig();
303        assert_eq!(stride, vec![]);
304        // zero element test
305        let stride = [1, 2, 0, 4].stride_c_contig();
306        println!("{stride:?}");
307    }
308}