use cubecl::{
zspace::{Shape, Strides},
{AutotuneKey, Runtime, quant::scheme::QuantScheme},
{client::ComputeClient, ir::StorageType},
};
use cubek_std::MatmulProblemSize;
use serde::{Deserialize, Serialize};
use cubecl::std::tensor::{MatrixBatchLayout, matrix_batch_layout};
use crate::definition::MatmulKind;
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
pub struct MatmulAutotuneKey {
pub definition: MatmulProblemDefinition,
pub analysis: MatmulAutotuneAnalysis,
}
const MAX_STRIDE_FACTOR: u32 = 10;
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
pub struct MatmulProblemDefinition {
#[autotune(anchor)]
pub m: usize,
#[autotune(anchor)]
pub n: usize,
#[autotune(anchor)]
pub k: usize,
pub lhs_pow2_factor: u8,
pub lhs_stride_factor: u8,
pub rhs_pow2_factor: u8,
pub rhs_stride_factor: u8,
pub elem_lhs: StorageType,
pub elem_rhs: StorageType,
pub elem_out: StorageType,
pub matrix_layout_lhs: MatrixBatchLayout,
pub matrix_layout_rhs: MatrixBatchLayout,
}
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
pub enum MatmulGlobalScale {
Large,
Medium,
Small,
}
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
pub struct MatmulAutotuneAnalysis {
pub scale_global: MatmulGlobalScale,
pub kind: MatmulKind,
}
impl MatmulGlobalScale {
pub fn from_size(m: usize, n: usize, k: usize) -> Self {
if m < 512 && k < 512 && n < 512 {
MatmulGlobalScale::Small
} else if m < 2048 && k < 2048 && n < 2048 {
MatmulGlobalScale::Medium
} else {
MatmulGlobalScale::Large
}
}
}
pub fn should_tune_double_buffering(fused: bool, key: &MatmulAutotuneKey) -> bool {
matches!(key.analysis.kind, MatmulKind::General)
&& match key.analysis.scale_global {
MatmulGlobalScale::Large => true,
MatmulGlobalScale::Medium => true,
MatmulGlobalScale::Small => fused,
}
}
impl MatmulAutotuneKey {
#[allow(clippy::too_many_arguments)]
pub fn generate<R: Runtime>(
_client: &ComputeClient<R>,
lhs_shape: &Shape,
rhs_shape: &Shape,
lhs_strides: &Strides,
rhs_strides: &Strides,
elem_lhs: StorageType,
elem_rhs: StorageType,
elem_out: StorageType,
lhs_scheme: Option<&QuantScheme>,
rhs_scheme: Option<&QuantScheme>,
) -> MatmulAutotuneKey {
let ndims = lhs_shape.len();
let m = lhs_shape[ndims - 2];
let k = lhs_shape[ndims - 1];
let n = rhs_shape[ndims - 1];
let matrix_layout_lhs = matrix_batch_layout(lhs_strides, lhs_scheme);
let matrix_layout_rhs = matrix_batch_layout(rhs_strides, rhs_scheme);
let kind = MatmulKind::from(MatmulProblemSize {
m: m as u32,
n: n as u32,
k: k as u32,
});
let lhs_pow2_factor = match matrix_layout_lhs {
MatrixBatchLayout::Contiguous => pow2_factor(k),
MatrixBatchLayout::MildlyPermuted { transposed, .. } => match transposed {
true => pow2_factor(m),
false => pow2_factor(k),
},
MatrixBatchLayout::HighlyPermuted => 0,
};
let rhs_pow2_factor = match matrix_layout_rhs {
MatrixBatchLayout::Contiguous => pow2_factor(n),
MatrixBatchLayout::MildlyPermuted { transposed, .. } => match transposed {
true => pow2_factor(k),
false => pow2_factor(n),
},
MatrixBatchLayout::HighlyPermuted => 0,
};
let lhs_stride_factor = match matrix_layout_lhs {
MatrixBatchLayout::Contiguous => stride_align(lhs_strides, ndims - 1, elem_lhs),
MatrixBatchLayout::MildlyPermuted {
transposed: true,
batch_swap: false,
} => stride_align(lhs_strides, ndims - 2, elem_lhs),
_ => 0,
};
let rhs_stride_factor = match matrix_layout_rhs {
MatrixBatchLayout::Contiguous => stride_align(rhs_strides, ndims - 1, elem_rhs),
MatrixBatchLayout::MildlyPermuted {
transposed: true,
batch_swap: false,
} => stride_align(rhs_strides, ndims - 2, elem_rhs),
_ => 0,
};
let definition = MatmulProblemDefinition::new(
m,
n,
k,
lhs_pow2_factor,
lhs_stride_factor,
rhs_pow2_factor,
rhs_stride_factor,
elem_lhs,
elem_rhs,
elem_out,
matrix_layout_lhs,
matrix_layout_rhs,
);
let analysis = MatmulAutotuneAnalysis {
scale_global: MatmulGlobalScale::from_size(m, n, k),
kind,
};
Self::new(definition, analysis)
}
}
fn stride_align(strides: &[usize], exclude_dim: usize, elem: StorageType) -> u8 {
let max = MAX_STRIDE_FACTOR;
let factor = strides
.iter()
.enumerate()
.filter(|(i, _)| *i != exclude_dim)
.map(|(_, it)| (*it * elem.size_bits()) / 8)
.map(|it| it.trailing_zeros())
.min()
.unwrap_or(max);
factor.min(max) as u8
}
fn pow2_factor(axis: usize) -> u8 {
axis.trailing_zeros().min(4) as u8
}