cubek-std 0.2.0-pre.5

CubeK: Standard Library
Documentation
use cubecl::{cmma::MmaDefinition, define_size, ir::MatrixIdent, prelude::*};

use crate::{
    MatrixLayout, SwizzleModes, TileSize,
    tile::{
        Filled, MmaAccTile, MmaLhsTile, MmaRhsTile, Strided, StridedTile, Tile,
        mma::{MmaFragmentReader as _, MmaIOConfig, MmaStageReader, MmaStageWriter},
        scope::Scope,
    },
};

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct MmaMatmul {
    pub tile_size: TileSize,
    pub plane_dim: u32,
    pub swizzle_modes: SwizzleModes,
    pub mma_io_config: MmaIOConfig,
}

// Fragment inner vector sizes for the three MMA roles. Bound at allocation time
// via `mma_register_vector_sizes` to match the hardware's `def.vector_size(...)`
// for each role — these are independent of the outer Tile enum's stage vector `V`.
define_size!(pub NL);
define_size!(pub NR);
define_size!(pub NA);

#[cube]
fn make_mma_definition<L: Numeric, R: Numeric, A: Numeric>(
    #[comptime] config: MmaMatmul,
) -> MmaDefinition<L, R, A> {
    MmaDefinition::new(
        config.tile_size.m() as usize,
        config.tile_size.n() as usize,
        config.tile_size.k() as usize,
    )
}

#[cube]
#[allow(unused_variables)]
pub fn mma_register_vector_sizes<L: Numeric, R: Numeric, A: Numeric>(def: MmaDefinition<L, R, A>) {
    let vector_size_a = def.vector_size(MatrixIdent::A);
    let vector_size_b = def.vector_size(MatrixIdent::B);
    let vector_size_acc = def.vector_size(MatrixIdent::Accumulator);
    intrinsic!(|scope| {
        scope.register_size::<NL>(vector_size_a);
        scope.register_size::<NR>(vector_size_b);
        scope.register_size::<NA>(vector_size_acc);
    });
}

#[cube]
pub fn mma_allocate_lhs<L: Numeric, R: Numeric, A: Numeric, Sc: Scope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) -> Tile<L, Sc, ReadWrite> {
    let def = make_mma_definition::<L, R, A>(config);
    mma_register_vector_sizes(def);
    let vector_count = def.vectors_per_lane(MatrixIdent::A);

    Tile::new_MmaLhs(MmaLhsTile::<L> {
        fragment: Array::new(vector_count),
        matrix_layout: layout,
        config,
    })
}

#[cube]
pub fn mma_allocate_rhs<R: Numeric, L: Numeric, A: Numeric, Sc: Scope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) -> Tile<R, Sc, ReadWrite> {
    let def = make_mma_definition::<L, R, A>(config);
    mma_register_vector_sizes(def);
    let vector_count = def.vectors_per_lane(MatrixIdent::B);

    Tile::new_MmaRhs(MmaRhsTile::<R> {
        fragment: Array::new(vector_count),
        matrix_layout: layout,
        config,
    })
}

#[cube]
pub fn mma_allocate_acc<A: Numeric, L: Numeric, R: Numeric, Sc: Scope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) -> Tile<A, Sc, ReadWrite> {
    let def = make_mma_definition::<L, R, A>(config);
    mma_register_vector_sizes(def);
    let vector_count = def.vectors_per_lane(MatrixIdent::Accumulator);

    Tile::new_MmaAcc(MmaAccTile::<A> {
        fragment: Array::new(vector_count),
        matrix_layout: layout,
        config,
    })
}

#[cube]
pub fn mma_execute<L: Numeric, R: Numeric, A: Numeric>(
    lhs: &Array<Vector<L, NL>>,
    rhs: &Array<Vector<R, NR>>,
    acc: &mut Array<Vector<A, NA>>,
    #[comptime] _matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let def = MmaDefinition::<L, R, A>::new(
        config.tile_size.m() as usize,
        config.tile_size.n() as usize,
        config.tile_size.k() as usize,
    );
    let out_arr = def.execute(lhs, rhs, acc);
    let num_vectors = def.vectors_per_lane(MatrixIdent::Accumulator);
    #[unroll]
    for i in 0..num_vectors {
        acc[i] = out_arr[i];
    }
}

