cubek_test_utils/test_tensor/
strides.rs1use cubecl::zspace::{Shape, Strides};
2
3#[derive(Debug, PartialEq, Eq, Default)]
4pub enum StrideSpec {
5 #[default]
6 RowMajor,
7 ColMajor,
8 Custom(Vec<usize>),
9}
10
11pub fn physical_extent(shape: &Shape, strides: &Strides) -> usize {
18 let mut max_offset = 0usize;
19 for (s, d) in strides.iter().zip(shape.iter()) {
20 if *d > 0 && *s > 0 {
21 max_offset += (d - 1) * s;
22 }
23 }
24 max_offset + 1
25}
26
27#[cfg(test)]
28mod tests {
29 use super::*;
30
31 #[test]
32 fn physical_extent_contiguous_row_major() {
33 let shape = Shape::from(vec![2, 3]);
35 let strides = Strides::new(&[3, 1]);
36 assert_eq!(physical_extent(&shape, &strides), 6);
37 }
38
39 #[test]
40 fn physical_extent_jumpy_strides_exceed_logical() {
41 let shape = Shape::from(vec![256, 256]);
44 let strides = Strides::new(&[512, 1]);
45 assert_eq!(physical_extent(&shape, &strides), 130816);
46 assert!(physical_extent(&shape, &strides) > 256 * 256);
48 }
49
50 #[test]
51 fn physical_extent_broadcast_strides_undercount_logical() {
52 let shape = Shape::from(vec![4, 3]);
56 let strides = Strides::new(&[0, 1]);
57 assert_eq!(physical_extent(&shape, &strides), 3);
58 assert!(physical_extent(&shape, &strides) < 4 * 3);
60 }
61}
62
63impl StrideSpec {
64 pub fn compute_strides(&self, shape: &Shape) -> Strides {
65 let n = shape.len();
66 match self {
67 StrideSpec::RowMajor => {
68 assert!(n >= 2, "RowMajor requires at least 2 dimensions");
69 let mut strides = vec![0; n];
70 strides[n - 1] = 1;
71 for i in (0..n - 1).rev() {
72 strides[i] = strides[i + 1] * shape[i + 1];
73 }
74 Strides::new(&strides)
75 }
76 StrideSpec::ColMajor => {
77 assert!(n >= 2, "ColMajor requires at least 2 dimensions");
78 let mut strides = vec![0; n];
79 strides[n - 2] = 1;
80 strides[n - 1] = shape[n - 2];
81 for i in (0..n - 2).rev() {
82 strides[i] = strides[i + 1] * shape[i + 1];
83 }
84 Strides::new(&strides)
85 }
86 StrideSpec::Custom(strides) => {
87 assert!(
88 strides.len() == n,
89 "Custom strides must have the same rank as the shape"
90 );
91 strides.clone().into()
92 }
93 }
94 }
95}