cubek-test-utils 0.2.0-pre.3

CubeK: Test Utils
Documentation
use cubecl::zspace::{Shape, Strides};

#[derive(Debug, PartialEq, Eq, Default)]
pub enum StrideSpec {
    #[default]
    RowMajor,
    ColMajor,
    Custom(Vec<usize>),
}

impl StrideSpec {
    pub fn compute_strides(&self, shape: &Shape) -> Strides {
        let n = shape.len();
        match self {
            StrideSpec::RowMajor => {
                assert!(n >= 2, "RowMajor requires at least 2 dimensions");
                let mut strides = vec![0; n];
                strides[n - 1] = 1;
                for i in (0..n - 1).rev() {
                    strides[i] = strides[i + 1] * shape[i + 1];
                }
                Strides::new(&strides)
            }
            StrideSpec::ColMajor => {
                assert!(n >= 2, "ColMajor requires at least 2 dimensions");
                let mut strides = vec![0; n];
                strides[n - 2] = 1;
                strides[n - 1] = shape[n - 2];
                for i in (0..n - 2).rev() {
                    strides[i] = strides[i + 1] * shape[i + 1];
                }
                Strides::new(&strides)
            }
            StrideSpec::Custom(strides) => {
                assert!(
                    strides.len() == n,
                    "Custom strides must have the same rank as the shape"
                );
                strides.clone().into()
            }
        }
    }
}