#[cube]
pub fn mma_load_lhs_from_shared<
    E: Numeric,
    ES: Size,
    L: Numeric,
    R: Numeric,
    A: Numeric,
    IO: SliceVisibility,
>(
    shared: &StridedTile<E, ES, IO>,
    fragment: &mut Array<Vector<L, NL>>,
    #[comptime] matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let shared = shared.to_read_only();
    let def = make_mma_definition::<L, R, A>(config);
    MmaStageReader::<Strided>::load_fragment(
        &shared,
        fragment,
        def,
        MatrixIdent::A,
        matrix_layout,
        comptime!(TileSize::new(
            config.tile_size.m(),
            config.tile_size.n(),
            config.tile_size.k(),
        )),
        config.mma_io_config,
    );
}

#[cube]
pub fn mma_load_rhs_from_shared<
    E: Numeric,
    ES: Size,
    R: Numeric,
    L: Numeric,
    A: Numeric,
    IO: SliceVisibility,
>(
    shared: &StridedTile<E, ES, IO>,
    fragment: &mut Array<Vector<R, NR>>,
    #[comptime] matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let shared = shared.to_read_only();
    let def = make_mma_definition::<L, R, A>(config);
    MmaStageReader::<Strided>::load_fragment(
        &shared,
        fragment,
        def,
        MatrixIdent::B,
        matrix_layout,
        comptime!(TileSize::new(
            config.tile_size.m(),
            config.tile_size.n(),
            config.tile_size.k(),
        )),
        config.mma_io_config,
    );
}

#[cube]
pub fn mma_load_acc_from_shared<
    E: Numeric,
    ES: Size,
    A: Numeric,
    L: Numeric,
    R: Numeric,
    IO: SliceVisibility,
>(
    shared: &StridedTile<E, ES, IO>,
    fragment: &mut Array<Vector<A, NA>>,
    #[comptime] matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let shared = shared.to_read_only();
    let def = make_mma_definition::<L, R, A>(config);
    MmaStageReader::<Strided>::load_fragment(
        &shared,
        fragment,
        def,
        MatrixIdent::Accumulator,
        matrix_layout,
        comptime!(TileSize::new(
            config.tile_size.m(),
            config.tile_size.n(),
            config.tile_size.k(),
        )),
        config.mma_io_config,
    );
}

#[cube]
pub fn mma_load_acc_zeros<E: Numeric, ES: Size, A: Numeric, L: Numeric, R: Numeric>(
    fragment: &mut Array<Vector<A, NA>>,
    #[comptime] matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let def = make_mma_definition::<L, R, A>(config);
    MmaStageReader::<Filled>::load_fragment::<A, NA, E, ES, L, R, A>(
        &E::from_int(0),
        fragment,
        def,
        MatrixIdent::Accumulator,
        matrix_layout,
        comptime!(TileSize::new(
            config.tile_size.m(),
            config.tile_size.n(),
            config.tile_size.k(),
        )),
        config.mma_io_config,
    );
}

#[cube]
pub fn mma_write_to_shared<E: Numeric, ES: Size, A: Numeric, L: Numeric, R: Numeric>(
    shared: &mut StridedTile<E, ES, ReadWrite>,
    fragment: &Array<Vector<A, NA>>,
    #[comptime] config: MmaMatmul,
) {
    let def = make_mma_definition::<L, R, A>(config);
    let out_layout = comptime!(shared.layout);
    MmaStageWriter::store_fragment(
        shared,
        fragment,
        def,
        MatrixIdent::Accumulator,
        out_layout,
        config.tile_size.m(),
        config.mma_io_config,
    );
}