cubek-std 0.2.0-pre.5

CubeK: Standard Library
Documentation
use cubecl::{
    cmma::{self},
    prelude::*,
};

use crate::{
    MatrixLayout, StageIdent, SwizzleModes, TileSize, as_cmma_layout,
    tile::{
        CmmaFragmentReader, CmmaStageReader, CmmaStageWriter, CmmaTile, Strided, StridedTile, Tile,
        scope::Scope,
    },
};

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct CmmaMatmul {
    pub tile_size: TileSize,
    pub plane_dim: u32,
    pub swizzle_modes: SwizzleModes,
}

impl CmmaMatmul {
    pub fn new(tile_size: TileSize, plane_dim: u32, swizzle_modes: SwizzleModes) -> Self {
        Self {
            tile_size,
            plane_dim,
            swizzle_modes,
        }
    }
}

#[cube]
pub fn cmma_allocate_lhs<L: Numeric, Sc: Scope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] tile_size: TileSize,
) -> Tile<L, Sc, ReadWrite> {
    let fragment = unsafe {
        cmma::Matrix::<L>::uninitialized(
            cmma::MatrixIdent::A,
            tile_size.m as usize,
            tile_size.n as usize,
            tile_size.k as usize,
            as_cmma_layout(layout),
        )
    };
    Tile::new_Cmma(CmmaTile::<L> {
        matrix: fragment,
        matrix_layout: layout,
        tile_size,
    })
}

#[cube]
pub fn cmma_allocate_rhs<R: Numeric, Sc: Scope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] tile_size: TileSize,
) -> Tile<R, Sc, ReadWrite> {
    let fragment = unsafe {
        cmma::Matrix::<R>::uninitialized(
            cmma::MatrixIdent::B,
            tile_size.m as usize,
            tile_size.n as usize,
            tile_size.k as usize,
            as_cmma_layout(layout),
        )
    };
    Tile::new_Cmma(CmmaTile::<R> {
        matrix: fragment,
        matrix_layout: layout,
        tile_size,
    })
}

#[cube]
pub fn cmma_allocate_acc<A: Numeric, Sc: Scope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] tile_size: TileSize,
) -> Tile<A, Sc, ReadWrite> {
    let fragment = unsafe {
        cmma::Matrix::<A>::uninitialized(
            cmma::MatrixIdent::Accumulator,
            tile_size.m as usize,
            tile_size.n as usize,
            tile_size.k as usize,
            cmma::MatrixLayout::Undefined,
        )
    };
    Tile::new_Cmma(CmmaTile::<A> {
        matrix: fragment,
        matrix_layout: layout,
        tile_size,
    })
}

#[cube]
pub fn cmma_execute<L: Numeric, R: Numeric, A: Numeric>(
    lhs: &cmma::Matrix<L>,
    rhs: &cmma::Matrix<R>,
    acc: &mut cmma::Matrix<A>,
) {
    cmma::execute::<L, R, A, A>(lhs, rhs, acc, acc);
}

#[cube]
pub fn cmma_load_from_shared<E: Numeric, ES: Size, N: Numeric, IO: SliceVisibility>(
    shared: &StridedTile<E, ES, IO>,
    matrix: &mut cmma::Matrix<N>,
    #[comptime] ident: StageIdent,
    #[comptime] matrix_layout: MatrixLayout,
) {
    let shared = shared.to_read_only();
    match ident {
        StageIdent::Lhs | StageIdent::Rhs => {
            CmmaStageReader::<Strided>::load_fragment(&shared, matrix, ComptimeOption::new_None());
        }
        StageIdent::Acc => {
            CmmaStageReader::<Strided>::load_fragment(
                &shared,
                matrix,
                ComptimeOption::new_Some(as_cmma_layout(matrix_layout)),
            );
        }
        _ => panic!("Invalid ident for CMMA load"),
    }
}

#[cube]
pub fn cmma_load_zeros<N: Numeric>(matrix: &mut cmma::Matrix<N>) {
    cmma::fill(matrix, N::from_int(0));
}

#[cube]
pub fn cmma_write_to_shared<E: Numeric, ES: Size, A: Numeric>(
    shared: &mut StridedTile<E, ES, ReadWrite>,
    matrix: &cmma::Matrix<A>,
) {
    let casted = cmma::cast::<A, E>(matrix);
    CmmaStageWriter::store_fragment(shared, &casted);
}