use crate::{
components::{
global::{GlobalReaderConfig, SharedGlobalMatmulConfig, memory::GlobalIterator},
stage::{StageConfig, StageFamily, TilingLayout},
},
definition::{MatmulElems, MatmulProblem, MatmulTypes, StageIdent},
};
use cubecl::{
ir::{BarrierLevel, DeviceProperties, OpaqueType, SemanticType},
prelude::*,
};
use cubek_std::{
stage::{StageMemoryConfig, SwizzleMode},
{InvalidConfigError, MatrixLayout},
};
#[cube]
pub trait LoadingJob<
EG: Numeric,
NG: Size,
ES: Numeric,
NS: Size,
TL: TilingLayout,
S: SyncStrategy,
>: CubeType + Clone
{
type Stage: StageFamily;
fn execute_task(
this: &mut Self,
#[comptime] task_id: u32,
global_iter: &GlobalIterator<Vector<EG, NG>>,
stage: &mut <Self::Stage as StageFamily>::Stage<ES, NS, TL>,
barrier: &mut S::Barrier,
#[comptime] config: GlobalReaderConfig,
);
fn task_count(this: &Self) -> comptime_type!(u32);
}
#[cube]
pub trait SyncStrategy {
type Barrier: CubeType + Clone;
fn create_barrier() -> Self::Barrier;
fn sync<MP: MatmulTypes, S: StageConfig>(
barrier: &mut Self::Barrier,
#[comptime] config: SharedGlobalMatmulConfig<S>,
);
}
pub trait LoadingValidation {
fn validate_with_config(
device_props: &DeviceProperties,
config: &GlobalReaderConfig,
) -> Result<(), InvalidConfigError>;
fn validate_with_problem(
problem: &MatmulProblem,
dtypes: &MatmulElems,
ident: StageIdent,
) -> Result<(), InvalidConfigError>;
}
pub fn validate_async_barrier(device_props: &DeviceProperties) -> Result<(), InvalidConfigError> {
if !device_props
.features
.supports_type(OpaqueType::Barrier(BarrierLevel::Cube))
{
return Err(Box::new(
"Async barrier instructions are not available on the current device",
));
}
Ok(())
}
pub fn validate_async_copy(
device_props: &DeviceProperties,
dtype_global: &StorageType,
dtype_stage: &StorageType,
) -> Result<(), InvalidConfigError> {
if !device_props.features.copy_async {
return Err(Box::new(
"Async copy instructions are not available on the current device",
));
}
if dtype_global.size() != dtype_stage.size() {
return Err(Box::new(
"Async copy requires stage and global types to be the same",
));
}
if matches!(dtype_global, StorageType::Packed(_, _))
&& !matches!(dtype_stage, StorageType::Packed(_, _))
{
return Err(Box::new(
"Async copy doesn't support dequantizing on global read",
));
}
Ok(())
}
pub fn validate_noswizzle(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
if config.swizzle != SwizzleMode::None {
return Err(Box::new("This loader doesn't support swizzling"));
}
Ok(())
}
pub fn validate_swizzle_atom_size(config: StageMemoryConfig) -> Result<(), InvalidConfigError> {
if config.swizzle == SwizzleMode::None {
return Ok(());
}
let vector_bytes = config.dtype.size() * config.vector_size as usize;
if vector_bytes > config.swizzle.atom_size() {
return Err(Box::new("Load atom can't be larger than swizzle atom"));
}
Ok(())
}
pub fn validate_tma(
device_props: &DeviceProperties,
smem_config: &StageMemoryConfig,
global_dtype: &StorageType,
) -> Result<(), InvalidConfigError> {
if !device_props.features.supports_type(SemanticType::TensorMap) {
return Err(Box::new(
"Tensor memory accelerator features are not available on the current device",
));
}
let stage_dtype = smem_config.dtype;
if global_dtype.size() != stage_dtype.size() {
return Err(Box::new(
"TMA requires stage and global types to be the same",
));
}
if matches!(global_dtype, StorageType::Packed(_, _))
&& !matches!(stage_dtype, StorageType::Packed(_, _))
{
return Err(Box::new("TMA doesn't support dequantizing on global read"));
}
if matches!(smem_config.swizzle, SwizzleMode::None) {
return Ok(());
}
let row_size = match smem_config.matrix_layout {
MatrixLayout::RowMajor => smem_config.elements_per_stage_along_col(),
MatrixLayout::ColMajor => smem_config.elements_per_stage_along_row(),
};
let row_bytes = row_size * global_dtype.size() as u32;
if row_bytes as usize != smem_config.swizzle.span_size() {
return Err(Box::new("Swizzling size must be equal to row size for TMA"));
}
Ok(())
}
pub fn validate_async_copy_with_problem(
problem: &MatmulProblem,
dtypes: &MatmulElems,
ident: StageIdent,
) -> Result<(), InvalidConfigError> {
let is_quantized = match ident {
StageIdent::Lhs => problem.lhs_scheme.is_some(),
StageIdent::Rhs => problem.rhs_scheme.is_some(),
StageIdent::Acc | StageIdent::Out => false,
};
if is_quantized {
return Err(Box::new(
"Async copy doesn't support dequantizing on global read",
));
}
let (strides, layout) = match ident {
StageIdent::Lhs => (&problem.lhs_strides, &problem.lhs_layout),
StageIdent::Rhs => (&problem.rhs_strides, &problem.rhs_layout),
_ => unreachable!("Should be a loadable tensors"),
};
if stride_align_bits(strides, layout, &dtypes.global(ident.into())) < 4 {
return Err(Box::new(
"Async copy requires strides to be aligned to 16 bytes",
));
}
Ok(())
}
pub fn validate_tma_with_problem(
problem: &MatmulProblem,
dtypes: &MatmulElems,
ident: StageIdent,
) -> Result<(), InvalidConfigError> {
let is_quantized = match ident {
StageIdent::Lhs => problem.lhs_scheme.is_some(),
StageIdent::Rhs => problem.rhs_scheme.is_some(),
StageIdent::Acc | StageIdent::Out => false,
};
if is_quantized {
return Err(Box::new("TMA doesn't support dequantizing on global read"));
}
let (strides, layout) = match ident {
StageIdent::Lhs => (&problem.lhs_strides, &problem.lhs_layout),
StageIdent::Rhs => (&problem.rhs_strides, &problem.rhs_layout),
_ => unreachable!("Should be a loadable tensors"),
};
if stride_align_bits(strides, layout, &dtypes.global(ident.into())) < 4 {
return Err(Box::new("TMA requires strides to be aligned to 16 bytes"));
}
if problem.lhs_batches != problem.rhs_batches
&& problem.lhs_batches.iter().product::<usize>() != 1
&& problem.rhs_batches.iter().product::<usize>() != 1
{
return Err(Box::new(
"TMA doesn't support mixing broadcast and non-broadcast dims",
));
}
Ok(())
}
fn stride_align_bits(strides: &[usize], layout: &MatrixLayout, dtype: &StorageType) -> u32 {
let exclude_dim = match layout {
MatrixLayout::RowMajor => strides.len() - 1,
MatrixLayout::ColMajor => strides.len() - 2,
};
strides
.iter()
.enumerate()
.filter(|(i, _)| *i != exclude_dim)
.map(|(_, it)| (*it * dtype.size_bits()) / 8)
.map(|it| it.trailing_zeros())
.min()
.unwrap_or(31)
}
pub struct NoLoadingValidation {}
impl LoadingValidation for NoLoadingValidation {
fn validate_with_config(
_device_props: &DeviceProperties,
_config: &GlobalReaderConfig,
) -> Result<(), InvalidConfigError> {
Ok(())
}
fn validate_with_problem(
_problem: &MatmulProblem,
_dtypes: &MatmulElems,
_ident: StageIdent,
) -> Result<(), InvalidConfigError> {
Ok(())
}
}
#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum ReaderMode {
Strict,
#[default]
Relaxed,
}