use std::marker::PhantomData;
use crate::{
components::global::read::validate_swizzle_atom_size,
components::global::read::{FullLoadingStrategy, sync::Synchronous},
components::global::{PlaneFlowPartition, read::tiled::TiledLayout},
components::stage::StridedStageFamily,
components::stage::{StridedStageMemory, TilingOrder},
components::{global::memory::GlobalIterator, stage::ContiguousTilingLayout},
components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation},
definition::{MatmulElems, MatmulProblem, StageIdent},
{components::global::GlobalReaderConfig, launch::RuntimeConfig},
};
use cubecl::{
std::tensor::layout::Coords2d,
{ir::DeviceProperties, prelude::*},
};
use cubek_std::{
tile::Strided,
{FormattedConfigError, InvalidConfigError},
};
use super::{LoadingJob, LoadingValidation};
#[derive(CubeType, Clone, Copy)]
pub struct SyncFullTilewiseLoading<T: TilingOrder> {
#[cube(comptime)]
tiling_order: PhantomData<T>,
}
impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncFullTilewiseLoading<TO> {
fn max_round_plane_count(
_elements_per_tile: u32,
tiles_per_stage: u32,
_vector_size: VectorSize,
_plane_dim: u32,
_dtype: StorageType,
) -> u32 {
tiles_per_stage
}
}
impl<T: TilingOrder> LoadingValidation for SyncFullTilewiseLoading<T> {
fn validate_with_config(
_device_props: &DeviceProperties,
config: &GlobalReaderConfig,
) -> Result<(), InvalidConfigError> {
let vector_size = config.gmem_config.vector_size;
let num_planes = config.loading_planes_count();
let num_tiles = config.smem_config.tiles_per_stage();
if !num_tiles.is_multiple_of(num_planes) {
return Err(FormattedConfigError::new(move || {
format!(
"Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for tilewise loading.",
)
}));
}
let num_tiles_per_plane = num_tiles / num_planes;
let num_vectors_per_tile = config.smem_config.elements_per_tile() / vector_size as u32;
let num_vectors_per_plane = num_vectors_per_tile * num_tiles_per_plane;
let plane_dim = config.plane_dim;
if !num_vectors_per_plane.is_multiple_of(plane_dim) {
return Err(FormattedConfigError::new(move || {
format!(
"Plane dimension {plane_dim:?} must divide number of vectors per plane {num_vectors_per_plane:?} for tilewise loading.",
)
}));
}
validate_swizzle_atom_size(config.smem_config)?;
ContiguousTilingLayout::<T>::check(config.smem_config)?;
Ok(())
}
fn validate_with_problem(
_problem: &MatmulProblem,
_dtypes: &MatmulElems,
_ident: StageIdent,
) -> Result<(), InvalidConfigError> {
Ok(())
}
}
#[cube]
impl<TO: TilingOrder, RC: RuntimeConfig> FullLoadingStrategy<RC> for SyncFullTilewiseLoading<TO> {
type TilingLayout = ContiguousTilingLayout<TO>;
type SyncStrategy = Synchronous;
type Job<EG: Numeric, NG: Size, ES: Numeric, NS: Size> = SyncFullTilewiseJob;
type Stage = StridedStageFamily;
type TileKind = Strided;
fn new_job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>(
_runtime_config: RC,
#[comptime] config: GlobalReaderConfig,
) -> Self::Job<EG, NG, ES, NS> {
let vector_size = NG::value().comptime() as u32;
let num_planes = config.loading_planes_count();
let num_tiles = config.smem_config.tiles_per_stage();
let num_tiles_per_plane = num_tiles / num_planes;
let num_vectors_per_tile = config.smem_config.elements_per_tile() / vector_size;
let num_vectors_per_plane = num_vectors_per_tile * num_tiles_per_plane;
let num_vectors_per_unit = num_vectors_per_plane / config.plane_dim;
let num_tiles_to_skip = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
.load_index(config.input_load_flow)
* num_tiles_per_plane;
let num_vectors_to_skip = num_tiles_to_skip * num_vectors_per_tile;
SyncFullTilewiseJob {
num_tiles_to_skip,
num_vectors_to_skip,
num_vectors_per_tile,
num_vectors_per_unit,
plane_dim: config.plane_dim,
}
}
}
#[derive(CubeType, Clone, Copy)]
pub struct SyncFullTilewiseJob {
pub num_tiles_to_skip: u32,
pub num_vectors_to_skip: u32,
#[cube(comptime)]
pub num_vectors_per_tile: u32,
#[cube(comptime)]
pub num_vectors_per_unit: u32,
#[cube(comptime)]
pub plane_dim: u32,
}
#[cube]
impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, TO: TilingOrder>
LoadingJob<EG, NG, ES, NS, ContiguousTilingLayout<TO>, Synchronous> for SyncFullTilewiseJob
{
type Stage = StridedStageFamily;
fn execute_task(
this: &mut Self,
#[comptime] task_id: u32,
global_iter: &GlobalIterator<Vector<EG, NG>>,
stage: &mut StridedStageMemory<ES, NS, ContiguousTilingLayout<TO>>,
_barrier: &mut (),
#[comptime] config: GlobalReaderConfig,
) {
let pos_across_tiles = task_id * this.plane_dim + UNIT_POS_X;
let nth_tile_for_this_plane = pos_across_tiles / this.num_vectors_per_tile;
let vector_index_within_tile = pos_across_tiles % this.num_vectors_per_tile;
let nth_tile_global = nth_tile_for_this_plane + this.num_tiles_to_skip;
let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile_global, config.smem_config);
SyncFullTilewiseJob::load_and_store_vector::<EG, NG, ES, NS, TO>(
this,
tile,
vector_index_within_tile,
nth_tile_for_this_plane * this.num_vectors_per_tile,
global_iter,
stage,
config,
);
}
fn task_count(this: &Self) -> comptime_type!(u32) {
this.num_vectors_per_unit
}
}
#[cube]
impl SyncFullTilewiseJob {
#[allow(clippy::too_many_arguments)]
fn load_and_store_vector<EG: Numeric, NG: Size, ES: Numeric, NS: Size, TO: TilingOrder>(
this: &Self,
tile: Coords2d,
vector_index_within_tile: u32,
num_vectors_to_skip_local: u32,
global_iter: &GlobalIterator<Vector<EG, NG>>,
stage: &mut StridedStageMemory<ES, NS, ContiguousTilingLayout<TO>>,
#[comptime] config: GlobalReaderConfig,
) {
let vector_size = NG::value().comptime() as u32;
let layout = TiledLayout::new(config.stage_ident, config.smem_config);
let view = global_iter.view().view(layout);
let vector_read = view.read_checked((tile, vector_index_within_tile * vector_size));
let offset =
this.num_vectors_to_skip + vector_index_within_tile + num_vectors_to_skip_local;
let type_size = Vector::<ES, NS>::type_size();
let offset = stage.swizzle.apply(offset, type_size);
stage.as_slice_mut::<NS>()[offset as usize] = Vector::cast_from(vector_read);
}
}