cubek-std 0.2.0

CubeK: Standard Library
Documentation
use cubecl::{cmma::MmaDefinition, ir::MatrixIdent, prelude::*};

use crate::{
    MatrixLayout, TileSize,
    tile::{
        compute::matmul::mma::{MmaFragmentReader as _, MmaStageReader, MmaStageWriter},
        data::{Filled, MmaMatmul, NA, NL, NR, Strided, StridedTile},
    },
};

#[cube]
fn make_mma_definition<L: Numeric, R: Numeric, A: Numeric>(
    #[comptime] config: MmaMatmul,
) -> MmaDefinition<L, R, A> {
    MmaDefinition::new(
        config.tile_size.m() as usize,
        config.tile_size.n() as usize,
        config.tile_size.k() as usize,
    )
}

#[cube]
pub fn mma_execute<L: Numeric, R: Numeric, A: Numeric>(
    lhs: &Array<Vector<L, NL>>,
    rhs: &Array<Vector<R, NR>>,
    acc: &mut Array<Vector<A, NA>>,
    #[comptime] _matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let def = MmaDefinition::<L, R, A>::new(
        config.tile_size.m() as usize,
        config.tile_size.n() as usize,
        config.tile_size.k() as usize,
    );
    let out_arr = def.execute(lhs, rhs, acc);
    let num_vectors = def.vectors_per_lane(MatrixIdent::Accumulator);
    #[unroll]
    for i in 0..num_vectors {
        acc[i] = out_arr[i];
    }
}

#[cube]
pub fn mma_load_lhs_from_shared<
    E: Numeric,
    ES: Size,
    L: Numeric,
    R: Numeric,
    A: Numeric,
    IO: SliceVisibility,
>(
    shared: &StridedTile<E, ES, IO>,
    fragment: &mut Array<Vector<L, NL>>,
    #[comptime] matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let shared = shared.to_read_only();
    let def = make_mma_definition::<L, R, A>(config);
    MmaStageReader::<Strided>::load_fragment(
        &shared,
        fragment,
        def,
        MatrixIdent::A,
        matrix_layout,
        comptime!(TileSize::new(
            config.tile_size.m(),
            config.tile_size.n(),
            config.tile_size.k(),
        )),
        config.mma_io_config,
    );
}

#[cube]
pub fn mma_load_rhs_from_shared<
    E: Numeric,
    ES: Size,
    R: Numeric,
    L: Numeric,
    A: Numeric,
    IO: SliceVisibility,
>(
    shared: &StridedTile<E, ES, IO>,
    fragment: &mut Array<Vector<R, NR>>,
    #[comptime] matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let shared = shared.to_read_only();
    let def = make_mma_definition::<L, R, A>(config);
    MmaStageReader::<Strided>::load_fragment(
        &shared,
        fragment,
        def,
        MatrixIdent::B,
        matrix_layout,
        comptime!(TileSize::new(
            config.tile_size.m(),
            config.tile_size.n(),
            config.tile_size.k(),
        )),
        config.mma_io_config,
    );
}

#[cube]
pub fn mma_load_acc_from_shared<
    E: Numeric,
    ES: Size,
    A: Numeric,
    L: Numeric,
    R: Numeric,
    IO: SliceVisibility,
>(
    shared: &StridedTile<E, ES, IO>,
    fragment: &mut Array<Vector<A, NA>>,
    #[comptime] matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let shared = shared.to_read_only();
    let def = make_mma_definition::<L, R, A>(config);
    MmaStageReader::<Strided>::load_fragment(
        &shared,
        fragment,
        def,
        MatrixIdent::Accumulator,
        matrix_layout,
        comptime!(TileSize::new(
            config.tile_size.m(),
            config.tile_size.n(),
            config.tile_size.k(),
        )),
        config.mma_io_config,
    );
}

#[cube]
pub fn mma_load_acc_zeros<E: Numeric, ES: Size, A: Numeric, L: Numeric, R: Numeric>(
    fragment: &mut Array<Vector<A, NA>>,
    #[comptime] matrix_layout: MatrixLayout,
    #[comptime] config: MmaMatmul,
) {
    let def = make_mma_definition::<L, R, A>(config);
    MmaStageReader::<Filled>::load_fragment::<A, NA, E, ES, L, R, A>(
        &E::from_int(0),
        fragment,
        def,
        MatrixIdent::Accumulator,
        matrix_layout,
        comptime!(TileSize::new(
            config.tile_size.m(),
            config.tile_size.n(),
            config.tile_size.k(),
        )),
        config.mma_io_config,
    );
}

#[cube]
pub fn mma_write_to_shared<E: Numeric, ES: Size, A: Numeric, L: Numeric, R: Numeric>(
    shared: &mut StridedTile<E, ES, ReadWrite>,
    fragment: &Array<Vector<A, NA>>,
    #[comptime] config: MmaMatmul,
) {
    let def = make_mma_definition::<L, R, A>(config);
    let out_layout = comptime!(shared.layout);
    MmaStageWriter::store_fragment(
        shared,
        fragment,
        def,
        MatrixIdent::Accumulator,
        out_layout,
        config.tile_size.m(),
        config.mma_io_config,
    );
}