burn_std/tensor/
indexing.rs

1//! A module for indexing utility machinery.
2
3use crate::IndexKind;
4pub use crate::errors::BoundsError;
5#[allow(unused_imports)]
6use alloc::format;
7#[allow(unused_imports)]
8use alloc::string::{String, ToString};
9use core::fmt::Debug;
10
11/// Helper trait for implementing indexing with support for negative indices.
12///
13/// # Example
14/// ```rust
15/// use burn_std::AsIndex;
16///
17/// fn example<I: AsIndex, const D: usize>(dim: I, size: usize) -> isize {
18///    let dim: usize = dim.expect_dim_index(D);
19///    unimplemented!()
20/// }
21/// ```
22pub trait AsIndex: Debug + Copy + Sized {
23    /// Converts into a slice index.
24    fn index(self) -> isize;
25
26    /// Short-form [`IndexWrap::expect_index(idx, size)`].
27    fn expect_elem_index(self, size: usize) -> usize {
28        IndexWrap::expect_elem(self, size)
29    }
30
31    /// Short-form [`IndexWrap::expect_dim(idx, size)`].
32    fn expect_dim_index(self, size: usize) -> usize {
33        IndexWrap::expect_dim(self, size)
34    }
35}
36
37impl AsIndex for usize {
38    fn index(self) -> isize {
39        self as isize
40    }
41}
42
43impl AsIndex for isize {
44    fn index(self) -> isize {
45        self
46    }
47}
48
49impl AsIndex for i64 {
50    fn index(self) -> isize {
51        self as isize
52    }
53}
54
55impl AsIndex for u64 {
56    fn index(self) -> isize {
57        self as isize
58    }
59}
60
61// Default integer type
62impl AsIndex for i32 {
63    fn index(self) -> isize {
64        self as isize
65    }
66}
67
68impl AsIndex for u32 {
69    fn index(self) -> isize {
70        self as isize
71    }
72}
73
74impl AsIndex for i16 {
75    fn index(self) -> isize {
76        self as isize
77    }
78}
79
80impl AsIndex for u16 {
81    fn index(self) -> isize {
82        self as isize
83    }
84}
85
86impl AsIndex for i8 {
87    fn index(self) -> isize {
88        self as isize
89    }
90}
91
92impl AsIndex for u8 {
93    fn index(self) -> isize {
94        self as isize
95    }
96}
97
98/// Wraps an index with negative indexing support.
99#[derive(Debug)]
100pub struct IndexWrap {
101    kind: IndexKind,
102    wrap_scalar: bool,
103}
104
105impl IndexWrap {
106    /// Get an instance for wrapping negative indices.
107    pub fn index() -> Self {
108        Self {
109            kind: IndexKind::Element,
110            wrap_scalar: false,
111        }
112    }
113
114    /// Get an instance for wrapping negative dimensions.
115    pub fn dim() -> Self {
116        Self {
117            kind: IndexKind::Dimension,
118            wrap_scalar: false,
119        }
120    }
121
122    /// Set the policy for wrapping 0-size ranges.
123    ///
124    /// When ``size`` == 0:
125    ///   - if `wrap_scalar`; then ``size == 1``
126    ///   - otherwise; an error.
127    pub fn with_wrap_scalar(self, wrap_scalar: bool) -> Self {
128        Self {
129            wrap_scalar,
130            ..self
131        }
132    }
133
134    /// Wrap an index with negative indexing support.
135    pub fn try_wrap<I: AsIndex>(&self, idx: I, size: usize) -> Result<usize, BoundsError> {
136        try_wrap(idx, size, self.kind, self.wrap_scalar)
137    }
138
139    /// Wrap an index with negative indexing support.
140    pub fn expect_wrap<I: AsIndex>(&self, idx: I, size: usize) -> usize {
141        expect_wrap(idx, size, self.kind, self.wrap_scalar)
142    }
143
144    /// Short-form [`NegativeWrap::index().expect_wrap(idx, size)`].
145    pub fn expect_elem<I: AsIndex>(idx: I, size: usize) -> usize {
146        Self::index().expect_wrap(idx, size)
147    }
148
149    /// Short-form [`NegativeWrap::dim().expect_wrap(idx, size)`].
150    pub fn expect_dim<I: AsIndex>(idx: I, size: usize) -> usize {
151        Self::dim().expect_wrap(idx, size)
152    }
153}
154
155/// Wraps an index with negative indexing support.
156///
157/// ## Arguments
158/// - `idx` - The index to canonicalize.
159/// - `size` - The size of the index range.
160/// - `kind` - The kind of index (for error messages).
161/// - `size_name` - The name of the size (for error messages).
162/// - `wrap_scalar` - If true, treat 0-size ranges as having size 1.
163///
164/// ## Returns
165///
166/// A `Result<usize, BoundsError>` of the canonicalized index.
167pub fn expect_wrap<I>(idx: I, size: usize, kind: IndexKind, wrap_scalar: bool) -> usize
168where
169    I: AsIndex,
170{
171    try_wrap(idx, size, kind, wrap_scalar).expect("valid index")
172}
173
174/// Wraps an index with negative indexing support.
175///
176/// ## Arguments
177/// - `idx` - The index to canonicalize.
178/// - `size` - The size of the index range.
179/// - `kind` - The kind of index (for error messages).
180/// - `size_name` - The name of the size (for error messages).
181/// - `wrap_scalar` - If true, treat 0-size ranges as having size 1.
182///
183/// ## Returns
184///
185/// A `Result<usize, BoundsError>` of the canonicalized index.
186pub fn try_wrap<I>(
187    idx: I,
188    size: usize,
189    kind: IndexKind,
190    wrap_scalar: bool,
191) -> Result<usize, BoundsError>
192where
193    I: AsIndex,
194{
195    let source_idx = idx.index();
196    let source_size = size;
197
198    let size = if source_size > 0 {
199        source_size
200    } else {
201        if !wrap_scalar {
202            return Err(BoundsError::index(kind, source_idx, 0..0));
203        }
204        1
205    };
206
207    if source_idx >= 0 && (source_idx as usize) < size {
208        return Ok(source_idx as usize);
209    }
210
211    let _idx = if source_idx < 0 {
212        source_idx + size as isize
213    } else {
214        source_idx
215    };
216
217    if _idx < 0 || (_idx as usize) >= size {
218        let rank = size as isize;
219
220        return Err(BoundsError::index(kind, source_idx, 0..rank));
221    }
222
223    Ok(_idx as usize)
224}
225
226/// Wraps a dimension index to be within the bounds of the dimension size.
227///
228/// ## Arguments
229///
230/// * `idx` - The dimension index to wrap.
231/// * `size` - The size of the dimension.
232///
233/// ## Returns
234///
235/// The positive wrapped dimension index.
236#[inline]
237#[must_use]
238pub fn wrap_index<I>(idx: I, size: usize) -> usize
239where
240    I: AsIndex,
241{
242    if size == 0 {
243        return 0; // Avoid modulo by zero
244    }
245    let wrapped = idx.index().rem_euclid(size as isize);
246    if wrapped < 0 {
247        (wrapped + size as isize) as usize
248    } else {
249        wrapped as usize
250    }
251}
252
253/// Compute the ravel index for the given coordinates.
254///
255/// This returns the row-major order raveling:
256/// * `strides[-1] = 1`
257/// * `strides[i] = strides[i+1] * dims[i+1]`
258/// * `dim_strides = coords * strides`
259/// * `ravel = sum(dim_strides)`
260///
261/// # Arguments
262/// - `indices`: the index for each dimension; must be the same length as `shape`.
263/// - `shape`: the shape of each dimension; be the same length as `indices`.
264///
265/// # Returns
266/// - the ravel offset index.
267pub fn ravel_index<I: AsIndex>(indices: &[I], shape: &[usize]) -> usize {
268    assert_eq!(
269        shape.len(),
270        indices.len(),
271        "Coordinate rank mismatch: expected {}, got {}",
272        shape.len(),
273        indices.len(),
274    );
275
276    let mut ravel_idx = 0;
277    let mut stride = 1;
278
279    for (i, &dim) in shape.iter().enumerate().rev() {
280        let idx = indices[i];
281        let coord = IndexWrap::index().expect_wrap(idx, dim);
282        ravel_idx += coord * stride;
283        stride *= dim;
284    }
285
286    ravel_idx
287}
288
289#[cfg(test)]
290#[allow(clippy::identity_op, reason = "useful for clarity")]
291mod tests {
292    use super::*;
293    use alloc::vec;
294
295    #[test]
296    fn test_ravel() {
297        let shape = vec![2, 3, 4, 5];
298
299        assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0);
300        assert_eq!(
301            ravel_index(&[1, 2, 3, 4], &shape),
302            1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
303        );
304    }
305
306    #[test]
307    fn test_wrap_idx() {
308        assert_eq!(wrap_index(0, 3), 0_usize);
309        assert_eq!(wrap_index(3, 3), 0_usize);
310        assert_eq!(wrap_index(2 * 3, 3), 0_usize);
311        assert_eq!(wrap_index(0 - 3, 3), 0_usize);
312        assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
313
314        assert_eq!(wrap_index(1, 3), 1_usize);
315        assert_eq!(wrap_index(1 + 3, 3), 1_usize);
316        assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
317        assert_eq!(wrap_index(1 - 3, 3), 1_usize);
318        assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
319
320        assert_eq!(wrap_index(2, 3), 2_usize);
321        assert_eq!(wrap_index(2 + 3, 3), 2_usize);
322        assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
323        assert_eq!(wrap_index(2 - 3, 3), 2_usize);
324        assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
325    }
326
327    #[test]
328    fn test_negative_wrap() {
329        assert_eq!(IndexWrap::index().expect_wrap(0, 3), 0);
330        assert_eq!(IndexWrap::index().expect_wrap(1, 3), 1);
331        assert_eq!(IndexWrap::index().expect_wrap(2, 3), 2);
332        assert_eq!(IndexWrap::index().expect_wrap(-1, 3), 2);
333        assert_eq!(IndexWrap::index().expect_wrap(-2, 3), 1);
334        assert_eq!(IndexWrap::index().expect_wrap(-3, 3), 0);
335
336        assert_eq!(IndexWrap::dim().expect_wrap(0, 3), 0);
337        assert_eq!(IndexWrap::dim().expect_wrap(1, 3), 1);
338        assert_eq!(IndexWrap::dim().expect_wrap(2, 3), 2);
339        assert_eq!(IndexWrap::dim().expect_wrap(-1, 3), 2);
340        assert_eq!(IndexWrap::dim().expect_wrap(-2, 3), 1);
341        assert_eq!(IndexWrap::dim().expect_wrap(-3, 3), 0);
342
343        assert_eq!(
344            IndexWrap::index().try_wrap(3, 3),
345            Err(BoundsError::Index {
346                kind: IndexKind::Element,
347                index: 3,
348                bounds: 0..3,
349            })
350        );
351        assert_eq!(
352            IndexWrap::index().try_wrap(-4, 3),
353            Err(BoundsError::Index {
354                kind: IndexKind::Element,
355                index: -4,
356                bounds: 0..3,
357            })
358        );
359        assert_eq!(
360            IndexWrap::dim().try_wrap(3, 3),
361            Err(BoundsError::Index {
362                kind: IndexKind::Dimension,
363                index: 3,
364                bounds: 0..3,
365            })
366        );
367        assert_eq!(
368            IndexWrap::dim().try_wrap(-4, 3),
369            Err(BoundsError::Index {
370                kind: IndexKind::Dimension,
371                index: -4,
372                bounds: 0..3,
373            })
374        );
375    }
376
377    #[test]
378    fn test_negative_wrap_scalar() {
379        assert_eq!(
380            IndexWrap::index().try_wrap(0, 0),
381            Err(BoundsError::Index {
382                kind: IndexKind::Element,
383                index: 0,
384                bounds: 0..0,
385            })
386        );
387
388        assert_eq!(
389            IndexWrap::index().with_wrap_scalar(true).expect_wrap(0, 0),
390            0
391        );
392        assert_eq!(
393            IndexWrap::index().with_wrap_scalar(true).expect_wrap(-1, 0),
394            0
395        );
396
397        assert_eq!(
398            IndexWrap::index().with_wrap_scalar(false).try_wrap(1, 0),
399            Err(BoundsError::Index {
400                kind: IndexKind::Element,
401                index: 1,
402                bounds: 0..0,
403            })
404        );
405    }
406}