1use crate::IndexKind;
4pub use crate::errors::BoundsError;
5#[allow(unused_imports)]
6use alloc::format;
7#[allow(unused_imports)]
8use alloc::string::{String, ToString};
9use core::fmt::Debug;
10
11pub use crate::tensor::index_conversion::AsIndex;
12
13#[derive(Debug)]
15pub struct IndexWrap {
16 kind: IndexKind,
17 wrap_scalar: bool,
18}
19
20impl IndexWrap {
21 pub fn index() -> Self {
23 Self {
24 kind: IndexKind::Element,
25 wrap_scalar: false,
26 }
27 }
28
29 pub fn dim() -> Self {
31 Self {
32 kind: IndexKind::Dimension,
33 wrap_scalar: false,
34 }
35 }
36
37 pub fn with_wrap_scalar(self, wrap_scalar: bool) -> Self {
43 Self {
44 wrap_scalar,
45 ..self
46 }
47 }
48
49 pub fn try_wrap<I: AsIndex>(&self, idx: I, size: usize) -> Result<usize, BoundsError> {
51 try_wrap(idx, size, self.kind, self.wrap_scalar)
52 }
53
54 pub fn expect_wrap<I: AsIndex>(&self, idx: I, size: usize) -> usize {
56 expect_wrap(idx, size, self.kind, self.wrap_scalar)
57 }
58
59 pub fn expect_elem<I: AsIndex>(idx: I, size: usize) -> usize {
61 Self::index().expect_wrap(idx, size)
62 }
63
64 pub fn expect_dim<I: AsIndex>(idx: I, size: usize) -> usize {
66 Self::dim().expect_wrap(idx, size)
67 }
68}
69
70pub fn expect_wrap<I>(idx: I, size: usize, kind: IndexKind, wrap_scalar: bool) -> usize
83where
84 I: AsIndex,
85{
86 try_wrap(idx, size, kind, wrap_scalar).expect("valid index")
87}
88
89pub fn try_wrap<I>(
102 idx: I,
103 size: usize,
104 kind: IndexKind,
105 wrap_scalar: bool,
106) -> Result<usize, BoundsError>
107where
108 I: AsIndex,
109{
110 let source_idx = idx.as_index();
111 let source_size = size;
112
113 let size = if source_size > 0 {
114 source_size
115 } else {
116 if !wrap_scalar {
117 return Err(BoundsError::index(kind, source_idx, 0..0));
118 }
119 1
120 };
121
122 if source_idx >= 0 && (source_idx as usize) < size {
123 return Ok(source_idx as usize);
124 }
125
126 let _idx = if source_idx < 0 {
127 source_idx + size as isize
128 } else {
129 source_idx
130 };
131
132 if _idx < 0 || (_idx as usize) >= size {
133 let rank = size as isize;
134
135 return Err(BoundsError::index(kind, source_idx, 0..rank));
136 }
137
138 Ok(_idx as usize)
139}
140
141#[inline]
152#[must_use]
153pub fn wrap_index<I>(idx: I, size: usize) -> usize
154where
155 I: AsIndex,
156{
157 if size == 0 {
158 return 0; }
160 let wrapped = idx.as_index().rem_euclid(size as isize);
161 if wrapped < 0 {
162 (wrapped + size as isize) as usize
163 } else {
164 wrapped as usize
165 }
166}
167
168pub fn ravel_index<I: AsIndex>(indices: &[I], shape: &[usize]) -> usize {
183 assert_eq!(
184 shape.len(),
185 indices.len(),
186 "Coordinate rank mismatch: expected {}, got {}",
187 shape.len(),
188 indices.len(),
189 );
190
191 let mut ravel_idx = 0;
192 let mut stride = 1;
193
194 for (i, &dim) in shape.iter().enumerate().rev() {
195 let idx = indices[i];
196 let coord = IndexWrap::index().expect_wrap(idx, dim);
197 ravel_idx += coord * stride;
198 stride *= dim;
199 }
200
201 ravel_idx
202}
203
204#[cfg(test)]
205#[allow(clippy::identity_op, reason = "useful for clarity")]
206mod tests {
207 use super::*;
208 use alloc::vec;
209
210 #[test]
211 fn test_ravel() {
212 let shape = vec![2, 3, 4, 5];
213
214 assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0);
215 assert_eq!(
216 ravel_index(&[1, 2, 3, 4], &shape),
217 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
218 );
219 }
220
221 #[test]
222 fn test_wrap_idx() {
223 assert_eq!(wrap_index(0, 3), 0_usize);
224 assert_eq!(wrap_index(3, 3), 0_usize);
225 assert_eq!(wrap_index(2 * 3, 3), 0_usize);
226 assert_eq!(wrap_index(0 - 3, 3), 0_usize);
227 assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
228
229 assert_eq!(wrap_index(1, 3), 1_usize);
230 assert_eq!(wrap_index(1 + 3, 3), 1_usize);
231 assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
232 assert_eq!(wrap_index(1 - 3, 3), 1_usize);
233 assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
234
235 assert_eq!(wrap_index(2, 3), 2_usize);
236 assert_eq!(wrap_index(2 + 3, 3), 2_usize);
237 assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
238 assert_eq!(wrap_index(2 - 3, 3), 2_usize);
239 assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
240 }
241
242 #[test]
243 fn test_negative_wrap() {
244 assert_eq!(IndexWrap::index().expect_wrap(0, 3), 0);
245 assert_eq!(IndexWrap::index().expect_wrap(1, 3), 1);
246 assert_eq!(IndexWrap::index().expect_wrap(2, 3), 2);
247 assert_eq!(IndexWrap::index().expect_wrap(-1, 3), 2);
248 assert_eq!(IndexWrap::index().expect_wrap(-2, 3), 1);
249 assert_eq!(IndexWrap::index().expect_wrap(-3, 3), 0);
250
251 assert_eq!(IndexWrap::dim().expect_wrap(0, 3), 0);
252 assert_eq!(IndexWrap::dim().expect_wrap(1, 3), 1);
253 assert_eq!(IndexWrap::dim().expect_wrap(2, 3), 2);
254 assert_eq!(IndexWrap::dim().expect_wrap(-1, 3), 2);
255 assert_eq!(IndexWrap::dim().expect_wrap(-2, 3), 1);
256 assert_eq!(IndexWrap::dim().expect_wrap(-3, 3), 0);
257
258 assert_eq!(
259 IndexWrap::index().try_wrap(3, 3),
260 Err(BoundsError::Index {
261 kind: IndexKind::Element,
262 index: 3,
263 bounds: 0..3,
264 })
265 );
266 assert_eq!(
267 IndexWrap::index().try_wrap(-4, 3),
268 Err(BoundsError::Index {
269 kind: IndexKind::Element,
270 index: -4,
271 bounds: 0..3,
272 })
273 );
274 assert_eq!(
275 IndexWrap::dim().try_wrap(3, 3),
276 Err(BoundsError::Index {
277 kind: IndexKind::Dimension,
278 index: 3,
279 bounds: 0..3,
280 })
281 );
282 assert_eq!(
283 IndexWrap::dim().try_wrap(-4, 3),
284 Err(BoundsError::Index {
285 kind: IndexKind::Dimension,
286 index: -4,
287 bounds: 0..3,
288 })
289 );
290 }
291
292 #[test]
293 fn test_negative_wrap_scalar() {
294 assert_eq!(
295 IndexWrap::index().try_wrap(0, 0),
296 Err(BoundsError::Index {
297 kind: IndexKind::Element,
298 index: 0,
299 bounds: 0..0,
300 })
301 );
302
303 assert_eq!(
304 IndexWrap::index().with_wrap_scalar(true).expect_wrap(0, 0),
305 0
306 );
307 assert_eq!(
308 IndexWrap::index().with_wrap_scalar(true).expect_wrap(-1, 0),
309 0
310 );
311
312 assert_eq!(
313 IndexWrap::index().with_wrap_scalar(false).try_wrap(1, 0),
314 Err(BoundsError::Index {
315 kind: IndexKind::Element,
316 index: 1,
317 bounds: 0..0,
318 })
319 );
320 }
321}