Skip to main content

cubecl_zspace/indexing/
wrapping.rs

1//! A module for indexing utility machinery.
2
3pub use super::type_conversion::AsIndex;
4pub use crate::errors::BoundsError;
5use crate::errors::IndexKind;
6#[allow(unused_imports)]
7use alloc::format;
8#[allow(unused_imports)]
9use alloc::string::{String, ToString};
10use core::fmt::Debug;
11
12/// Wraps an index with negative indexing support.
13#[derive(Debug)]
14pub struct IndexWrap {
15    kind: IndexKind,
16    wrap_scalar: bool,
17}
18
19impl IndexWrap {
20    /// Get an instance for wrapping negative indices.
21    pub fn index() -> Self {
22        Self {
23            kind: IndexKind::Element,
24            wrap_scalar: false,
25        }
26    }
27
28    /// Get an instance for wrapping negative dimensions.
29    pub fn dim() -> Self {
30        Self {
31            kind: IndexKind::Dimension,
32            wrap_scalar: false,
33        }
34    }
35
36    /// Set the policy for wrapping 0-size ranges.
37    ///
38    /// When ``size`` == 0:
39    ///   - if `wrap_scalar`; then ``size == 1``
40    ///   - otherwise; an error.
41    pub fn with_wrap_scalar(self, wrap_scalar: bool) -> Self {
42        Self {
43            wrap_scalar,
44            ..self
45        }
46    }
47
48    /// Wrap an index with negative indexing support.
49    pub fn try_wrap<I: AsIndex>(&self, idx: I, size: usize) -> Result<usize, BoundsError> {
50        try_wrap(idx, size, self.kind, self.wrap_scalar)
51    }
52
53    /// Wrap an index with negative indexing support.
54    pub fn expect_wrap<I: AsIndex>(&self, idx: I, size: usize) -> usize {
55        expect_wrap(idx, size, self.kind, self.wrap_scalar)
56    }
57
58    /// Short-form [`NegativeWrap::index().expect_wrap(idx, size)`].
59    pub fn expect_elem<I: AsIndex>(idx: I, size: usize) -> usize {
60        Self::index().expect_wrap(idx, size)
61    }
62
63    /// Short-form [`NegativeWrap::dim().expect_wrap(idx, size)`].
64    pub fn expect_dim<I: AsIndex>(idx: I, size: usize) -> usize {
65        Self::dim().expect_wrap(idx, size)
66    }
67}
68
69/// Wraps an index with negative indexing support.
70///
71/// ## Arguments
72/// - `idx` - The index to canonicalize.
73/// - `size` - The size of the index range.
74/// - `kind` - The kind of index (for error messages).
75/// - `size_name` - The name of the size (for error messages).
76/// - `wrap_scalar` - If true, treat 0-size ranges as having size 1.
77///
78/// ## Returns
79///
80/// A `Result<usize, BoundsError>` of the canonicalized index.
81pub fn expect_wrap<I>(idx: I, size: usize, kind: IndexKind, wrap_scalar: bool) -> usize
82where
83    I: AsIndex,
84{
85    try_wrap(idx, size, kind, wrap_scalar).expect("valid index")
86}
87
88/// Wraps an index with negative indexing support.
89///
90/// ## Arguments
91/// - `idx` - The index to canonicalize.
92/// - `size` - The size of the index range.
93/// - `kind` - The kind of index (for error messages).
94/// - `size_name` - The name of the size (for error messages).
95/// - `wrap_scalar` - If true, treat 0-size ranges as having size 1.
96///
97/// ## Returns
98///
99/// A `Result<usize, BoundsError>` of the canonicalized index.
100pub fn try_wrap<I>(
101    idx: I,
102    size: usize,
103    kind: IndexKind,
104    wrap_scalar: bool,
105) -> Result<usize, BoundsError>
106where
107    I: AsIndex,
108{
109    let source_idx = idx.as_index();
110    let source_size = size;
111
112    let size = if source_size > 0 {
113        source_size
114    } else {
115        if !wrap_scalar {
116            return Err(BoundsError::index(kind, source_idx, 0..0));
117        }
118        1
119    };
120
121    if source_idx >= 0 && (source_idx as usize) < size {
122        return Ok(source_idx as usize);
123    }
124
125    let _idx = if source_idx < 0 {
126        source_idx + size as isize
127    } else {
128        source_idx
129    };
130
131    if _idx < 0 || (_idx as usize) >= size {
132        let rank = size as isize;
133
134        return Err(BoundsError::index(kind, source_idx, 0..rank));
135    }
136
137    Ok(_idx as usize)
138}
139
140/// Wraps a dimension index to be within the bounds of the dimension size.
141///
142/// ## Arguments
143///
144/// * `idx` - The dimension index to wrap.
145/// * `size` - The size of the dimension.
146///
147/// ## Returns
148///
149/// The positive wrapped dimension index.
150#[inline]
151#[must_use]
152pub fn wrap_index<I>(idx: I, size: usize) -> usize
153where
154    I: AsIndex,
155{
156    if size == 0 {
157        return 0; // Avoid modulo by zero
158    }
159    let wrapped = idx.as_index().rem_euclid(size as isize);
160    if wrapped < 0 {
161        (wrapped + size as isize) as usize
162    } else {
163        wrapped as usize
164    }
165}
166
167/// Compute the ravel index for the given coordinates.
168///
169/// This returns the row-major order raveling:
170/// * `strides[-1] = 1`
171/// * `strides[i] = strides[i+1] * dims[i+1]`
172/// * `dim_strides = coords * strides`
173/// * `ravel = sum(dim_strides)`
174///
175/// # Arguments
176/// - `indices`: the index for each dimension; must be the same length as `shape`.
177/// - `shape`: the shape of each dimension; be the same length as `indices`.
178///
179/// # Returns
180/// - the ravel offset index.
181pub fn ravel_index<I: AsIndex>(indices: &[I], shape: &[usize]) -> usize {
182    assert_eq!(
183        shape.len(),
184        indices.len(),
185        "Coordinate rank mismatch: expected {}, got {}",
186        shape.len(),
187        indices.len(),
188    );
189
190    let mut ravel_idx = 0;
191    let mut stride = 1;
192
193    for (i, &dim) in shape.iter().enumerate().rev() {
194        let idx = indices[i];
195        let coord = IndexWrap::index().expect_wrap(idx, dim);
196        ravel_idx += coord * stride;
197        stride *= dim;
198    }
199
200    ravel_idx
201}
202
203#[cfg(test)]
204#[allow(clippy::identity_op, reason = "useful for clarity")]
205mod tests {
206    use super::*;
207    use crate::errors::IndexKind;
208    use alloc::vec;
209
210    #[test]
211    fn test_ravel() {
212        let shape = vec![2, 3, 4, 5];
213
214        assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0);
215        assert_eq!(
216            ravel_index(&[1, 2, 3, 4], &shape),
217            1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4
218        );
219    }
220
221    #[test]
222    fn test_wrap_idx() {
223        assert_eq!(wrap_index(0, 3), 0_usize);
224        assert_eq!(wrap_index(3, 3), 0_usize);
225        assert_eq!(wrap_index(2 * 3, 3), 0_usize);
226        assert_eq!(wrap_index(0 - 3, 3), 0_usize);
227        assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize);
228
229        assert_eq!(wrap_index(1, 3), 1_usize);
230        assert_eq!(wrap_index(1 + 3, 3), 1_usize);
231        assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize);
232        assert_eq!(wrap_index(1 - 3, 3), 1_usize);
233        assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize);
234
235        assert_eq!(wrap_index(2, 3), 2_usize);
236        assert_eq!(wrap_index(2 + 3, 3), 2_usize);
237        assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize);
238        assert_eq!(wrap_index(2 - 3, 3), 2_usize);
239        assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize);
240    }
241
242    #[test]
243    fn test_negative_wrap() {
244        assert_eq!(IndexWrap::index().expect_wrap(0, 3), 0);
245        assert_eq!(IndexWrap::index().expect_wrap(1, 3), 1);
246        assert_eq!(IndexWrap::index().expect_wrap(2, 3), 2);
247        assert_eq!(IndexWrap::index().expect_wrap(-1, 3), 2);
248        assert_eq!(IndexWrap::index().expect_wrap(-2, 3), 1);
249        assert_eq!(IndexWrap::index().expect_wrap(-3, 3), 0);
250
251        assert_eq!(IndexWrap::dim().expect_wrap(0, 3), 0);
252        assert_eq!(IndexWrap::dim().expect_wrap(1, 3), 1);
253        assert_eq!(IndexWrap::dim().expect_wrap(2, 3), 2);
254        assert_eq!(IndexWrap::dim().expect_wrap(-1, 3), 2);
255        assert_eq!(IndexWrap::dim().expect_wrap(-2, 3), 1);
256        assert_eq!(IndexWrap::dim().expect_wrap(-3, 3), 0);
257
258        assert_eq!(
259            IndexWrap::index().try_wrap(3, 3),
260            Err(BoundsError::Index {
261                kind: IndexKind::Element,
262                index: 3,
263                bounds: 0..3,
264            })
265        );
266        assert_eq!(
267            IndexWrap::index().try_wrap(-4, 3),
268            Err(BoundsError::Index {
269                kind: IndexKind::Element,
270                index: -4,
271                bounds: 0..3,
272            })
273        );
274        assert_eq!(
275            IndexWrap::dim().try_wrap(3, 3),
276            Err(BoundsError::Index {
277                kind: IndexKind::Dimension,
278                index: 3,
279                bounds: 0..3,
280            })
281        );
282        assert_eq!(
283            IndexWrap::dim().try_wrap(-4, 3),
284            Err(BoundsError::Index {
285                kind: IndexKind::Dimension,
286                index: -4,
287                bounds: 0..3,
288            })
289        );
290    }
291
292    #[test]
293    fn test_negative_wrap_scalar() {
294        assert_eq!(
295            IndexWrap::index().try_wrap(0, 0),
296            Err(BoundsError::Index {
297                kind: IndexKind::Element,
298                index: 0,
299                bounds: 0..0,
300            })
301        );
302
303        assert_eq!(
304            IndexWrap::index().with_wrap_scalar(true).expect_wrap(0, 0),
305            0
306        );
307        assert_eq!(
308            IndexWrap::index().with_wrap_scalar(true).expect_wrap(-1, 0),
309            0
310        );
311
312        assert_eq!(
313            IndexWrap::index().with_wrap_scalar(false).try_wrap(1, 0),
314            Err(BoundsError::Index {
315                kind: IndexKind::Element,
316                index: 1,
317                bounds: 0..0,
318            })
319        );
320    }
321}