cubek_test_utils/test_tensor/
strides.rs

1#[derive(Debug, PartialEq, Eq, Default)]
2pub enum StrideSpec {
3    #[default]
4    RowMajor,
5    ColMajor,
6    Custom(Vec<usize>),
7}
8
9impl StrideSpec {
10    pub fn compute_strides(&self, shape: &[usize]) -> Vec<usize> {
11        let n = shape.len();
12
13        match self {
14            StrideSpec::RowMajor => {
15                assert!(n >= 2, "RowMajor requires at least 2 dimensions");
16                let mut strides = vec![0; n];
17                strides[n - 1] = 1;
18                for i in (0..n - 1).rev() {
19                    strides[i] = strides[i + 1] * shape[i + 1];
20                }
21                strides
22            }
23            StrideSpec::ColMajor => {
24                assert!(n >= 2, "ColMajor requires at least 2 dimensions");
25                let mut strides = vec![0; n];
26                strides[n - 2] = 1;
27                strides[n - 1] = shape[n - 2];
28                for i in (0..n - 2).rev() {
29                    strides[i] = strides[i + 1] * shape[i + 1];
30                }
31                strides
32            }
33            StrideSpec::Custom(strides) => {
34                assert!(
35                    strides.len() == n,
36                    "Custom strides must have the same rank as the shape"
37                );
38                strides.clone()
39            }
40        }
41    }
42}