burn_tensor/tensor/indexing/
mod.rs

1//! A module for indexing utility machinery.
2
3use core::fmt::Debug;
4
5/// Helper trait for implementing indexing with support for negative indices.
6///
7/// # Example
8/// ```rust
9/// use burn_tensor::indexing::{AsIndex, canonicalize_dim};
10///
11/// fn example<I: AsIndex, const D: usize>(dim: I, size: usize) -> isize {
12///    let dim: usize = canonicalize_dim(dim, D, false);
13///    unimplemented!()
14/// }
15/// ```
16pub trait AsIndex: Debug + Copy + Sized {
17    /// Converts into a slice index.
18    fn index(self) -> isize;
19}
20
21impl AsIndex for usize {
22    fn index(self) -> isize {
23        self as isize
24    }
25}
26
27impl AsIndex for isize {
28    fn index(self) -> isize {
29        self
30    }
31}
32
33impl AsIndex for i64 {
34    fn index(self) -> isize {
35        self as isize
36    }
37}
38
39impl AsIndex for u64 {
40    fn index(self) -> isize {
41        self as isize
42    }
43}
44
45// Default integer type
46impl AsIndex for i32 {
47    fn index(self) -> isize {
48        self as isize
49    }
50}
51
52impl AsIndex for u32 {
53    fn index(self) -> isize {
54        self as isize
55    }
56}
57
58impl AsIndex for i16 {
59    fn index(self) -> isize {
60        self as isize
61    }
62}
63
64impl AsIndex for u16 {
65    fn index(self) -> isize {
66        self as isize
67    }
68}
69
70impl AsIndex for i8 {
71    fn index(self) -> isize {
72        self as isize
73    }
74}
75
76impl AsIndex for u8 {
77    fn index(self) -> isize {
78        self as isize
79    }
80}
81
82/// Canonicalizes and bounds checks an index with negative indexing support.
83///
84/// ## Arguments
85///
86/// * `idx` - The index to canonicalize.
87/// * `size` - The size of the index range.
88/// * `wrap_scalar` - If true, pretend scalars have rank=1.
89///
90/// ## Returns
91///
92/// The canonicalized dimension index.
93///
94/// ## Panics
95///
96/// * If `wrap_scalar` is false and the tensor has no dimensions.
97/// * If the dimension index is out of range.
98#[must_use]
99pub fn canonicalize_index<Index>(idx: Index, size: usize, wrap_scalar: bool) -> usize
100where
101    Index: AsIndex,
102{
103    canonicalize_named_index("index", "size", idx, size, wrap_scalar)
104}
105
106/// Canonicalizes and bounds checks a dimension index with negative indexing support.
107///
108/// ## Arguments
109///
110/// * `idx` - The dimension index to canonicalize.
111/// * `rank` - The number of dimensions.
112/// * `wrap_scalar` - If true, pretend scalars have rank=1.
113///
114/// ## Returns
115///
116/// The canonicalized dimension index.
117///
118/// ## Panics
119///
120/// * If `wrap_scalar` is false and the tensor has no dimensions.
121/// * If the dimension index is out of range.
122#[must_use]
123pub fn canonicalize_dim<Dim>(idx: Dim, rank: usize, wrap_scalar: bool) -> usize
124where
125    Dim: AsIndex,
126{
127    canonicalize_named_index("dimension index", "rank", idx, rank, wrap_scalar)
128}
129
130/// Canonicalizes and bounds checks an index with negative indexing support.
131///
132/// ## Arguments
133///
134/// * `name` - The name of the index (for error messages).
135/// * `size_name` - The name of the size (for error messages).
136/// * `idx` - The index to canonicalize.
137/// * `size` - The size of the index range.
138/// * `wrap_scalar` - If true, treat 0-size ranges as having size 1.
139///
140/// ## Returns
141///
142/// The canonicalized index.
143///
144/// ## Panics
145///
146/// * If `wrap_scalar` is false and the size is 0.
147/// * If the index is out of range for the dimension size.
148#[inline(always)]
149#[must_use]
150fn canonicalize_named_index<I>(
151    name: &str,
152    size_name: &str,
153    idx: I,
154    size: usize,
155    wrap_scalar: bool,
156) -> usize
157where
158    I: AsIndex,
159{
160    let idx = idx.index();
161
162    let rank = if size > 0 {
163        size
164    } else {
165        if !wrap_scalar {
166            panic!("{name} {idx} used when {size_name} is 0");
167        }
168        1
169    };
170
171    if idx >= 0 && (idx as usize) < rank {
172        return idx as usize;
173    }
174
175    let _idx = if idx < 0 { idx + rank as isize } else { idx };
176
177    if _idx < 0 || (_idx as usize) >= rank {
178        let rank = rank as isize;
179        let lower = -rank;
180        let upper = rank - 1;
181        panic!("{name} {idx} out of range: ({lower}..={upper})");
182    }
183
184    _idx as usize
185}
186
187/// Wraps a dimension index to be within the bounds of the dimension size.
188///
189/// ## Arguments
190///
191/// * `idx` - The dimension index to wrap.
192/// * `size` - The size of the dimension.
193///
194/// ## Returns
195///
196/// The positive wrapped dimension index.
197#[inline]
198#[must_use]
199pub fn wrap_index<I>(idx: I, size: usize) -> usize
200where
201    I: AsIndex,
202{
203    if size == 0 {
204        return 0; // Avoid modulo by zero
205    }
206    let wrapped = idx.index().rem_euclid(size as isize);
207    if wrapped < 0 {
208        (wrapped + size as isize) as usize
209    } else {
210        wrapped as usize
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_wrap_idx() {
220        assert_eq!(wrap_index(0, 3), 0_usize);
221        assert_eq!(wrap_index(3, 3), 0_usize);
222        assert_eq!(wrap_index(2 * 3, 3), 0_usize);
223        assert_eq!(wrap_index(0 - 3, 3), 0_usize);
224        assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
225
226        assert_eq!(wrap_index(1, 3), 1_usize);
227        assert_eq!(wrap_index(1 + 3, 3), 1_usize);
228        assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
229        assert_eq!(wrap_index(1 - 3, 3), 1_usize);
230        assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
231
232        assert_eq!(wrap_index(2, 3), 2_usize);
233        assert_eq!(wrap_index(2 + 3, 3), 2_usize);
234        assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
235        assert_eq!(wrap_index(2 - 3, 3), 2_usize);
236        assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
237    }
238
239    #[test]
240    fn test_canonicalize_dim() {
241        let wrap_scalar = false;
242        assert_eq!(canonicalize_dim(0, 3, wrap_scalar), 0_usize);
243        assert_eq!(canonicalize_dim(1, 3, wrap_scalar), 1_usize);
244        assert_eq!(canonicalize_dim(2, 3, wrap_scalar), 2_usize);
245
246        assert_eq!(canonicalize_dim(-1, 3, wrap_scalar), (3 - 1) as usize);
247        assert_eq!(canonicalize_dim(-2, 3, wrap_scalar), (3 - 2) as usize);
248        assert_eq!(canonicalize_dim(-3, 3, wrap_scalar), (3 - 3) as usize);
249
250        let wrap_scalar = true;
251        assert_eq!(canonicalize_dim(0, 0, wrap_scalar), 0);
252        assert_eq!(canonicalize_dim(-1, 0, wrap_scalar), 0);
253    }
254
255    #[test]
256    #[should_panic = "dimension index 0 used when rank is 0"]
257    fn test_canonicalize_dim_error_no_dims() {
258        let _d = canonicalize_dim(0, 0, false);
259    }
260
261    #[test]
262    #[should_panic = "dimension index 3 out of range: (-3..=2)"]
263    fn test_canonicalize_dim_error_too_big() {
264        let _d = canonicalize_dim(3, 3, false);
265    }
266    #[test]
267    #[should_panic = "dimension index -4 out of range: (-3..=2)"]
268    fn test_canonicalize_dim_error_too_small() {
269        let _d = canonicalize_dim(-4, 3, false);
270    }
271
272    #[test]
273    fn test_canonicalize_index() {
274        let wrap_scalar = false;
275        assert_eq!(canonicalize_index(0, 3, wrap_scalar), 0_usize);
276        assert_eq!(canonicalize_index(1, 3, wrap_scalar), 1_usize);
277        assert_eq!(canonicalize_index(2, 3, wrap_scalar), 2_usize);
278
279        assert_eq!(canonicalize_index(-1, 3, wrap_scalar), (3 - 1) as usize);
280        assert_eq!(canonicalize_index(-2, 3, wrap_scalar), (3 - 2) as usize);
281        assert_eq!(canonicalize_index(-3, 3, wrap_scalar), (3 - 3) as usize);
282
283        let wrap_scalar = true;
284        assert_eq!(canonicalize_index(0, 0, wrap_scalar), 0);
285        assert_eq!(canonicalize_index(-1, 0, wrap_scalar), 0);
286    }
287
288    #[test]
289    #[should_panic = "index 3 out of range: (-3..=2)"]
290    fn test_canonicalize_index_error_too_big() {
291        let _d = canonicalize_index(3, 3, false);
292    }
293}