use cubecl::{
prelude::*,
quant::scheme::QuantScheme,
zspace::{Strides, strides},
};
use crate::InvalidConfigError;
#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug, Default)]
pub enum MatrixLayout {
#[default]
RowMajor,
ColMajor,
}
impl MatrixLayout {
pub fn from_shape_and_strides(
shape: &[usize],
strides: &[usize],
scheme: Option<&QuantScheme>,
) -> Result<Self, InvalidConfigError> {
assert!(
shape.len() >= 2 && shape.len() == strides.len(),
"Shape/stride mismatch or not a matrix"
);
if let Some(packing_dim) = scheme.and_then(|s| s.packing_dim()) {
if packing_dim == 0 {
return Ok(MatrixLayout::RowMajor);
}
if packing_dim == 1 {
return Ok(MatrixLayout::ColMajor);
}
return Err(Box::new(format!(
"Invalid or non-contiguous matrix layout: packing_dim={packing_dim:?}"
)));
}
let n = shape.len();
let outer = shape[n - 2];
let inner = shape[n - 1];
let stride_outer = strides[n - 2];
let stride_inner = strides[n - 1];
if (stride_inner == 1) && stride_outer >= inner {
return Ok(MatrixLayout::RowMajor);
}
if (stride_outer == 1) && stride_inner >= outer {
return Ok(MatrixLayout::ColMajor);
}
Err(Box::new(format!(
"Invalid or non-contiguous matrix layout: shape={shape:?}, strides={strides:?}",
)))
}
pub fn to_strides(&self, shape: &[usize]) -> Strides {
assert!(shape.len() >= 2, "Shape must have at least 2 dimensions");
let n = shape.len();
let mut strides = strides![0; n];
match self {
MatrixLayout::RowMajor => {
strides[n - 1] = 1; strides[n - 2] = shape[n - 1]; }
MatrixLayout::ColMajor => {
strides[n - 2] = 1; strides[n - 1] = shape[n - 2]; }
}
for i in (0..n - 2).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
}
#[cube]
pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
match layout {
MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
}
}
#[cube]
pub fn from_cmma_layout(#[comptime] layout: cmma::MatrixLayout) -> comptime_type!(MatrixLayout) {
match layout {
cmma::MatrixLayout::RowMajor => MatrixLayout::RowMajor,
cmma::MatrixLayout::ColMajor => MatrixLayout::ColMajor,
cmma::MatrixLayout::Undefined => MatrixLayout::RowMajor,
}
}