redstone_ml/common/methods.rs
1use crate::iterator::collapse_contiguous::collapse_to_uniform_stride;
2use crate::ndarray::flags::NdArrayFlags;
3
4#[allow(clippy::len_without_is_empty)]
5pub trait StridedMemory: Sized {
6 /// Returns the dimensions of the ndarray along each axis.
7 ///
8 /// ```
9 /// # use redstone_ml::*;
10 ///
11 /// let a = NdArray::new([3, 4, 5]);
12 /// assert_eq!(a.shape(), &[3]);
13 ///
14 /// let b = NdArray::new([[3], [5]]);
15 /// assert_eq!(b.shape(), &[2, 1]);
16 ///
17 /// let c = NdArray::scalar(0);
18 /// assert_eq!(c.shape(), &[]);
19 /// ```
20 fn shape(&self) -> &[usize];
21
22 /// Returns the stride of the ndarray.
23 ///
24 /// The stride represents the distance in memory between elements in an ndarray along each axis.
25 ///
26 /// ```
27 /// # use redstone_ml::*;
28 ///
29 /// let a = NdArray::new([[3, 4], [5, 6]]);
30 /// assert_eq!(a.stride(), &[2, 1]);
31 /// ```
32 fn stride(&self) -> &[usize];
33
34 /// Returns the number of dimensions in the ndarray.
35 ///
36 /// ```
37 /// # use redstone_ml::*;
38 /// let a = NdArray::new([3, 4, 5]);
39 /// assert_eq!(a.ndims(), 1);
40 ///
41 /// let b = NdArray::new([[3], [5]]);
42 /// assert_eq!(b.ndims(), 2);
43 ///
44 /// let c = NdArray::scalar(0);
45 /// assert_eq!(c.ndims(), 0);
46 /// ```
47 fn ndims(&self) -> usize {
48 self.shape().len()
49 }
50
51 /// Returns the length along the first dimension of the ndarray.
52 /// If the ndarray is a scalar, this returns 0.
53 ///
54 /// # Examples
55 ///
56 /// ```
57 /// # use redstone_ml::*;
58 /// let a = NdArray::new([3, 4, 5]);
59 /// assert_eq!(a.len(), 3);
60 ///
61 /// let b = NdArray::new([[3], [5]]);
62 /// assert_eq!(b.len(), 2);
63 ///
64 /// let c = NdArray::scalar(0);
65 /// assert_eq!(c.len(), 0);
66 /// ```
67 #[inline]
68 fn len(&self) -> usize {
69 if self.shape().is_empty() {
70 return 0;
71 }
72
73 self.shape()[0]
74 }
75
76 /// Returns the total number of elements in the ndarray.
77 ///
78 /// ```
79 /// # use redstone_ml::*;
80 /// let a = NdArray::new([3, 4, 5]);
81 /// assert_eq!(a.size(), 3);
82 ///
83 /// let b = NdArray::new([[3], [5]]);
84 /// assert_eq!(b.size(), 2);
85 ///
86 /// let c = NdArray::scalar(0);
87 /// assert_eq!(c.size(), 1);
88 /// ```
89 #[inline]
90 fn size(&self) -> usize {
91 self.shape().iter().product()
92 }
93
94 /// Returns flags containing information about various ndarray metadata.
95 fn flags(&self) -> NdArrayFlags;
96
97 /// Returns whether this ndarray is stored contiguously in memory.
98 ///
99 /// ```
100 /// # use redstone_ml::*;
101 /// let a = NdArray::new([[3, 4], [5, 6]]);
102 /// assert!(a.is_contiguous());
103 ///
104 /// let b = a.slice_along(Axis(1), 0);
105 /// assert!(!b.is_contiguous());
106 /// ```
107 #[inline]
108 fn is_contiguous(&self) -> bool {
109 self.flags().contains(NdArrayFlags::Contiguous)
110 }
111
112 /// Returns whether this ndarray is slice of another ndarray.
113 ///
114 /// ```
115 /// # use redstone_ml::*;
116 /// let a = NdArray::new([[3, 4], [5, 6]]);
117 /// assert!(!a.is_view());
118 ///
119 /// let b = a.slice_along(Axis(1), 0);
120 /// assert!(b.is_view());
121 /// ```
122 #[inline]
123 fn is_view(&self) -> bool {
124 !self.flags().contains(NdArrayFlags::Owned)
125 }
126
127 /// Whether the elements of this ndarray are stored in memory with a uniform distance between them.
128 ///
129 /// Contiguous arrays are always uniformly strided. Views may sometimes be uniformly strided.
130 ///
131 /// ```
132 /// # use redstone_ml::*;
133 /// let a = NdArray::new([[3, 4, 5], [6, 7, 8]]);
134 /// assert!(a.is_uniformly_strided());
135 ///
136 /// let b = a.slice_along(Axis(1), 0);
137 /// assert!(b.is_uniformly_strided());
138 ///
139 /// let c = a.slice_along(Axis(1), ..2);
140 /// assert!(!c.is_uniformly_strided());
141 /// ```
142 #[inline]
143 fn is_uniformly_strided(&self) -> bool {
144 self.flags().contains(NdArrayFlags::UniformStride)
145 }
146
147 /// If the elements of this ndarray are stored in memory with a uniform distance between them,
148 /// returns this distance.
149 ///
150 /// Contiguous arrays always have a uniform stride of 1.
151 /// NdArray views may sometimes be uniformly strided.
152 ///
153 /// ```
154 /// # use redstone_ml::*;
155 /// let a = NdArray::new([[3, 4, 5], [6, 7, 8]]);
156 /// assert_eq!(a.has_uniform_stride(), Some(1));
157 ///
158 /// let b = a.slice_along(Axis(1), 0);
159 /// assert_eq!(b.has_uniform_stride(), Some(3));
160 ///
161 /// let c = a.slice_along(Axis(1), ..2);
162 /// assert_eq!(c.has_uniform_stride(), None);
163 /// ```
164 #[inline]
165 fn has_uniform_stride(&self) -> Option<usize> {
166 if !self.is_uniformly_strided() {
167 return None;
168 }
169
170 if self.ndims() == 0 {
171 return Some(0);
172 }
173
174 let (_, new_stride) = collapse_to_uniform_stride(self.shape(), self.stride());
175 Some(new_stride[0])
176 }
177}