burn_std/tensor/
indexing.rs

1//! A module for indexing utility machinery.
2
3use crate::IndexKind;
4pub use crate::errors::BoundsError;
5#[allow(unused_imports)]
6use alloc::format;
7#[allow(unused_imports)]
8use alloc::string::{String, ToString};
9use core::fmt::Debug;
10
11pub use crate::tensor::index_conversion::AsIndex;
12
13/// Wraps an index with negative indexing support.
14#[derive(Debug)]
15pub struct IndexWrap {
16    kind: IndexKind,
17    wrap_scalar: bool,
18}
19
20impl IndexWrap {
21    /// Get an instance for wrapping negative indices.
22    pub fn index() -> Self {
23        Self {
24            kind: IndexKind::Element,
25            wrap_scalar: false,
26        }
27    }
28
29    /// Get an instance for wrapping negative dimensions.
30    pub fn dim() -> Self {
31        Self {
32            kind: IndexKind::Dimension,
33            wrap_scalar: false,
34        }
35    }
36
37    /// Set the policy for wrapping 0-size ranges.
38    ///
39    /// When ``size`` == 0:
40    ///   - if `wrap_scalar`; then ``size == 1``
41    ///   - otherwise; an error.
42    pub fn with_wrap_scalar(self, wrap_scalar: bool) -> Self {
43        Self {
44            wrap_scalar,
45            ..self
46        }
47    }
48
49    /// Wrap an index with negative indexing support.
50    pub fn try_wrap<I: AsIndex>(&self, idx: I, size: usize) -> Result<usize, BoundsError> {
51        try_wrap(idx, size, self.kind, self.wrap_scalar)
52    }
53
54    /// Wrap an index with negative indexing support.
55    pub fn expect_wrap<I: AsIndex>(&self, idx: I, size: usize) -> usize {
56        expect_wrap(idx, size, self.kind, self.wrap_scalar)
57    }
58
59    /// Short-form [`NegativeWrap::index().expect_wrap(idx, size)`].
60    pub fn expect_elem<I: AsIndex>(idx: I, size: usize) -> usize {
61        Self::index().expect_wrap(idx, size)
62    }
63
64    /// Short-form [`NegativeWrap::dim().expect_wrap(idx, size)`].
65    pub fn expect_dim<I: AsIndex>(idx: I, size: usize) -> usize {
66        Self::dim().expect_wrap(idx, size)
67    }
68}
69
70/// Wraps an index with negative indexing support.
71///
72/// ## Arguments
73/// - `idx` - The index to canonicalize.
74/// - `size` - The size of the index range.
75/// - `kind` - The kind of index (for error messages).
76/// - `size_name` - The name of the size (for error messages).
77/// - `wrap_scalar` - If true, treat 0-size ranges as having size 1.
78///
79/// ## Returns
80///
81/// A `Result<usize, BoundsError>` of the canonicalized index.
82pub fn expect_wrap<I>(idx: I, size: usize, kind: IndexKind, wrap_scalar: bool) -> usize
83where
84    I: AsIndex,
85{
86    try_wrap(idx, size, kind, wrap_scalar).expect("valid index")
87}
88
89/// Wraps an index with negative indexing support.
90///
91/// ## Arguments
92/// - `idx` - The index to canonicalize.
93/// - `size` - The size of the index range.
94/// - `kind` - The kind of index (for error messages).
95/// - `size_name` - The name of the size (for error messages).
96/// - `wrap_scalar` - If true, treat 0-size ranges as having size 1.
97///
98/// ## Returns
99///
100/// A `Result<usize, BoundsError>` of the canonicalized index.
101pub fn try_wrap<I>(
102    idx: I,
103    size: usize,
104    kind: IndexKind,
105    wrap_scalar: bool,
106) -> Result<usize, BoundsError>
107where
108    I: AsIndex,
109{
110    let source_idx = idx.as_index();
111    let source_size = size;
112
113    let size = if source_size > 0 {
114        source_size
115    } else {
116        if !wrap_scalar {
117            return Err(BoundsError::index(kind, source_idx, 0..0));
118        }
119        1
120    };
121
122    if source_idx >= 0 && (source_idx as usize) < size {
123        return Ok(source_idx as usize);
124    }
125
126    let _idx = if source_idx < 0 {
127        source_idx + size as isize
128    } else {
129        source_idx
130    };
131
132    if _idx < 0 || (_idx as usize) >= size {
133        let rank = size as isize;
134
135        return Err(BoundsError::index(kind, source_idx, 0..rank));
136    }
137
138    Ok(_idx as usize)
139}
140
141/// Wraps a dimension index to be within the bounds of the dimension size.
142///
143/// ## Arguments
144///
145/// * `idx` - The dimension index to wrap.
146/// * `size` - The size of the dimension.
147///
148/// ## Returns
149///
150/// The positive wrapped dimension index.
151#[inline]
152#[must_use]
153pub fn wrap_index<I>(idx: I, size: usize) -> usize
154where
155    I: AsIndex,
156{
157    if size == 0 {
158        return 0; // Avoid modulo by zero
159    }
160    let wrapped = idx.as_index().rem_euclid(size as isize);
161    if wrapped < 0 {
162        (wrapped + size as isize) as usize
163    } else {
164        wrapped as usize
165    }
166}
167
168/// Compute the ravel index for the given coordinates.
169///
170/// This returns the row-major order raveling:
171/// * `strides[-1] = 1`
172/// * `strides[i] = strides[i+1] * dims[i+1]`
173/// * `dim_strides = coords * strides`
174/// * `ravel = sum(dim_strides)`
175///
176/// # Arguments
177/// - `indices`: the index for each dimension; must be the same length as `shape`.
178/// - `shape`: the shape of each dimension; be the same length as `indices`.
179///
180/// # Returns
181/// - the ravel offset index.
182pub fn ravel_index<I: AsIndex>(indices: &[I], shape: &[usize]) -> usize {
183    assert_eq!(
184        shape.len(),
185        indices.len(),
186        "Coordinate rank mismatch: expected {}, got {}",
187        shape.len(),
188        indices.len(),
189    );
190
191    let mut ravel_idx = 0;
192    let mut stride = 1;
193
194    for (i, &dim) in shape.iter().enumerate().rev() {
195        let idx = indices[i];
196        let coord = IndexWrap::index().expect_wrap(idx, dim);
197        ravel_idx += coord * stride;
198        stride *= dim;
199    }
200
201    ravel_idx
202}
203
204#[cfg(test)]
205#[allow(clippy::identity_op, reason = "useful for clarity")]
206mod tests {
207    use super::*;
208    use alloc::vec;
209
210    #[test]
211    fn test_ravel() {
212        let shape = vec![2, 3, 4, 5];
213
214        assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0);
215        assert_eq!(
216            ravel_index(&[1, 2, 3, 4], &shape),
217            1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
218        );
219    }
220
221    #[test]
222    fn test_wrap_idx() {
223        assert_eq!(wrap_index(0, 3), 0_usize);
224        assert_eq!(wrap_index(3, 3), 0_usize);
225        assert_eq!(wrap_index(2 * 3, 3), 0_usize);
226        assert_eq!(wrap_index(0 - 3, 3), 0_usize);
227        assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
228
229        assert_eq!(wrap_index(1, 3), 1_usize);
230        assert_eq!(wrap_index(1 + 3, 3), 1_usize);
231        assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
232        assert_eq!(wrap_index(1 - 3, 3), 1_usize);
233        assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
234
235        assert_eq!(wrap_index(2, 3), 2_usize);
236        assert_eq!(wrap_index(2 + 3, 3), 2_usize);
237        assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
238        assert_eq!(wrap_index(2 - 3, 3), 2_usize);
239        assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
240    }
241
242    #[test]
243    fn test_negative_wrap() {
244        assert_eq!(IndexWrap::index().expect_wrap(0, 3), 0);
245        assert_eq!(IndexWrap::index().expect_wrap(1, 3), 1);
246        assert_eq!(IndexWrap::index().expect_wrap(2, 3), 2);
247        assert_eq!(IndexWrap::index().expect_wrap(-1, 3), 2);
248        assert_eq!(IndexWrap::index().expect_wrap(-2, 3), 1);
249        assert_eq!(IndexWrap::index().expect_wrap(-3, 3), 0);
250
251        assert_eq!(IndexWrap::dim().expect_wrap(0, 3), 0);
252        assert_eq!(IndexWrap::dim().expect_wrap(1, 3), 1);
253        assert_eq!(IndexWrap::dim().expect_wrap(2, 3), 2);
254        assert_eq!(IndexWrap::dim().expect_wrap(-1, 3), 2);
255        assert_eq!(IndexWrap::dim().expect_wrap(-2, 3), 1);
256        assert_eq!(IndexWrap::dim().expect_wrap(-3, 3), 0);
257
258        assert_eq!(
259            IndexWrap::index().try_wrap(3, 3),
260            Err(BoundsError::Index {
261                kind: IndexKind::Element,
262                index: 3,
263                bounds: 0..3,
264            })
265        );
266        assert_eq!(
267            IndexWrap::index().try_wrap(-4, 3),
268            Err(BoundsError::Index {
269                kind: IndexKind::Element,
270                index: -4,
271                bounds: 0..3,
272            })
273        );
274        assert_eq!(
275            IndexWrap::dim().try_wrap(3, 3),
276            Err(BoundsError::Index {
277                kind: IndexKind::Dimension,
278                index: 3,
279                bounds: 0..3,
280            })
281        );
282        assert_eq!(
283            IndexWrap::dim().try_wrap(-4, 3),
284            Err(BoundsError::Index {
285                kind: IndexKind::Dimension,
286                index: -4,
287                bounds: 0..3,
288            })
289        );
290    }
291
292    #[test]
293    fn test_negative_wrap_scalar() {
294        assert_eq!(
295            IndexWrap::index().try_wrap(0, 0),
296            Err(BoundsError::Index {
297                kind: IndexKind::Element,
298                index: 0,
299                bounds: 0..0,
300            })
301        );
302
303        assert_eq!(
304            IndexWrap::index().with_wrap_scalar(true).expect_wrap(0, 0),
305            0
306        );
307        assert_eq!(
308            IndexWrap::index().with_wrap_scalar(true).expect_wrap(-1, 0),
309            0
310        );
311
312        assert_eq!(
313            IndexWrap::index().with_wrap_scalar(false).try_wrap(1, 0),
314            Err(BoundsError::Index {
315                kind: IndexKind::Element,
316                index: 1,
317                bounds: 0..0,
318            })
319        );
320    }
321}