rstsr_common/layout/
stride.rs

1use crate::prelude_dev::*;
2use core::ops::{Deref, DerefMut};
3
4#[derive(Debug, Clone, PartialEq)]
5pub struct Stride<D>(pub D::Stride)
6where
7    D: DimBaseAPI;
8
9impl<D> Deref for Stride<D>
10where
11    D: DimBaseAPI,
12{
13    type Target = D::Stride;
14
15    fn deref(&self) -> &Self::Target {
16        &self.0
17    }
18}
19
20impl<D> DerefMut for Stride<D>
21where
22    D: DimBaseAPI,
23{
24    fn deref_mut(&mut self) -> &mut Self::Target {
25        &mut self.0
26    }
27}
28
29pub trait DimStrideAPI: DimBaseAPI {
30    /// Number of dimensions of the shape.
31    fn ndim(stride: &Stride<Self>) -> usize;
32}
33
34impl<D> Stride<D>
35where
36    D: DimStrideAPI,
37{
38    pub fn ndim(&self) -> usize {
39        <D as DimStrideAPI>::ndim(self)
40    }
41}
42
43impl<const N: usize> DimStrideAPI for Ix<N> {
44    fn ndim(stride: &Stride<Ix<N>>) -> usize {
45        stride.len()
46    }
47}
48
49impl DimStrideAPI for IxD {
50    fn ndim(stride: &Stride<IxD>) -> usize {
51        stride.len()
52    }
53}
54
55#[cfg(test)]
56mod test {
57    use super::*;
58
59    #[test]
60    fn test_ndim() {
61        // general test
62        let stride = Stride::<Ix2>([2, 3]);
63        assert_eq!(stride.ndim(), 2);
64        let stride = Stride::<IxD>(vec![2, 3]);
65        assert_eq!(stride.ndim(), 2);
66        // empty dimension test
67        let stride = Stride::<Ix0>([]);
68        assert_eq!(stride.ndim(), 0);
69        let stride = Stride::<IxD>(vec![]);
70        assert_eq!(stride.ndim(), 0);
71    }
72}