Skip to main content

cubek_test_utils/test_tensor/
strides.rs

1use cubecl::zspace::{Shape, Strides};
2
3#[derive(Debug, PartialEq, Eq, Default)]
4pub enum StrideSpec {
5    #[default]
6    RowMajor,
7    ColMajor,
8    Custom(Vec<usize>),
9}
10
11/// Number of elements in the physical buffer required to cover every logical
12/// index in `shape` under `strides`, assuming element 0 is at offset 0.
13///
14/// Exceeds `shape.iter().product()` for jumpy strides (e.g. a slice stepping
15/// over padding) and is less than it for broadcast strides (a stride of 0
16/// makes every index in that dim share the same physical offset).
17pub fn physical_extent(shape: &Shape, strides: &Strides) -> usize {
18    let mut max_offset = 0usize;
19    for (s, d) in strides.iter().zip(shape.iter()) {
20        if *d > 0 && *s > 0 {
21            max_offset += (d - 1) * s;
22        }
23    }
24    max_offset + 1
25}
26
27#[cfg(test)]
28mod tests {
29    use super::*;
30
31    #[test]
32    fn physical_extent_contiguous_row_major() {
33        // Row-major 2x3 → strides (3, 1) → 6 elements covered.
34        let shape = Shape::from(vec![2, 3]);
35        let strides = Strides::new(&[3, 1]);
36        assert_eq!(physical_extent(&shape, &strides), 6);
37    }
38
39    #[test]
40    fn physical_extent_jumpy_strides_exceed_logical() {
41        // 256x256 logical view of a wider 256x512 buffer (stride 512 on dim 0).
42        // Last reachable offset is 255*512 + 255*1 = 130815 → +1 = 130816.
43        let shape = Shape::from(vec![256, 256]);
44        let strides = Strides::new(&[512, 1]);
45        assert_eq!(physical_extent(&shape, &strides), 130816);
46        // And it strictly exceeds the logical element count.
47        assert!(physical_extent(&shape, &strides) > 256 * 256);
48    }
49
50    #[test]
51    fn physical_extent_broadcast_strides_undercount_logical() {
52        // Broadcast dim: stride 0 means every index along that dim shares the
53        // same physical offset. A 4x3 tensor broadcasting dim 0 only needs 3
54        // elements of physical storage, not 12.
55        let shape = Shape::from(vec![4, 3]);
56        let strides = Strides::new(&[0, 1]);
57        assert_eq!(physical_extent(&shape, &strides), 3);
58        // ...and it's less than the logical element count.
59        assert!(physical_extent(&shape, &strides) < 4 * 3);
60    }
61}
62
63impl StrideSpec {
64    pub fn compute_strides(&self, shape: &Shape) -> Strides {
65        let n = shape.len();
66        match self {
67            StrideSpec::RowMajor => {
68                assert!(n >= 2, "RowMajor requires at least 2 dimensions");
69                let mut strides = vec![0; n];
70                strides[n - 1] = 1;
71                for i in (0..n - 1).rev() {
72                    strides[i] = strides[i + 1] * shape[i + 1];
73                }
74                Strides::new(&strides)
75            }
76            StrideSpec::ColMajor => {
77                assert!(n >= 2, "ColMajor requires at least 2 dimensions");
78                let mut strides = vec![0; n];
79                strides[n - 2] = 1;
80                strides[n - 1] = shape[n - 2];
81                for i in (0..n - 2).rev() {
82                    strides[i] = strides[i + 1] * shape[i + 1];
83                }
84                Strides::new(&strides)
85            }
86            StrideSpec::Custom(strides) => {
87                assert!(
88                    strides.len() == n,
89                    "Custom strides must have the same rank as the shape"
90                );
91                strides.clone().into()
92            }
93        }
94    }
95}