redstone_ml/common/
methods.rs

1use crate::iterator::collapse_contiguous::collapse_to_uniform_stride;
2use crate::ndarray::flags::NdArrayFlags;
3
4#[allow(clippy::len_without_is_empty)]
5pub trait StridedMemory: Sized {
6    /// Returns the dimensions of the ndarray along each axis.
7    ///
8    /// ```
9    /// # use redstone_ml::*;
10    ///
11    /// let a = NdArray::new([3, 4, 5]);
12    /// assert_eq!(a.shape(), &[3]);
13    ///
14    /// let b = NdArray::new([[3], [5]]);
15    /// assert_eq!(b.shape(), &[2, 1]);
16    ///
17    /// let c = NdArray::scalar(0);
18    /// assert_eq!(c.shape(), &[]);
19    /// ```
20    fn shape(&self) -> &[usize];
21
22    /// Returns the stride of the ndarray.
23    ///
24    /// The stride represents the distance in memory between elements in an ndarray along each axis.
25    ///
26    /// ```
27    /// # use redstone_ml::*;
28    ///
29    /// let a = NdArray::new([[3, 4], [5, 6]]);
30    /// assert_eq!(a.stride(), &[2, 1]);
31    /// ```
32    fn stride(&self) -> &[usize];
33
34    /// Returns the number of dimensions in the ndarray.
35    ///
36    /// ```
37    /// # use redstone_ml::*;
38    /// let a = NdArray::new([3, 4, 5]);
39    /// assert_eq!(a.ndims(), 1);
40    ///
41    /// let b = NdArray::new([[3], [5]]);
42    /// assert_eq!(b.ndims(), 2);
43    ///
44    /// let c = NdArray::scalar(0);
45    /// assert_eq!(c.ndims(), 0);
46    /// ```
47    fn ndims(&self) -> usize {
48        self.shape().len()
49    }
50
51    /// Returns the length along the first dimension of the ndarray.
52    /// If the ndarray is a scalar, this returns 0.
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// # use redstone_ml::*;
58    /// let a = NdArray::new([3, 4, 5]);
59    /// assert_eq!(a.len(), 3);
60    ///
61    /// let b = NdArray::new([[3], [5]]);
62    /// assert_eq!(b.len(), 2);
63    ///
64    /// let c = NdArray::scalar(0);
65    /// assert_eq!(c.len(), 0);
66    /// ```
67    #[inline]
68    fn len(&self) -> usize {
69        if self.shape().is_empty() {
70            return 0;
71        }
72
73        self.shape()[0]
74    }
75
76    /// Returns the total number of elements in the ndarray.
77    ///
78    /// ```
79    /// # use redstone_ml::*;
80    /// let a = NdArray::new([3, 4, 5]);
81    /// assert_eq!(a.size(), 3);
82    ///
83    /// let b = NdArray::new([[3], [5]]);
84    /// assert_eq!(b.size(), 2);
85    ///
86    /// let c = NdArray::scalar(0);
87    /// assert_eq!(c.size(), 1);
88    /// ```
89    #[inline]
90    fn size(&self) -> usize {
91        self.shape().iter().product()
92    }
93
94    /// Returns flags containing information about various ndarray metadata.
95    fn flags(&self) -> NdArrayFlags;
96
97    /// Returns whether this ndarray is stored contiguously in memory.
98    ///
99    /// ```
100    /// # use redstone_ml::*;
101    /// let a = NdArray::new([[3, 4], [5, 6]]);
102    /// assert!(a.is_contiguous());
103    ///
104    /// let b = a.slice_along(Axis(1), 0);
105    /// assert!(!b.is_contiguous());
106    /// ```
107    #[inline]
108    fn is_contiguous(&self) -> bool {
109        self.flags().contains(NdArrayFlags::Contiguous)
110    }
111
112    /// Returns whether this ndarray is slice of another ndarray.
113    ///
114    /// ```
115    /// # use redstone_ml::*;
116    /// let a = NdArray::new([[3, 4], [5, 6]]);
117    /// assert!(!a.is_view());
118    ///
119    /// let b = a.slice_along(Axis(1), 0);
120    /// assert!(b.is_view());
121    /// ```
122    #[inline]
123    fn is_view(&self) -> bool {
124        !self.flags().contains(NdArrayFlags::Owned)
125    }
126
127    /// Whether the elements of this ndarray are stored in memory with a uniform distance between them.
128    ///
129    /// Contiguous arrays are always uniformly strided. Views may sometimes be uniformly strided.
130    ///
131    /// ```
132    /// # use redstone_ml::*;
133    /// let a = NdArray::new([[3, 4, 5], [6, 7, 8]]);
134    /// assert!(a.is_uniformly_strided());
135    ///
136    /// let b = a.slice_along(Axis(1), 0);
137    /// assert!(b.is_uniformly_strided());
138    ///
139    /// let c = a.slice_along(Axis(1), ..2);
140    /// assert!(!c.is_uniformly_strided());
141    /// ```
142    #[inline]
143    fn is_uniformly_strided(&self) -> bool {
144        self.flags().contains(NdArrayFlags::UniformStride)
145    }
146
147    /// If the elements of this ndarray are stored in memory with a uniform distance between them,
148    /// returns this distance.
149    ///
150    /// Contiguous arrays always have a uniform stride of 1.
151    /// NdArray views may sometimes be uniformly strided.
152    ///
153    /// ```
154    /// # use redstone_ml::*;
155    /// let a = NdArray::new([[3, 4, 5], [6, 7, 8]]);
156    /// assert_eq!(a.has_uniform_stride(), Some(1));
157    ///
158    /// let b = a.slice_along(Axis(1), 0);
159    /// assert_eq!(b.has_uniform_stride(), Some(3));
160    ///
161    /// let c = a.slice_along(Axis(1), ..2);
162    /// assert_eq!(c.has_uniform_stride(), None);
163    /// ```
164    #[inline]
165    fn has_uniform_stride(&self) -> Option<usize> {
166        if !self.is_uniformly_strided() {
167            return None;
168        }
169
170        if self.ndims() == 0 {
171            return Some(0);
172        }
173
174        let (_, new_stride) = collapse_to_uniform_stride(self.shape(), self.stride());
175        Some(new_stride[0])
176    }
177}