burn_std/tensor/
indexing.rs

1//! A module for indexing utility machinery.
2
3use core::fmt::Debug;
4
5/// Helper trait for implementing indexing with support for negative indices.
6///
7/// # Example
8/// ```rust
9/// use burn_std::{AsIndex, canonicalize_dim};
10///
11/// fn example<I: AsIndex, const D: usize>(dim: I, size: usize) -> isize {
12///    let dim: usize = canonicalize_dim(dim, D, false);
13///    unimplemented!()
14/// }
15/// ```
16pub trait AsIndex: Debug + Copy + Sized {
17    /// Converts into a slice index.
18    fn index(self) -> isize;
19}
20
21impl AsIndex for usize {
22    fn index(self) -> isize {
23        self as isize
24    }
25}
26
27impl AsIndex for isize {
28    fn index(self) -> isize {
29        self
30    }
31}
32
33impl AsIndex for i64 {
34    fn index(self) -> isize {
35        self as isize
36    }
37}
38
39impl AsIndex for u64 {
40    fn index(self) -> isize {
41        self as isize
42    }
43}
44
45// Default integer type
46impl AsIndex for i32 {
47    fn index(self) -> isize {
48        self as isize
49    }
50}
51
52impl AsIndex for u32 {
53    fn index(self) -> isize {
54        self as isize
55    }
56}
57
58impl AsIndex for i16 {
59    fn index(self) -> isize {
60        self as isize
61    }
62}
63
64impl AsIndex for u16 {
65    fn index(self) -> isize {
66        self as isize
67    }
68}
69
70impl AsIndex for i8 {
71    fn index(self) -> isize {
72        self as isize
73    }
74}
75
76impl AsIndex for u8 {
77    fn index(self) -> isize {
78        self as isize
79    }
80}
81
82/// Canonicalizes and bounds checks an index with negative indexing support.
83///
84/// ## Arguments
85///
86/// * `idx` - The index to canonicalize.
87/// * `size` - The size of the index range.
88/// * `wrap_scalar` - If true, pretend scalars have rank=1.
89///
90/// ## Returns
91///
92/// The canonicalized dimension index.
93///
94/// ## Panics
95///
96/// * If `wrap_scalar` is false and the tensor has no dimensions.
97/// * If the dimension index is out of range.
98#[must_use]
99pub fn canonicalize_index<Index>(idx: Index, size: usize, wrap_scalar: bool) -> usize
100where
101    Index: AsIndex,
102{
103    canonicalize_named_index("index", "size", idx, size, wrap_scalar)
104}
105
106/// Canonicalizes and bounds checks a dimension index with negative indexing support.
107///
108/// ## Arguments
109///
110/// * `idx` - The dimension index to canonicalize.
111/// * `rank` - The number of dimensions.
112/// * `wrap_scalar` - If true, pretend scalars have rank=1.
113///
114/// ## Returns
115///
116/// The canonicalized dimension index.
117///
118/// ## Panics
119///
120/// * If `wrap_scalar` is false and the tensor has no dimensions.
121/// * If the dimension index is out of range.
122#[must_use]
123pub fn canonicalize_dim<Dim>(idx: Dim, rank: usize, wrap_scalar: bool) -> usize
124where
125    Dim: AsIndex,
126{
127    canonicalize_named_index("dimension index", "rank", idx, rank, wrap_scalar)
128}
129
130/// Canonicalizes and bounds checks an index with negative indexing support.
131///
132/// ## Arguments
133///
134/// * `name` - The name of the index (for error messages).
135/// * `size_name` - The name of the size (for error messages).
136/// * `idx` - The index to canonicalize.
137/// * `size` - The size of the index range.
138/// * `wrap_scalar` - If true, treat 0-size ranges as having size 1.
139///
140/// ## Returns
141///
142/// The canonicalized index.
143///
144/// ## Panics
145///
146/// * If `wrap_scalar` is false and the size is 0.
147/// * If the index is out of range for the dimension size.
148#[inline(always)]
149#[must_use]
150fn canonicalize_named_index<I>(
151    name: &str,
152    size_name: &str,
153    idx: I,
154    size: usize,
155    wrap_scalar: bool,
156) -> usize
157where
158    I: AsIndex,
159{
160    let idx = idx.index();
161
162    let rank = if size > 0 {
163        size
164    } else {
165        if !wrap_scalar {
166            panic!("{name} {idx} used when {size_name} is 0");
167        }
168        1
169    };
170
171    if idx >= 0 && (idx as usize) < rank {
172        return idx as usize;
173    }
174
175    let _idx = if idx < 0 { idx + rank as isize } else { idx };
176
177    if _idx < 0 || (_idx as usize) >= rank {
178        let rank = rank as isize;
179        let lower = -rank;
180        let upper = rank - 1;
181        panic!("{name} {idx} out of range: ({lower}..={upper})");
182    }
183
184    _idx as usize
185}
186
187/// Wraps a dimension index to be within the bounds of the dimension size.
188///
189/// ## Arguments
190///
191/// * `idx` - The dimension index to wrap.
192/// * `size` - The size of the dimension.
193///
194/// ## Returns
195///
196/// The positive wrapped dimension index.
197#[inline]
198#[must_use]
199pub fn wrap_index<I>(idx: I, size: usize) -> usize
200where
201    I: AsIndex,
202{
203    if size == 0 {
204        return 0; // Avoid modulo by zero
205    }
206    let wrapped = idx.index().rem_euclid(size as isize);
207    if wrapped < 0 {
208        (wrapped + size as isize) as usize
209    } else {
210        wrapped as usize
211    }
212}
213
214/// Compute the ravel index for the given coordinates.
215///
216/// This returns the row-major order raveling:
217/// * `strides[-1] = 1`
218/// * `strides[i] = strides[i+1] * dims[i+1]`
219/// * `dim_strides = coords * strides`
220/// * `ravel = sum(dim_strides)`
221///
222/// # Arguments
223/// - `indices`: the index for each dimension; must be the same length as `shape`.
224/// - `shape`: the shape of each dimension; be the same length as `indices`.
225///
226/// # Returns
227/// - the ravel offset index.
228pub fn ravel_index<I: AsIndex>(indices: &[I], shape: &[usize]) -> usize {
229    assert_eq!(
230        shape.len(),
231        indices.len(),
232        "Coordinate rank mismatch: expected {}, got {}",
233        shape.len(),
234        indices.len(),
235    );
236
237    let mut ravel_idx = 0;
238    let mut stride = 1;
239
240    for (i, &dim) in shape.iter().enumerate().rev() {
241        let coord = canonicalize_index(indices[i], dim, false);
242        ravel_idx += coord * stride;
243        stride *= dim;
244    }
245
246    ravel_idx
247}
248
249#[cfg(test)]
250#[allow(clippy::identity_op, reason = "useful for clarity")]
251mod tests {
252    use super::*;
253    use alloc::vec;
254
255    #[test]
256    fn test_ravel() {
257        let shape = vec![2, 3, 4, 5];
258
259        assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0);
260        assert_eq!(
261            ravel_index(&[1, 2, 3, 4], &shape),
262            1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
263        );
264    }
265
266    #[test]
267    fn test_wrap_idx() {
268        assert_eq!(wrap_index(0, 3), 0_usize);
269        assert_eq!(wrap_index(3, 3), 0_usize);
270        assert_eq!(wrap_index(2 * 3, 3), 0_usize);
271        assert_eq!(wrap_index(0 - 3, 3), 0_usize);
272        assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
273
274        assert_eq!(wrap_index(1, 3), 1_usize);
275        assert_eq!(wrap_index(1 + 3, 3), 1_usize);
276        assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
277        assert_eq!(wrap_index(1 - 3, 3), 1_usize);
278        assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
279
280        assert_eq!(wrap_index(2, 3), 2_usize);
281        assert_eq!(wrap_index(2 + 3, 3), 2_usize);
282        assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
283        assert_eq!(wrap_index(2 - 3, 3), 2_usize);
284        assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
285    }
286
287    #[test]
288    fn test_canonicalize_dim() {
289        let wrap_scalar = false;
290        assert_eq!(canonicalize_dim(0, 3, wrap_scalar), 0_usize);
291        assert_eq!(canonicalize_dim(1, 3, wrap_scalar), 1_usize);
292        assert_eq!(canonicalize_dim(2, 3, wrap_scalar), 2_usize);
293
294        assert_eq!(canonicalize_dim(-1, 3, wrap_scalar), (3 - 1) as usize);
295        assert_eq!(canonicalize_dim(-2, 3, wrap_scalar), (3 - 2) as usize);
296        assert_eq!(canonicalize_dim(-3, 3, wrap_scalar), (3 - 3) as usize);
297
298        let wrap_scalar = true;
299        assert_eq!(canonicalize_dim(0, 0, wrap_scalar), 0);
300        assert_eq!(canonicalize_dim(-1, 0, wrap_scalar), 0);
301    }
302
303    #[test]
304    #[should_panic = "dimension index 0 used when rank is 0"]
305    fn test_canonicalize_dim_error_no_dims() {
306        let _d = canonicalize_dim(0, 0, false);
307    }
308
309    #[test]
310    #[should_panic = "dimension index 3 out of range: (-3..=2)"]
311    fn test_canonicalize_dim_error_too_big() {
312        let _d = canonicalize_dim(3, 3, false);
313    }
314    #[test]
315    #[should_panic = "dimension index -4 out of range: (-3..=2)"]
316    fn test_canonicalize_dim_error_too_small() {
317        let _d = canonicalize_dim(-4, 3, false);
318    }
319
320    #[test]
321    fn test_canonicalize_index() {
322        let wrap_scalar = false;
323        assert_eq!(canonicalize_index(0, 3, wrap_scalar), 0_usize);
324        assert_eq!(canonicalize_index(1, 3, wrap_scalar), 1_usize);
325        assert_eq!(canonicalize_index(2, 3, wrap_scalar), 2_usize);
326
327        assert_eq!(canonicalize_index(-1, 3, wrap_scalar), (3 - 1) as usize);
328        assert_eq!(canonicalize_index(-2, 3, wrap_scalar), (3 - 2) as usize);
329        assert_eq!(canonicalize_index(-3, 3, wrap_scalar), (3 - 3) as usize);
330
331        let wrap_scalar = true;
332        assert_eq!(canonicalize_index(0, 0, wrap_scalar), 0);
333        assert_eq!(canonicalize_index(-1, 0, wrap_scalar), 0);
334    }
335
336    #[test]
337    #[should_panic = "index 3 out of range: (-3..=2)"]
338    fn test_canonicalize_index_error_too_big() {
339        let _d = canonicalize_index(3, 3, false);
340    }
341}