cubek-std 0.2.0

CubeK: Standard Library
Documentation
use cubecl::{
    cmma::Matrix as CubeMatrix,
    cmma::{self},
    prelude::*,
};

use crate::{
    MatrixLayout, SwizzleModes, TileSize, as_cmma_layout,
    tile::{Tile, TileScope},
};

#[derive(CubeType)]
pub struct CmmaTile<N: Numeric> {
    pub matrix: CubeMatrix<N>,
    #[cube(comptime)]
    pub matrix_layout: MatrixLayout,
    #[cube(comptime)]
    pub tile_size: TileSize,
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct CmmaMatmul {
    pub tile_size: TileSize,
    pub plane_dim: u32,
    pub swizzle_modes: SwizzleModes,
}

impl CmmaMatmul {
    pub fn new(tile_size: TileSize, plane_dim: u32, swizzle_modes: SwizzleModes) -> Self {
        Self {
            tile_size,
            plane_dim,
            swizzle_modes,
        }
    }
}

#[cube]
pub fn cmma_allocate_lhs<L: Numeric, Sc: TileScope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] tile_size: TileSize,
) -> Tile<L, Sc, ReadWrite> {
    let fragment = unsafe {
        cmma::Matrix::<L>::uninitialized(
            cmma::MatrixIdent::A,
            tile_size.m as usize,
            tile_size.n as usize,
            tile_size.k as usize,
            as_cmma_layout(layout),
        )
    };
    Tile::new_Cmma(CmmaTile::<L> {
        matrix: fragment,
        matrix_layout: layout,
        tile_size,
    })
}

#[cube]
pub fn cmma_allocate_rhs<R: Numeric, Sc: TileScope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] tile_size: TileSize,
) -> Tile<R, Sc, ReadWrite> {
    let fragment = unsafe {
        cmma::Matrix::<R>::uninitialized(
            cmma::MatrixIdent::B,
            tile_size.m as usize,
            tile_size.n as usize,
            tile_size.k as usize,
            as_cmma_layout(layout),
        )
    };
    Tile::new_Cmma(CmmaTile::<R> {
        matrix: fragment,
        matrix_layout: layout,
        tile_size,
    })
}

#[cube]
pub fn cmma_allocate_acc<A: Numeric, Sc: TileScope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] tile_size: TileSize,
) -> Tile<A, Sc, ReadWrite> {
    let fragment = unsafe {
        cmma::Matrix::<A>::uninitialized(
            cmma::MatrixIdent::Accumulator,
            tile_size.m as usize,
            tile_size.n as usize,
            tile_size.k as usize,
            cmma::MatrixLayout::Undefined,
        )
    };
    Tile::new_Cmma(CmmaTile::<A> {
        matrix: fragment,
        matrix_layout: layout,
        tile_size,
    })
}