burn_tensor/tensor/indexing/
mod.rs1use core::fmt::Debug;
4
5pub trait AsIndex: Debug + Copy + Sized {
17 fn index(self) -> isize;
19}
20
21impl AsIndex for usize {
22 fn index(self) -> isize {
23 self as isize
24 }
25}
26
27impl AsIndex for isize {
28 fn index(self) -> isize {
29 self
30 }
31}
32
33impl AsIndex for i64 {
34 fn index(self) -> isize {
35 self as isize
36 }
37}
38
39impl AsIndex for u64 {
40 fn index(self) -> isize {
41 self as isize
42 }
43}
44
45impl AsIndex for i32 {
47 fn index(self) -> isize {
48 self as isize
49 }
50}
51
52impl AsIndex for u32 {
53 fn index(self) -> isize {
54 self as isize
55 }
56}
57
58impl AsIndex for i16 {
59 fn index(self) -> isize {
60 self as isize
61 }
62}
63
64impl AsIndex for u16 {
65 fn index(self) -> isize {
66 self as isize
67 }
68}
69
70impl AsIndex for i8 {
71 fn index(self) -> isize {
72 self as isize
73 }
74}
75
76impl AsIndex for u8 {
77 fn index(self) -> isize {
78 self as isize
79 }
80}
81
82#[must_use]
99pub fn canonicalize_index<Index>(idx: Index, size: usize, wrap_scalar: bool) -> usize
100where
101 Index: AsIndex,
102{
103 canonicalize_named_index("index", "size", idx, size, wrap_scalar)
104}
105
106#[must_use]
123pub fn canonicalize_dim<Dim>(idx: Dim, rank: usize, wrap_scalar: bool) -> usize
124where
125 Dim: AsIndex,
126{
127 canonicalize_named_index("dimension index", "rank", idx, rank, wrap_scalar)
128}
129
130#[inline(always)]
149#[must_use]
150fn canonicalize_named_index<I>(
151 name: &str,
152 size_name: &str,
153 idx: I,
154 size: usize,
155 wrap_scalar: bool,
156) -> usize
157where
158 I: AsIndex,
159{
160 let idx = idx.index();
161
162 let rank = if size > 0 {
163 size
164 } else {
165 if !wrap_scalar {
166 panic!("{name} {idx} used when {size_name} is 0");
167 }
168 1
169 };
170
171 if idx >= 0 && (idx as usize) < rank {
172 return idx as usize;
173 }
174
175 let _idx = if idx < 0 { idx + rank as isize } else { idx };
176
177 if _idx < 0 || (_idx as usize) >= rank {
178 let rank = rank as isize;
179 let lower = -rank;
180 let upper = rank - 1;
181 panic!("{name} {idx} out of range: ({lower}..={upper})");
182 }
183
184 _idx as usize
185}
186
187#[inline]
198#[must_use]
199pub fn wrap_index<I>(idx: I, size: usize) -> usize
200where
201 I: AsIndex,
202{
203 if size == 0 {
204 return 0; }
206 let wrapped = idx.index().rem_euclid(size as isize);
207 if wrapped < 0 {
208 (wrapped + size as isize) as usize
209 } else {
210 wrapped as usize
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test_wrap_idx() {
220 assert_eq!(wrap_index(0, 3), 0_usize);
221 assert_eq!(wrap_index(3, 3), 0_usize);
222 assert_eq!(wrap_index(2 * 3, 3), 0_usize);
223 assert_eq!(wrap_index(0 - 3, 3), 0_usize);
224 assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
225
226 assert_eq!(wrap_index(1, 3), 1_usize);
227 assert_eq!(wrap_index(1 + 3, 3), 1_usize);
228 assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
229 assert_eq!(wrap_index(1 - 3, 3), 1_usize);
230 assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
231
232 assert_eq!(wrap_index(2, 3), 2_usize);
233 assert_eq!(wrap_index(2 + 3, 3), 2_usize);
234 assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
235 assert_eq!(wrap_index(2 - 3, 3), 2_usize);
236 assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
237 }
238
239 #[test]
240 fn test_canonicalize_dim() {
241 let wrap_scalar = false;
242 assert_eq!(canonicalize_dim(0, 3, wrap_scalar), 0_usize);
243 assert_eq!(canonicalize_dim(1, 3, wrap_scalar), 1_usize);
244 assert_eq!(canonicalize_dim(2, 3, wrap_scalar), 2_usize);
245
246 assert_eq!(canonicalize_dim(-1, 3, wrap_scalar), (3 - 1) as usize);
247 assert_eq!(canonicalize_dim(-2, 3, wrap_scalar), (3 - 2) as usize);
248 assert_eq!(canonicalize_dim(-3, 3, wrap_scalar), (3 - 3) as usize);
249
250 let wrap_scalar = true;
251 assert_eq!(canonicalize_dim(0, 0, wrap_scalar), 0);
252 assert_eq!(canonicalize_dim(-1, 0, wrap_scalar), 0);
253 }
254
255 #[test]
256 #[should_panic = "dimension index 0 used when rank is 0"]
257 fn test_canonicalize_dim_error_no_dims() {
258 let _d = canonicalize_dim(0, 0, false);
259 }
260
261 #[test]
262 #[should_panic = "dimension index 3 out of range: (-3..=2)"]
263 fn test_canonicalize_dim_error_too_big() {
264 let _d = canonicalize_dim(3, 3, false);
265 }
266 #[test]
267 #[should_panic = "dimension index -4 out of range: (-3..=2)"]
268 fn test_canonicalize_dim_error_too_small() {
269 let _d = canonicalize_dim(-4, 3, false);
270 }
271
272 #[test]
273 fn test_canonicalize_index() {
274 let wrap_scalar = false;
275 assert_eq!(canonicalize_index(0, 3, wrap_scalar), 0_usize);
276 assert_eq!(canonicalize_index(1, 3, wrap_scalar), 1_usize);
277 assert_eq!(canonicalize_index(2, 3, wrap_scalar), 2_usize);
278
279 assert_eq!(canonicalize_index(-1, 3, wrap_scalar), (3 - 1) as usize);
280 assert_eq!(canonicalize_index(-2, 3, wrap_scalar), (3 - 2) as usize);
281 assert_eq!(canonicalize_index(-3, 3, wrap_scalar), (3 - 3) as usize);
282
283 let wrap_scalar = true;
284 assert_eq!(canonicalize_index(0, 0, wrap_scalar), 0);
285 assert_eq!(canonicalize_index(-1, 0, wrap_scalar), 0);
286 }
287
288 #[test]
289 #[should_panic = "index 3 out of range: (-3..=2)"]
290 fn test_canonicalize_index_error_too_big() {
291 let _d = canonicalize_index(3, 3, false);
292 }
293}