1use crate::IndexKind;
4pub use crate::errors::BoundsError;
5#[allow(unused_imports)]
6use alloc::format;
7#[allow(unused_imports)]
8use alloc::string::{String, ToString};
9use core::fmt::Debug;
10
11pub trait AsIndex: Debug + Copy + Sized {
23 fn index(self) -> isize;
25
26 fn expect_elem_index(self, size: usize) -> usize {
28 IndexWrap::expect_elem(self, size)
29 }
30
31 fn expect_dim_index(self, size: usize) -> usize {
33 IndexWrap::expect_dim(self, size)
34 }
35}
36
37impl AsIndex for usize {
38 fn index(self) -> isize {
39 self as isize
40 }
41}
42
43impl AsIndex for isize {
44 fn index(self) -> isize {
45 self
46 }
47}
48
49impl AsIndex for i64 {
50 fn index(self) -> isize {
51 self as isize
52 }
53}
54
55impl AsIndex for u64 {
56 fn index(self) -> isize {
57 self as isize
58 }
59}
60
61impl AsIndex for i32 {
63 fn index(self) -> isize {
64 self as isize
65 }
66}
67
68impl AsIndex for u32 {
69 fn index(self) -> isize {
70 self as isize
71 }
72}
73
74impl AsIndex for i16 {
75 fn index(self) -> isize {
76 self as isize
77 }
78}
79
80impl AsIndex for u16 {
81 fn index(self) -> isize {
82 self as isize
83 }
84}
85
86impl AsIndex for i8 {
87 fn index(self) -> isize {
88 self as isize
89 }
90}
91
92impl AsIndex for u8 {
93 fn index(self) -> isize {
94 self as isize
95 }
96}
97
98#[derive(Debug)]
100pub struct IndexWrap {
101 kind: IndexKind,
102 wrap_scalar: bool,
103}
104
105impl IndexWrap {
106 pub fn index() -> Self {
108 Self {
109 kind: IndexKind::Element,
110 wrap_scalar: false,
111 }
112 }
113
114 pub fn dim() -> Self {
116 Self {
117 kind: IndexKind::Dimension,
118 wrap_scalar: false,
119 }
120 }
121
122 pub fn with_wrap_scalar(self, wrap_scalar: bool) -> Self {
128 Self {
129 wrap_scalar,
130 ..self
131 }
132 }
133
134 pub fn try_wrap<I: AsIndex>(&self, idx: I, size: usize) -> Result<usize, BoundsError> {
136 try_wrap(idx, size, self.kind, self.wrap_scalar)
137 }
138
139 pub fn expect_wrap<I: AsIndex>(&self, idx: I, size: usize) -> usize {
141 expect_wrap(idx, size, self.kind, self.wrap_scalar)
142 }
143
144 pub fn expect_elem<I: AsIndex>(idx: I, size: usize) -> usize {
146 Self::index().expect_wrap(idx, size)
147 }
148
149 pub fn expect_dim<I: AsIndex>(idx: I, size: usize) -> usize {
151 Self::dim().expect_wrap(idx, size)
152 }
153}
154
155pub fn expect_wrap<I>(idx: I, size: usize, kind: IndexKind, wrap_scalar: bool) -> usize
168where
169 I: AsIndex,
170{
171 try_wrap(idx, size, kind, wrap_scalar).expect("valid index")
172}
173
174pub fn try_wrap<I>(
187 idx: I,
188 size: usize,
189 kind: IndexKind,
190 wrap_scalar: bool,
191) -> Result<usize, BoundsError>
192where
193 I: AsIndex,
194{
195 let source_idx = idx.index();
196 let source_size = size;
197
198 let size = if source_size > 0 {
199 source_size
200 } else {
201 if !wrap_scalar {
202 return Err(BoundsError::index(kind, source_idx, 0..0));
203 }
204 1
205 };
206
207 if source_idx >= 0 && (source_idx as usize) < size {
208 return Ok(source_idx as usize);
209 }
210
211 let _idx = if source_idx < 0 {
212 source_idx + size as isize
213 } else {
214 source_idx
215 };
216
217 if _idx < 0 || (_idx as usize) >= size {
218 let rank = size as isize;
219
220 return Err(BoundsError::index(kind, source_idx, 0..rank));
221 }
222
223 Ok(_idx as usize)
224}
225
226#[inline]
237#[must_use]
238pub fn wrap_index<I>(idx: I, size: usize) -> usize
239where
240 I: AsIndex,
241{
242 if size == 0 {
243 return 0; }
245 let wrapped = idx.index().rem_euclid(size as isize);
246 if wrapped < 0 {
247 (wrapped + size as isize) as usize
248 } else {
249 wrapped as usize
250 }
251}
252
253pub fn ravel_index<I: AsIndex>(indices: &[I], shape: &[usize]) -> usize {
268 assert_eq!(
269 shape.len(),
270 indices.len(),
271 "Coordinate rank mismatch: expected {}, got {}",
272 shape.len(),
273 indices.len(),
274 );
275
276 let mut ravel_idx = 0;
277 let mut stride = 1;
278
279 for (i, &dim) in shape.iter().enumerate().rev() {
280 let idx = indices[i];
281 let coord = IndexWrap::index().expect_wrap(idx, dim);
282 ravel_idx += coord * stride;
283 stride *= dim;
284 }
285
286 ravel_idx
287}
288
289#[cfg(test)]
290#[allow(clippy::identity_op, reason = "useful for clarity")]
291mod tests {
292 use super::*;
293 use alloc::vec;
294
295 #[test]
296 fn test_ravel() {
297 let shape = vec![2, 3, 4, 5];
298
299 assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0);
300 assert_eq!(
301 ravel_index(&[1, 2, 3, 4], &shape),
302 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
303 );
304 }
305
306 #[test]
307 fn test_wrap_idx() {
308 assert_eq!(wrap_index(0, 3), 0_usize);
309 assert_eq!(wrap_index(3, 3), 0_usize);
310 assert_eq!(wrap_index(2 * 3, 3), 0_usize);
311 assert_eq!(wrap_index(0 - 3, 3), 0_usize);
312 assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
313
314 assert_eq!(wrap_index(1, 3), 1_usize);
315 assert_eq!(wrap_index(1 + 3, 3), 1_usize);
316 assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
317 assert_eq!(wrap_index(1 - 3, 3), 1_usize);
318 assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
319
320 assert_eq!(wrap_index(2, 3), 2_usize);
321 assert_eq!(wrap_index(2 + 3, 3), 2_usize);
322 assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
323 assert_eq!(wrap_index(2 - 3, 3), 2_usize);
324 assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
325 }
326
327 #[test]
328 fn test_negative_wrap() {
329 assert_eq!(IndexWrap::index().expect_wrap(0, 3), 0);
330 assert_eq!(IndexWrap::index().expect_wrap(1, 3), 1);
331 assert_eq!(IndexWrap::index().expect_wrap(2, 3), 2);
332 assert_eq!(IndexWrap::index().expect_wrap(-1, 3), 2);
333 assert_eq!(IndexWrap::index().expect_wrap(-2, 3), 1);
334 assert_eq!(IndexWrap::index().expect_wrap(-3, 3), 0);
335
336 assert_eq!(IndexWrap::dim().expect_wrap(0, 3), 0);
337 assert_eq!(IndexWrap::dim().expect_wrap(1, 3), 1);
338 assert_eq!(IndexWrap::dim().expect_wrap(2, 3), 2);
339 assert_eq!(IndexWrap::dim().expect_wrap(-1, 3), 2);
340 assert_eq!(IndexWrap::dim().expect_wrap(-2, 3), 1);
341 assert_eq!(IndexWrap::dim().expect_wrap(-3, 3), 0);
342
343 assert_eq!(
344 IndexWrap::index().try_wrap(3, 3),
345 Err(BoundsError::Index {
346 kind: IndexKind::Element,
347 index: 3,
348 bounds: 0..3,
349 })
350 );
351 assert_eq!(
352 IndexWrap::index().try_wrap(-4, 3),
353 Err(BoundsError::Index {
354 kind: IndexKind::Element,
355 index: -4,
356 bounds: 0..3,
357 })
358 );
359 assert_eq!(
360 IndexWrap::dim().try_wrap(3, 3),
361 Err(BoundsError::Index {
362 kind: IndexKind::Dimension,
363 index: 3,
364 bounds: 0..3,
365 })
366 );
367 assert_eq!(
368 IndexWrap::dim().try_wrap(-4, 3),
369 Err(BoundsError::Index {
370 kind: IndexKind::Dimension,
371 index: -4,
372 bounds: 0..3,
373 })
374 );
375 }
376
377 #[test]
378 fn test_negative_wrap_scalar() {
379 assert_eq!(
380 IndexWrap::index().try_wrap(0, 0),
381 Err(BoundsError::Index {
382 kind: IndexKind::Element,
383 index: 0,
384 bounds: 0..0,
385 })
386 );
387
388 assert_eq!(
389 IndexWrap::index().with_wrap_scalar(true).expect_wrap(0, 0),
390 0
391 );
392 assert_eq!(
393 IndexWrap::index().with_wrap_scalar(true).expect_wrap(-1, 0),
394 0
395 );
396
397 assert_eq!(
398 IndexWrap::index().with_wrap_scalar(false).try_wrap(1, 0),
399 Err(BoundsError::Index {
400 kind: IndexKind::Element,
401 index: 1,
402 bounds: 0..0,
403 })
404 );
405 }
406}