cubek-std 0.2.0

CubeK: Standard Library
Documentation
use cubecl::prelude::*;

use crate::{
    MatrixLayout, SwizzleModes, TileSize,
    tile::{Tile, TileScope},
};

#[derive(CubeType)]
pub struct InterleavedTile<N: Numeric> {
    pub data: Array<N>,
    #[cube(comptime)]
    pub matrix_layout: MatrixLayout,
    #[cube(comptime)]
    pub config: InterleavedMatmul,
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct InterleavedMatmul {
    pub tile_size: TileSize,
    pub plane_dim: u32,
    pub swizzle_modes: SwizzleModes,
}

impl InterleavedMatmul {
    pub fn new(tile_size: TileSize, plane_dim: u32, swizzle_modes: SwizzleModes) -> Self {
        Self {
            tile_size,
            plane_dim,
            swizzle_modes,
        }
    }

    pub fn elements_per_unit_m(&self) -> usize {
        self.tile_size.m() as usize
    }

    pub fn elements_per_unit_n(&self) -> usize {
        self.tile_size.n() as usize
    }

    pub fn local_tile_size(&self) -> TileSize {
        TileSize {
            m: self.tile_size.m(),
            n: self.tile_size.n(),
            k: self.tile_size.k(),
        }
    }

    pub fn elements_per_unit_k(&self) -> usize {
        let k = self.tile_size.k() as usize;
        let plane_dim = self.plane_dim as usize;
        assert!(
            k.is_multiple_of(plane_dim),
            "k must be divisible by plane_dim. Got k={:?}, plane_dim={:?}",
            k,
            plane_dim
        );

        k / plane_dim
    }
}

#[cube]
pub fn interleaved_allocate_lhs<L: Numeric, Sc: TileScope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] config: InterleavedMatmul,
) -> Tile<L, Sc, ReadWrite> {
    let m = config.tile_size.m();
    let k = config.tile_size.k();
    let plane_dim = config.plane_dim;
    Tile::new_Interleaved(InterleavedTile::<L> {
        data: Array::new((m * (k / plane_dim)) as usize),
        matrix_layout: layout,
        config,
    })
}

#[cube]
pub fn interleaved_allocate_rhs<R: Numeric, Sc: TileScope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] config: InterleavedMatmul,
) -> Tile<R, Sc, ReadWrite> {
    let n = config.tile_size.n();
    let k = config.tile_size.k();
    let plane_dim = config.plane_dim;
    Tile::new_Interleaved(InterleavedTile::<R> {
        data: Array::new(((k / plane_dim) * n) as usize),
        matrix_layout: layout,
        config,
    })
}

#[cube]
pub fn interleaved_allocate_acc<A: Numeric, Sc: TileScope>(
    #[comptime] layout: MatrixLayout,
    #[comptime] config: InterleavedMatmul,
) -> Tile<A, Sc, ReadWrite> {
    let m = config.tile_size.m();
    let n = config.tile_size.n();
    Tile::new_Interleaved(InterleavedTile::<A> {
        data: Array::new((m * n) as usize),
        matrix_layout: layout,
        config,
    })
}