cubek_test_utils/test_tensor/
strides.rs1#[derive(Debug, PartialEq, Eq, Default)]
2pub enum StrideSpec {
3 #[default]
4 RowMajor,
5 ColMajor,
6 Custom(Vec<usize>),
7}
8
9impl StrideSpec {
10 pub fn compute_strides(&self, shape: &[usize]) -> Vec<usize> {
11 let n = shape.len();
12
13 match self {
14 StrideSpec::RowMajor => {
15 assert!(n >= 2, "RowMajor requires at least 2 dimensions");
16 let mut strides = vec![0; n];
17 strides[n - 1] = 1;
18 for i in (0..n - 1).rev() {
19 strides[i] = strides[i + 1] * shape[i + 1];
20 }
21 strides
22 }
23 StrideSpec::ColMajor => {
24 assert!(n >= 2, "ColMajor requires at least 2 dimensions");
25 let mut strides = vec![0; n];
26 strides[n - 2] = 1;
27 strides[n - 1] = shape[n - 2];
28 for i in (0..n - 2).rev() {
29 strides[i] = strides[i + 1] * shape[i + 1];
30 }
31 strides
32 }
33 StrideSpec::Custom(strides) => {
34 assert!(
35 strides.len() == n,
36 "Custom strides must have the same rank as the shape"
37 );
38 strides.clone()
39 }
40 }
41 }
42}