use cubecl::prelude::*;
use crate::{
MatrixLayout, SwizzleModes, TileSize,
tile::{Tile, TileScope},
};
#[derive(CubeType)]
pub struct RegisterTile<N: Numeric> {
pub data: Array<N>,
#[cube(comptime)]
pub matrix_layout: MatrixLayout,
#[cube(comptime)]
pub config: RegisterMatmul,
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum ProductType {
Inner,
Outer,
}
impl ProductType {
pub fn from_layouts(
lhs_layout: MatrixLayout,
rhs_layout: MatrixLayout,
tile_size: TileSize,
) -> Self {
let lhs_preferred = match lhs_layout {
MatrixLayout::RowMajor => ProductType::Inner,
MatrixLayout::ColMajor => ProductType::Outer,
};
let rhs_preferred = match rhs_layout {
MatrixLayout::RowMajor => ProductType::Outer,
MatrixLayout::ColMajor => ProductType::Inner,
};
if lhs_preferred == rhs_preferred {
lhs_preferred
} else if tile_size.m() == 1 {
rhs_preferred
} else if tile_size.n() == 1 {
lhs_preferred
} else {
ProductType::Outer
}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct RegisterMatmul {
pub tile_size: TileSize,
pub plane_dim: u32,
pub swizzle_modes: SwizzleModes,
pub product_type: ProductType,
}
impl RegisterMatmul {
pub fn new(
lhs_layout: MatrixLayout,
rhs_layout: MatrixLayout,
tile_size: TileSize,
plane_dim: u32,
swizzle_modes: SwizzleModes,
) -> Self {
Self {
tile_size,
plane_dim,
swizzle_modes,
product_type: ProductType::from_layouts(lhs_layout, rhs_layout, tile_size),
}
}
}
#[cube]
pub fn register_allocate_lhs<L: Numeric, Sc: TileScope>(
#[comptime] layout: MatrixLayout,
#[comptime] config: RegisterMatmul,
) -> Tile<L, Sc, ReadWrite> {
Tile::new_Register(RegisterTile::<L> {
data: Array::new((config.tile_size.m() * config.tile_size.k()) as usize),
matrix_layout: layout,
config,
})
}
#[cube]
pub fn register_allocate_rhs<R: Numeric, Sc: TileScope>(
#[comptime] layout: MatrixLayout,
#[comptime] config: RegisterMatmul,
) -> Tile<R, Sc, ReadWrite> {
Tile::new_Register(RegisterTile::<R> {
data: Array::new((config.tile_size.n() * config.tile_size.k()) as usize),
matrix_layout: layout,
config,
})
}
#[cube]
pub fn register_allocate_acc<A: Numeric, Sc: TileScope>(
#[comptime] layout: MatrixLayout,
#[comptime] config: RegisterMatmul,
) -> Tile<A, Sc, ReadWrite> {
Tile::new_Register(RegisterTile::<A> {
data: Array::new((config.tile_size.m() * config.tile_size.n()) as usize),
matrix_layout: layout,
config,
})
}