1pub use super::type_conversion::AsIndex;
4pub use crate::errors::BoundsError;
5use crate::errors::IndexKind;
6#[allow(unused_imports)]
7use alloc::format;
8#[allow(unused_imports)]
9use alloc::string::{String, ToString};
10use core::fmt::Debug;
11
12#[derive(Debug)]
14pub struct IndexWrap {
15 kind: IndexKind,
16 wrap_scalar: bool,
17}
18
19impl IndexWrap {
20 pub fn index() -> Self {
22 Self {
23 kind: IndexKind::Element,
24 wrap_scalar: false,
25 }
26 }
27
28 pub fn dim() -> Self {
30 Self {
31 kind: IndexKind::Dimension,
32 wrap_scalar: false,
33 }
34 }
35
36 pub fn with_wrap_scalar(self, wrap_scalar: bool) -> Self {
42 Self {
43 wrap_scalar,
44 ..self
45 }
46 }
47
48 pub fn try_wrap<I: AsIndex>(&self, idx: I, size: usize) -> Result<usize, BoundsError> {
50 try_wrap(idx, size, self.kind, self.wrap_scalar)
51 }
52
53 pub fn expect_wrap<I: AsIndex>(&self, idx: I, size: usize) -> usize {
55 expect_wrap(idx, size, self.kind, self.wrap_scalar)
56 }
57
58 pub fn expect_elem<I: AsIndex>(idx: I, size: usize) -> usize {
60 Self::index().expect_wrap(idx, size)
61 }
62
63 pub fn expect_dim<I: AsIndex>(idx: I, size: usize) -> usize {
65 Self::dim().expect_wrap(idx, size)
66 }
67}
68
69pub fn expect_wrap<I>(idx: I, size: usize, kind: IndexKind, wrap_scalar: bool) -> usize
82where
83 I: AsIndex,
84{
85 try_wrap(idx, size, kind, wrap_scalar).expect("valid index")
86}
87
88pub fn try_wrap<I>(
101 idx: I,
102 size: usize,
103 kind: IndexKind,
104 wrap_scalar: bool,
105) -> Result<usize, BoundsError>
106where
107 I: AsIndex,
108{
109 let source_idx = idx.as_index();
110 let source_size = size;
111
112 let size = if source_size > 0 {
113 source_size
114 } else {
115 if !wrap_scalar {
116 return Err(BoundsError::index(kind, source_idx, 0..0));
117 }
118 1
119 };
120
121 if source_idx >= 0 && (source_idx as usize) < size {
122 return Ok(source_idx as usize);
123 }
124
125 let _idx = if source_idx < 0 {
126 source_idx + size as isize
127 } else {
128 source_idx
129 };
130
131 if _idx < 0 || (_idx as usize) >= size {
132 let rank = size as isize;
133
134 return Err(BoundsError::index(kind, source_idx, 0..rank));
135 }
136
137 Ok(_idx as usize)
138}
139
140#[inline]
151#[must_use]
152pub fn wrap_index<I>(idx: I, size: usize) -> usize
153where
154 I: AsIndex,
155{
156 if size == 0 {
157 return 0; }
159 let wrapped = idx.as_index().rem_euclid(size as isize);
160 if wrapped < 0 {
161 (wrapped + size as isize) as usize
162 } else {
163 wrapped as usize
164 }
165}
166
167pub fn ravel_index<I: AsIndex>(indices: &[I], shape: &[usize]) -> usize {
182 assert_eq!(
183 shape.len(),
184 indices.len(),
185 "Coordinate rank mismatch: expected {}, got {}",
186 shape.len(),
187 indices.len(),
188 );
189
190 let mut ravel_idx = 0;
191 let mut stride = 1;
192
193 for (i, &dim) in shape.iter().enumerate().rev() {
194 let idx = indices[i];
195 let coord = IndexWrap::index().expect_wrap(idx, dim);
196 ravel_idx += coord * stride;
197 stride *= dim;
198 }
199
200 ravel_idx
201}
202
203#[cfg(test)]
204#[allow(clippy::identity_op, reason = "useful for clarity")]
205mod tests {
206 use super::*;
207 use crate::errors::IndexKind;
208 use alloc::vec;
209
210 #[test]
211 fn test_ravel() {
212 let shape = vec![2, 3, 4, 5];
213
214 assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0);
215 assert_eq!(
216 ravel_index(&[1, 2, 3, 4], &shape),
217 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
218 );
219 }
220
221 #[test]
222 fn test_wrap_idx() {
223 assert_eq!(wrap_index(0, 3), 0_usize);
224 assert_eq!(wrap_index(3, 3), 0_usize);
225 assert_eq!(wrap_index(2 * 3, 3), 0_usize);
226 assert_eq!(wrap_index(0 - 3, 3), 0_usize);
227 assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
228
229 assert_eq!(wrap_index(1, 3), 1_usize);
230 assert_eq!(wrap_index(1 + 3, 3), 1_usize);
231 assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
232 assert_eq!(wrap_index(1 - 3, 3), 1_usize);
233 assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
234
235 assert_eq!(wrap_index(2, 3), 2_usize);
236 assert_eq!(wrap_index(2 + 3, 3), 2_usize);
237 assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
238 assert_eq!(wrap_index(2 - 3, 3), 2_usize);
239 assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
240 }
241
242 #[test]
243 fn test_negative_wrap() {
244 assert_eq!(IndexWrap::index().expect_wrap(0, 3), 0);
245 assert_eq!(IndexWrap::index().expect_wrap(1, 3), 1);
246 assert_eq!(IndexWrap::index().expect_wrap(2, 3), 2);
247 assert_eq!(IndexWrap::index().expect_wrap(-1, 3), 2);
248 assert_eq!(IndexWrap::index().expect_wrap(-2, 3), 1);
249 assert_eq!(IndexWrap::index().expect_wrap(-3, 3), 0);
250
251 assert_eq!(IndexWrap::dim().expect_wrap(0, 3), 0);
252 assert_eq!(IndexWrap::dim().expect_wrap(1, 3), 1);
253 assert_eq!(IndexWrap::dim().expect_wrap(2, 3), 2);
254 assert_eq!(IndexWrap::dim().expect_wrap(-1, 3), 2);
255 assert_eq!(IndexWrap::dim().expect_wrap(-2, 3), 1);
256 assert_eq!(IndexWrap::dim().expect_wrap(-3, 3), 0);
257
258 assert_eq!(
259 IndexWrap::index().try_wrap(3, 3),
260 Err(BoundsError::Index {
261 kind: IndexKind::Element,
262 index: 3,
263 bounds: 0..3,
264 })
265 );
266 assert_eq!(
267 IndexWrap::index().try_wrap(-4, 3),
268 Err(BoundsError::Index {
269 kind: IndexKind::Element,
270 index: -4,
271 bounds: 0..3,
272 })
273 );
274 assert_eq!(
275 IndexWrap::dim().try_wrap(3, 3),
276 Err(BoundsError::Index {
277 kind: IndexKind::Dimension,
278 index: 3,
279 bounds: 0..3,
280 })
281 );
282 assert_eq!(
283 IndexWrap::dim().try_wrap(-4, 3),
284 Err(BoundsError::Index {
285 kind: IndexKind::Dimension,
286 index: -4,
287 bounds: 0..3,
288 })
289 );
290 }
291
292 #[test]
293 fn test_negative_wrap_scalar() {
294 assert_eq!(
295 IndexWrap::index().try_wrap(0, 0),
296 Err(BoundsError::Index {
297 kind: IndexKind::Element,
298 index: 0,
299 bounds: 0..0,
300 })
301 );
302
303 assert_eq!(
304 IndexWrap::index().with_wrap_scalar(true).expect_wrap(0, 0),
305 0
306 );
307 assert_eq!(
308 IndexWrap::index().with_wrap_scalar(true).expect_wrap(-1, 0),
309 0
310 );
311
312 assert_eq!(
313 IndexWrap::index().with_wrap_scalar(false).try_wrap(1, 0),
314 Err(BoundsError::Index {
315 kind: IndexKind::Element,
316 index: 1,
317 bounds: 0..0,
318 })
319 );
320 }
321}