cubek-matmul 0.2.0

CubeK: Matrix Multiplication Kernels
Documentation
use std::marker::PhantomData;

use crate::components::{
    global::memory::GlobalIterator,
    stage::{ContiguousTilingLayout, TilingOrder},
};
use crate::{
    components::global::read::validate_swizzle_atom_size,
    components::global::read::{PartialLoadingStrategy, sync::Synchronous},
    components::global::{PlaneFlowPartition, read::tiled::TiledLayout},
    components::stage::StridedStageFamily,
    components::stage::StridedStageMemory,
    components::stage::TilingOrderEnum,
};
use crate::{
    components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation},
    definition::{MatmulElems, MatmulProblem, StageIdent},
    {components::global::GlobalReaderConfig, launch::RuntimeConfig},
};
use cubecl::{
    std::tensor::layout::Coords2d,
    {ir::DeviceProperties, prelude::*},
};
use cubek_std::{
    tile::Strided,
    {FormattedConfigError, InvalidConfigError},
};

use super::{LoadingJob, LoadingValidation};

#[derive(CubeType, Clone, Copy)]
/// Each tile is guaranteed to be loaded entirely by the same plane.
/// Each plane can load multiple tiles, provided the number of planes evenly divides the number of tiles.
/// In this case, a plane loads contiguous tiles following the `TilingOrder`,
/// until it would otherwise write to the opposite stage. At that point, it continues on the next
/// row or column of the same stage, skipping over the memory region of the other stage.
///
/// Only supports RowMajorTilingOrder for Lhs and ColMajorTilingOrder for Rhs
pub struct SyncPartialTilewiseLoading<T: TilingOrder> {
    #[cube(comptime)]
    tiling_order: PhantomData<T>,
}

impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncPartialTilewiseLoading<TO> {
    fn max_round_plane_count(
        _elements_per_tile: u32,
        tiles_per_stage: u32,
        _vector_size: VectorSize,
        _plane_dim: u32,
        _dtype: StorageType,
    ) -> u32 {
        tiles_per_stage
    }
}

impl<T: TilingOrder> LoadingValidation for SyncPartialTilewiseLoading<T> {
    fn validate_with_config(
        _device_props: &DeviceProperties,
        config: &GlobalReaderConfig,
    ) -> Result<(), InvalidConfigError> {
        let vector_size = config.gmem_config.vector_size;
        let num_planes = config.loading_planes_count();
        let num_tiles = config.smem_config.tiles_per_stage();

        if !num_tiles.is_multiple_of(num_planes) {
            return Err(FormattedConfigError::new(move || {
                "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for tilewise loading.".to_string()
            }));
        }

        let num_tiles_per_plane = num_tiles / num_planes;
        let num_vectors_per_tile = config.smem_config.elements_per_tile() / vector_size as u32;
        let num_vectors_per_plane = num_vectors_per_tile * num_tiles_per_plane;
        let num_planes = config.plane_dim;

        if !num_vectors_per_plane.is_multiple_of(num_planes) {
            return Err(FormattedConfigError::new(move || {
                format!(
                    "Number of planes {num_planes:?} must divide number of vectors per plane {num_vectors_per_plane:?} for tilewise loading."
                )
            }));
        }

        match config.stage_ident {
            StageIdent::Lhs => {
                if !matches!(T::to_enum(), TilingOrderEnum::RowMajor) {
                    return Err(FormattedConfigError::new(move || {
                        "Sync partial tilewise on Lhs is only supported with RowMajor tiling order"
                            .to_string()
                    }));
                }
            }
            StageIdent::Rhs => {
                if !matches!(T::to_enum(), TilingOrderEnum::ColMajor) {
                    return Err(FormattedConfigError::new(move || {
                        "Sync partial tilewise on Rhs is only supported with ColMajor tiling order"
                            .to_string()
                    }));
                }
            }
            _ => unreachable!(),
        }

        validate_swizzle_atom_size(config.smem_config)?;
        ContiguousTilingLayout::<T>::check(config.smem_config)?;

        Ok(())
    }

    fn validate_with_problem(
        _problem: &MatmulProblem,
        _dtypes: &MatmulElems,
        _ident: StageIdent,
    ) -> Result<(), InvalidConfigError> {
        Ok(())
    }
}

