use cubecl_core::{quant::scheme::QuantScheme, zspace::Strides};
use serde::{Deserialize, Serialize};
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
pub enum MatrixBatchLayout {
Contiguous,
MildlyPermuted {
transposed: bool,
batch_swap: bool,
},
HighlyPermuted,
}
pub fn matrix_batch_layout(strides: &Strides, scheme: Option<&QuantScheme>) -> MatrixBatchLayout {
let packing_dim = scheme.and_then(|s| s.packing_dim());
let rank = strides.len();
if rank <= 1 {
return MatrixBatchLayout::Contiguous;
}
let mut transposed = false;
let mut batch_swap = false;
let row_stride = strides[rank - 2];
let col_stride = strides[rank - 1];
if row_stride == 0 || col_stride == 0 {
return MatrixBatchLayout::HighlyPermuted;
}
if let Some(packing_dim) = packing_dim {
match packing_dim {
0 => {}
1 => {
transposed = true;
}
_ => {
return MatrixBatchLayout::HighlyPermuted;
}
}
} else if row_stride < col_stride {
transposed = true;
}
let mut previous_stride = row_stride;
for d in 0..rank - 2 {
let current_stride = strides[rank - 3 - d];
if current_stride < row_stride || current_stride < col_stride {
if current_stride == 0 {
batch_swap = true;
} else {
return MatrixBatchLayout::HighlyPermuted;
}
}
if current_stride < previous_stride {
batch_swap = true;
}
previous_stride = current_stride;
}
if transposed || batch_swap {
MatrixBatchLayout::MildlyPermuted {
transposed,
batch_swap,
}
} else {
MatrixBatchLayout::Contiguous
}
}
#[cfg(test)]
mod tests {
use cubecl_core::zspace::strides;
use super::*;
#[test]
fn layout_is_contiguous() {
let strides = strides![8, 4, 2, 1];
assert_eq!(
matrix_batch_layout(&strides, None),
MatrixBatchLayout::Contiguous
);
}
#[test]
fn vector_is_contiguous() {
let strides = strides![1];
assert_eq!(
matrix_batch_layout(&strides, None),
MatrixBatchLayout::Contiguous
)
}
#[test]
fn layout_is_transposed_only() {
let strides = strides![8, 4, 1, 2];
if let MatrixBatchLayout::MildlyPermuted {
transposed,
batch_swap,
} = matrix_batch_layout(&strides, None)
{
assert!(transposed && !batch_swap);
} else {
unreachable!()
}
}
#[test]
fn layout_has_swapped_batches_only() {
let strides = strides![4, 8, 2, 1];
if let MatrixBatchLayout::MildlyPermuted {
transposed,
batch_swap,
} = matrix_batch_layout(&strides, None)
{
assert!(!transposed && batch_swap);
} else {
unreachable!()
}
}
#[test]
fn layout_has_swapped_batches_and_is_transposed() {
let strides = strides![4, 8, 1, 2];
if let MatrixBatchLayout::MildlyPermuted {
transposed,
batch_swap,
} = matrix_batch_layout(&strides, None)
{
assert!(transposed && batch_swap);
} else {
unreachable!()
}
}
#[test]
fn layout_has_batch_swapped_with_row() {
let strides = strides![8, 2, 4, 1];
assert_eq!(
matrix_batch_layout(&strides, None),
MatrixBatchLayout::HighlyPermuted
);
}
#[test]
fn layout_has_batch_swapped_with_col() {
let strides = strides![1, 4, 2, 8];
assert_eq!(
matrix_batch_layout(&strides, None),
MatrixBatchLayout::HighlyPermuted
);
}
#[test]
fn layout_has_multiple_broadcasted_dims() {
let strides = strides![0, 0, 1];
assert_eq!(
matrix_batch_layout(&strides, None),
MatrixBatchLayout::HighlyPermuted
);
}
#[test]
fn layout_has_row_broadcasted() {
let strides = strides![0, 1];
assert_eq!(
matrix_batch_layout(&strides, None),
MatrixBatchLayout::HighlyPermuted
);
}
#[test]
fn layout_has_col_broadcasted() {
let strides = strides![1, 0];
assert_eq!(
matrix_batch_layout(&strides, None),
MatrixBatchLayout::HighlyPermuted
);
}
#[test]
fn layout_has_batch_broadcasted() {
let strides = strides![0, 4, 1];
if let MatrixBatchLayout::MildlyPermuted {
transposed,
batch_swap,
} = matrix_batch_layout(&strides, None)
{
assert!(!transposed && batch_swap);
} else {
unreachable!()
}
}
#[test]
fn layout_has_multiple_batch_broadcasted() {
let strides = strides![0, 0, 4, 1];
if let MatrixBatchLayout::MildlyPermuted {
transposed,
batch_swap,
} = matrix_batch_layout(&strides, None)
{
assert!(!transposed && batch_swap);
} else {
unreachable!()
}
}
}