use crate::{
components::{
global::{
GlobalReaderConfig,
memory::{GlobalIterator, load_window_in_stage},
multi_stage::LoadMaxRoundPlaneCount,
read::{
FullLoadingStrategy, LoadingJob, async_barrier::AsyncBarrier,
validate_async_barrier, validate_noswizzle,
},
},
stage::{StridedStageFamily, StridedStageMemory, StridedTilingLayout, TilingValidation},
},
definition::{MatmulElems, MatmulProblem, StageIdent},
launch::RuntimeConfig,
};
use cubecl::{
ir::DeviceProperties,
prelude::{barrier::Barrier, *},
};
use cubek_std::{InvalidConfigError, MatrixLayout, tile::Strided};
use super::LoadingValidation;
#[derive(CubeType, Clone, Copy)]
pub struct AsyncFullCooperativeLoading {}
impl LoadingValidation for AsyncFullCooperativeLoading {
fn validate_with_config(
device_props: &DeviceProperties,
config: &GlobalReaderConfig,
) -> Result<(), InvalidConfigError> {
StridedTilingLayout::check(config.smem_config)?;
validate_async_barrier(device_props)?;
validate_noswizzle(config.smem_config)?;
Ok(())
}
fn validate_with_problem(
_problem: &MatmulProblem,
_dtypes: &MatmulElems,
_ident: StageIdent,
) -> Result<(), InvalidConfigError> {
Ok(())
}
}
impl LoadMaxRoundPlaneCount for AsyncFullCooperativeLoading {
fn max_round_plane_count(
_elements_per_tile: u32,
_tiles_per_stage: u32,
_vector_size: VectorSize,
_plane_dim: u32,
_dtype: StorageType,
) -> u32 {
4
}
}
#[cube]
impl<RC: RuntimeConfig> FullLoadingStrategy<RC> for AsyncFullCooperativeLoading {
type TilingLayout = StridedTilingLayout;
type SyncStrategy = AsyncBarrier;
type Job<EG: Numeric, NG: Size, ES: Numeric, NS: Size> = AsyncFullCooperativeJob;
type Stage = StridedStageFamily;
type TileKind = Strided;
const SHOULD_CLEAR: bool = true;
fn new_job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>(
_runtime_config: RC,
#[comptime] config: GlobalReaderConfig,
) -> AsyncFullCooperativeJob {
let matrix_layout = config.gmem_config.matrix_layout;
let num_slices = match matrix_layout {
MatrixLayout::RowMajor => config.smem_config.elements_per_stage_along_row(),
MatrixLayout::ColMajor => config.smem_config.elements_per_stage_along_col(),
};
AsyncFullCooperativeJob { num_slices }
}
}
#[derive(CubeType, Clone, Copy)]
pub struct AsyncFullCooperativeJob {
#[cube(comptime)]
num_slices: u32,
}
#[cube]
impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size>
LoadingJob<EG, NG, ES, NS, StridedTilingLayout, AsyncBarrier> for AsyncFullCooperativeJob
{
type Stage = StridedStageFamily;
fn execute_task(
_this: &mut Self,
#[comptime] task_id: u32,
global_iter: &GlobalIterator<Vector<EG, NG>>,
stage: &mut StridedStageMemory<ES, NS, StridedTilingLayout>,
barrier: &mut Shared<Barrier>,
#[comptime] config: GlobalReaderConfig,
) {
let window = load_window_in_stage(
&global_iter.view(),
task_id,
config.smem_config,
config.gmem_config,
);
let mut destination: SliceMut<Vector<ES, NS>> =
StridedTilingLayout::nth_slice::<ES, NS>(stage, task_id, config.smem_config);
barrier.memcpy_async_cooperative(&window.downcast(), &mut destination);
}
fn task_count(this: &Self) -> comptime_type!(u32) {
this.num_slices
}
}