#[cube]
impl<TO: TilingOrder, RC: RuntimeConfig> PartialLoadingStrategy<RC>
    for SyncPartialTilewiseLoading<TO>
{
    type TilingLayout = ContiguousTilingLayout<TO>;
    type SyncStrategy = Synchronous;
    type Stage = StridedStageFamily;
    type TileKind = Strided;

    type Job<EG: Numeric, NG: Size, ES: Numeric, NS: Size> = SyncPartialTilewiseJob;

    fn new_job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>(
        _runtime_config: RC,
        #[comptime] stage_index: u32,
        #[comptime] config: GlobalReaderConfig,
    ) -> SyncPartialTilewiseJob {
        let vector_size = NG::value().comptime() as u32;
        let num_planes = config.loading_planes_count();
        let num_tiles = config.smem_config.tiles_per_stage();
        let plane_dim = config.plane_dim;

        let num_tiles_per_plane = num_tiles / num_planes;
        let num_vectors_per_tile = config.smem_config.elements_per_tile() / vector_size;
        let num_vectors_per_plane = num_vectors_per_tile * num_tiles_per_plane;
        let num_vectors_per_unit = num_vectors_per_plane / plane_dim;

        let stage_width = match config.stage_ident {
            StageIdent::Lhs => config.smem_config.tiles_per_stage_along_col(),
            StageIdent::Rhs => config.smem_config.tiles_per_stage_along_row(),
            _ => unreachable!(),
        };

        let num_tiles_to_skip = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
            .load_index(config.input_load_flow)
            * num_tiles_per_plane;

        SyncPartialTilewiseJob {
            stage_index,
            num_tiles_to_skip,
            stage_width,
            num_vectors_per_tile,
            num_vectors_per_unit,
            plane_dim,
        }
    }
}

#[derive(CubeType, Clone, Copy)]
pub struct SyncPartialTilewiseJob {
    num_tiles_to_skip: u32,
    stage_index: u32,

    #[cube(comptime)]
    stage_width: u32,
    #[cube(comptime)]
    num_vectors_per_tile: u32,
    #[cube(comptime)]
    num_vectors_per_unit: u32,
    #[cube(comptime)]
    plane_dim: u32,
}

#[cube]
impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, TO: TilingOrder>
    LoadingJob<EG, NG, ES, NS, ContiguousTilingLayout<TO>, Synchronous> for SyncPartialTilewiseJob
{
    type Stage = StridedStageFamily;

    fn execute_task(
        this: &mut Self,
        #[comptime] task_id: u32,
        global_iter: &GlobalIterator<Vector<EG, NG>>,
        stage: &mut StridedStageMemory<ES, NS, ContiguousTilingLayout<TO>>,
        _barrier: &mut (),
        #[comptime] config: GlobalReaderConfig,
    ) {
        let mut stage = stage.with_buffer_index(this.stage_index);
        let pos_across_tiles = task_id * this.plane_dim + UNIT_POS_X;
        let nth_tile_for_this_plane = pos_across_tiles / this.num_vectors_per_tile;
        let vector_index_within_tile = pos_across_tiles % this.num_vectors_per_tile;

        let nth_tile_global = this.num_tiles_to_skip + nth_tile_for_this_plane;

        let tile = TO::to_row_col(
            nth_tile_global,
            config.smem_config.tiles_per_stage_along_row(),
            config.smem_config.tiles_per_stage_along_col(),
            config.smem_config,
        );

        let tile = match config.stage_ident {
            StageIdent::Lhs => (tile.0, tile.1 + this.stage_index * this.stage_width),
            StageIdent::Rhs => (tile.0 + this.stage_index * this.stage_width, tile.1),
            _ => tile,
        };

        let num_vectors_to_skip_global = nth_tile_global * this.num_vectors_per_tile;

        SyncPartialTilewiseJob::load_and_store_vector::<EG, NG, ES, NS, TO>(
            tile,
            vector_index_within_tile,
            num_vectors_to_skip_global,
            global_iter,
            &mut stage,
            config,
        );
    }

    fn task_count(this: &Self) -> comptime_type!(u32) {
        this.num_vectors_per_unit
    }
}

#[cube]
impl SyncPartialTilewiseJob {
    #[allow(clippy::too_many_arguments)]
    fn load_and_store_vector<EG: Numeric, NG: Size, ES: Numeric, NS: Size, TO: TilingOrder>(
        tile: Coords2d,
        vector_index_within_tile: u32,
        num_vectors_to_skip_global: u32,
        global_iter: &GlobalIterator<Vector<EG, NG>>,
        stage: &mut StridedStageMemory<ES, NS, ContiguousTilingLayout<TO>>,
        #[comptime] config: GlobalReaderConfig,
    ) {
        let layout = TiledLayout::new(config.stage_ident, config.smem_config);
        let view = global_iter.view().view(layout);

        let vector_read = view.read_checked((tile, vector_index_within_tile * NG::value() as u32));

        let offset = vector_index_within_tile + num_vectors_to_skip_global;
        let type_size = Vector::<ES, NS>::type_size();
        let offset = stage.swizzle.apply(offset, type_size);

        stage.as_slice_mut::<NS>()[offset as usize] = Vector::cast_from(vector_read);
    }
}