rstsr_common/layout/
shape.rs1use crate::prelude_dev::*;
2
3pub trait DimShapeAPI: DimBaseAPI {
4 fn shape_size(&self) -> usize;
24
25 fn stride_f_contig(&self) -> Self::Stride;
36
37 fn stride_c_contig(&self) -> Self::Stride;
48
49 fn stride_contig(&self) -> Self::Stride;
56
57 unsafe fn unravel_index_f(&self, index: usize) -> Self;
63
64 unsafe fn unravel_index_c(&self, index: usize) -> Self;
70}
71
72impl<const N: usize> DimShapeAPI for Ix<N> {
73 fn shape_size(&self) -> usize {
74 self.iter().product()
75 }
76
77 fn stride_f_contig(&self) -> [isize; N] {
78 let mut stride = [1; N];
79 for i in 1..N {
80 stride[i] = stride[i - 1] * self[i - 1].max(1) as isize;
81 }
82 stride
83 }
84
85 fn stride_c_contig(&self) -> [isize; N] {
86 let mut stride = [1; N];
87 if N == 0 {
88 return stride;
89 }
90 for i in (0..N - 1).rev() {
91 stride[i] = stride[i + 1] * self[i + 1].max(1) as isize;
92 }
93 stride
94 }
95
96 fn stride_contig(&self) -> [isize; N] {
97 match FlagOrder::default() {
98 RowMajor => Self::stride_c_contig(self),
99 ColMajor => Self::stride_f_contig(self),
100 }
101 }
102
103 #[inline]
104 unsafe fn unravel_index_f(&self, index: usize) -> Self {
105 let mut index = index;
106 let mut result = self.new_shape();
107 match self.ndim() {
108 0 => (),
109 1 => {
110 result[0] = index;
111 },
112 2 => {
113 result[1] = index / self[0];
114 result[0] = index % self[0];
115 },
116 3 => {
117 result[2] = index / (self[0] * self[1]);
118 index %= self[0] * self[1];
119 result[1] = index / self[0];
120 result[0] = index % self[0];
121 },
122 4 => {
123 result[3] = index / (self[0] * self[1] * self[2]);
124 index %= self[0] * self[1] * self[2];
125 result[2] = index / (self[0] * self[1]);
126 index %= self[0] * self[1];
127 result[1] = index / self[0];
128 result[0] = index % self[0];
129 },
130 _ => {
131 for i in 0..(self.ndim() - 1) {
132 let dim = self[i];
133 result[i] = index % dim;
134 index /= dim;
135 }
136 result[self.ndim() - 1] = index;
137 },
138 }
139 return result;
140 }
141
142 #[inline]
143 unsafe fn unravel_index_c(&self, index: usize) -> Self {
144 let mut index = index;
145 let mut result = self.new_shape();
146 match self.ndim() {
147 0 => (),
148 1 => {
149 result[0] = index;
150 },
151 2 => {
152 result[0] = index / self[1];
153 result[1] = index % self[1];
154 },
155 3 => {
156 result[0] = index / (self[1] * self[2]);
157 index %= self[1] * self[2];
158 result[1] = index / self[2];
159 result[2] = index % self[2];
160 },
161 4 => {
162 result[0] = index / (self[1] * self[2] * self[3]);
163 index %= self[1] * self[2] * self[3];
164 result[1] = index / (self[2] * self[3]);
165 index %= self[2] * self[3];
166 result[2] = index / self[3];
167 result[3] = index % self[3];
168 },
169 _ => {
170 for i in (1..self.ndim()).rev() {
171 let dim = self[i];
172 result[i] = index % dim;
173 index /= dim;
174 }
175 result[0] = index;
176 },
177 }
178 return result;
179 }
180}
181
182impl DimShapeAPI for IxD {
183 fn shape_size(&self) -> usize {
184 self.iter().product()
185 }
186
187 fn stride_f_contig(&self) -> Vec<isize> {
188 let mut stride = vec![1; self.len()];
189 for i in 1..self.len() {
190 stride[i] = stride[i - 1] * self[i - 1] as isize;
191 }
192 stride
193 }
194
195 fn stride_c_contig(&self) -> Vec<isize> {
196 let mut stride = vec![1; self.len()];
197 if self.is_empty() {
198 return stride;
199 }
200 for i in (0..self.len() - 1).rev() {
201 stride[i] = stride[i + 1] * self[i + 1] as isize;
202 }
203 stride
204 }
205
206 fn stride_contig(&self) -> Vec<isize> {
207 match FlagOrder::default() {
208 RowMajor => Self::stride_c_contig(self),
209 ColMajor => Self::stride_f_contig(self),
210 }
211 }
212
213 #[inline]
214 unsafe fn unravel_index_f(&self, index: usize) -> Self {
215 let mut index = index;
216 let mut result = self.new_shape();
217 if self.ndim() >= 1 {
218 for i in 0..(self.ndim() - 1) {
219 let dim = self[i];
220 result[i] = index % dim;
221 index /= dim;
222 }
223 result[self.ndim() - 1] = index;
224 }
225 return result;
226 }
227
228 #[inline]
229 unsafe fn unravel_index_c(&self, index: usize) -> Self {
230 let mut index = index;
231 let mut result = self.new_shape();
232 if self.ndim() >= 1 {
233 for i in (1..self.ndim()).rev() {
234 let dim = self[i];
235 result[i] = index % dim;
236 index /= dim;
237 }
238 result[0] = index;
239 }
240 return result;
241 }
242}
243
244#[cfg(test)]
245mod test {
246 use super::*;
247
248 #[test]
249 fn test_ndim() {
250 let shape = [2, 3];
252 assert_eq!(shape.ndim(), 2);
253 let shape = vec![2, 3];
254 assert_eq!(shape.ndim(), 2);
255 let shape = [];
257 assert_eq!(shape.ndim(), 0);
258 let shape = vec![];
259 assert_eq!(shape.ndim(), 0);
260 }
261
262 #[test]
263 fn test_size() {
264 let shape = [2, 3];
266 assert_eq!(shape.shape_size(), 6);
267 let shape = vec![];
268 assert_eq!(shape.shape_size(), 1);
269 let shape = [];
271 assert_eq!(shape.shape_size(), 1);
272 let shape = vec![];
273 assert_eq!(shape.shape_size(), 1);
274 let shape = [1, 2, 0, 4];
276 assert_eq!(shape.shape_size(), 0);
277 }
278
279 #[test]
280 fn test_stride_f_contig() {
281 let stride = [2, 3, 5].stride_f_contig();
283 assert_eq!(stride, [1, 2, 6]);
284 let stride = [].stride_f_contig();
286 assert_eq!(stride, []);
287 let stride = vec![].stride_f_contig();
288 assert_eq!(stride, vec![]);
289 let stride = [1, 2, 0, 4].stride_f_contig();
291 println!("{stride:?}");
292 }
293
294 #[test]
295 fn test_stride_c_contig() {
296 let stride = [2, 3, 5].stride_c_contig();
298 assert_eq!(stride, [15, 5, 1]);
299 let stride = [].stride_c_contig();
301 assert_eq!(stride, []);
302 let stride = vec![].stride_c_contig();
303 assert_eq!(stride, vec![]);
304 let stride = [1, 2, 0, 4].stride_c_contig();
306 println!("{stride:?}");
307 }
308}