use cubecl::{CubeDim, Runtime, client::ComputeClient, flex32, prelude::CubePrimitive, tf32};
use cubek_std::{
MatrixLayout, SwizzleModes,
cube_count::{Count3d, CubeCountPlan, HypercubeBlueprint},
};
use crate::{
components::{
CubeDimResource,
global::{LoadFlows, memory::GlobalLayoutConfig, read::ReaderMode},
stage::PartitionBuffering,
tile::TileMatmulKind,
},
definition::{MatmulElems, MatmulProblem, MatmulSetupError, TilingScheme},
routines::DeviceSettings,
};
use std::{fmt::Debug, hash::Hash};
pub trait Blueprint: Debug + Clone + Eq + PartialEq + Hash {
fn lhs_global_layout_config(&self) -> GlobalLayoutConfig;
fn rhs_global_layout_config(&self) -> GlobalLayoutConfig;
fn out_global_layout_config(&self) -> GlobalLayoutConfig;
fn tiling_scheme(&self) -> TilingScheme;
fn swizzle_modes(&self) -> SwizzleModes;
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TilingBlueprint {
pub plane_dim: u32,
pub tile_matmul: TileMatmulKind,
pub tiling_scheme: TilingScheme,
pub swizzle_modes: SwizzleModes,
pub partition_buffering: PartitionBuffering,
pub loading_precompute_strategy: LoadingPrecomputeStrategy,
pub reader_mode: ReaderMode,
pub load_flows: LoadFlows,
pub hypercube_blueprint: HypercubeBlueprint,
pub lhs_layout: MatrixLayout,
pub rhs_layout: MatrixLayout,
pub check_m_bounds: bool,
pub check_n_bounds: bool,
pub check_k_bounds: bool,
}
impl Blueprint for TilingBlueprint {
fn lhs_global_layout_config(&self) -> GlobalLayoutConfig {
GlobalLayoutConfig {
matrix_layout: self.lhs_layout,
check_row_bounds: self.check_m_bounds,
check_col_bounds: self.check_k_bounds,
}
}
fn rhs_global_layout_config(&self) -> GlobalLayoutConfig {
GlobalLayoutConfig {
matrix_layout: self.rhs_layout,
check_row_bounds: self.check_k_bounds,
check_col_bounds: self.check_n_bounds,
}
}
fn out_global_layout_config(&self) -> GlobalLayoutConfig {
GlobalLayoutConfig {
matrix_layout: MatrixLayout::RowMajor,
check_row_bounds: self.check_m_bounds,
check_col_bounds: self.check_n_bounds,
}
}
fn tiling_scheme(&self) -> TilingScheme {
self.tiling_scheme
}
fn swizzle_modes(&self) -> SwizzleModes {
self.swizzle_modes
}
}
pub fn adjust_dtypes<R: Runtime>(
client: &ComputeClient<R>,
dtypes: &mut MatmulElems,
requires_accelerator: bool,
) {
let f32_dtype = f32::as_type_native_unchecked().storage_type();
let flex_dtype = flex32::as_type_native_unchecked().storage_type();
let tf32_dtype = tf32::as_type_native_unchecked().storage_type();
let f16_dtype = half::f16::as_type_native_unchecked().storage_type();
if requires_accelerator {
if dtypes.lhs_global == f32_dtype
&& dtypes.rhs_global == f32_dtype
&& client.properties().supports_type(tf32_dtype)
{
dtypes.lhs_stage = tf32_dtype;
dtypes.rhs_stage = tf32_dtype;
dtypes.lhs_register = tf32_dtype;
dtypes.rhs_register = tf32_dtype;
} else if dtypes.lhs_global == flex_dtype
&& dtypes.rhs_global == flex_dtype
&& client.properties().supports_type(f16_dtype)
{
dtypes.lhs_stage = f16_dtype;
dtypes.rhs_stage = f16_dtype;
dtypes.lhs_register = f16_dtype;
dtypes.rhs_register = f16_dtype;
}
}
}
impl TilingBlueprint {
pub fn builder(
tile_matmul: TileMatmulKind,
tiling_scheme: TilingScheme,
plane_dim: u32,
problem: &MatmulProblem,
) -> TilingBlueprintBuilder {
let hypercube_blueprint = HypercubeBlueprint::builder().build();
let check_m_bounds =
!(problem.m as u32).is_multiple_of(tiling_scheme.elements_per_stage_along_m());
let check_n_bounds =
!(problem.n as u32).is_multiple_of(tiling_scheme.elements_per_stage_along_n());
let check_k_bounds =
!(problem.k as u32).is_multiple_of(tiling_scheme.elements_per_stage_along_k());
TilingBlueprintBuilder {
plane_dim,
tile_matmul,
tiling_scheme,
hypercube_blueprint,
check_m_bounds,
check_n_bounds,
check_k_bounds,
lhs_layout: problem.lhs_layout,
rhs_layout: problem.rhs_layout,
shared_swizzle: Default::default(),
partition_buffering: PartitionBuffering::default(),
loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
reader_mode: ReaderMode::default(),
load_specialization_config: LoadFlows::default(),
}
}
pub fn cube_launch_info<R: Runtime>(
&self,
cubedim_resource: CubeDimResource,
problem: &MatmulProblem,
device_settings: &DeviceSettings<R>,
) -> Result<(CubeDim, CubeCountPlan), MatmulSetupError> {
let plane_dim = device_settings.plane_dim;
let cube_dim = cubedim_resource.to_cube_dim(plane_dim)?;
let target_cube_count = Count3d {
x: (problem.m as u32)
.div_ceil(self.tiling_scheme.elements_per_global_partition_along_m()),
y: (problem.n as u32)
.div_ceil(self.tiling_scheme.elements_per_global_partition_along_n()),
z: (problem.num_batches() as u32)
.div_ceil(self.tiling_scheme.global_partition_size.batches),
};
let cube_count_plan = CubeCountPlan::from_blueprint(
&self.hypercube_blueprint,
target_cube_count,
&device_settings.max_cube_count,
);
Ok((cube_dim, cube_count_plan))
}
}
pub struct TilingBlueprintBuilder {
plane_dim: u32,
tile_matmul: TileMatmulKind,
tiling_scheme: TilingScheme,
check_m_bounds: bool,
check_n_bounds: bool,
check_k_bounds: bool,
lhs_layout: MatrixLayout,
rhs_layout: MatrixLayout,
hypercube_blueprint: HypercubeBlueprint,
shared_swizzle: SwizzleModes,
partition_buffering: PartitionBuffering,
loading_precompute_strategy: LoadingPrecomputeStrategy,
reader_mode: ReaderMode,
load_specialization_config: LoadFlows,
}
impl TilingBlueprintBuilder {
pub fn hypercube_blueprint(mut self, hypercube_blueprint: HypercubeBlueprint) -> Self {
self.hypercube_blueprint = hypercube_blueprint;
self
}
pub fn shared_swizzle(mut self, swizzle: SwizzleModes) -> Self {
self.shared_swizzle = swizzle;
self
}
pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
self.partition_buffering = partition_buffering;
self
}
pub fn loading_precompute_strategy(
mut self,
loading_precompute_strategy: LoadingPrecomputeStrategy,
) -> Self {
self.loading_precompute_strategy = loading_precompute_strategy;
self
}
pub fn reader_mode(mut self, reader_mode: ReaderMode) -> Self {
self.reader_mode = reader_mode;
self
}
pub fn load_specialization_config(mut self, load_specialization_config: LoadFlows) -> Self {
self.load_specialization_config = load_specialization_config;
self
}
pub fn build(self) -> TilingBlueprint {
TilingBlueprint {
plane_dim: self.plane_dim,
tile_matmul: self.tile_matmul,
tiling_scheme: self.tiling_scheme,
swizzle_modes: self.shared_swizzle,
hypercube_blueprint: self.hypercube_blueprint,
partition_buffering: self.partition_buffering,
loading_precompute_strategy: self.loading_precompute_strategy,
reader_mode: self.reader_mode,
load_flows: self.load_specialization_config,
lhs_layout: self.lhs_layout,
rhs_layout: self.rhs_layout,
check_m_bounds: self.check_m_bounds,
check_n_bounds: self.check_n_bounds,
check_k_bounds: self.check_k_bounds,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum MultiRowStrategy {
#[default]
Never,
Always(u32),
Adaptive { minimum_stage_count: u32 },
}
#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum LoadingPrecomputeStrategy {
#[default]
Never,
Always,
}
impl From<LoadingPrecomputeStrategy> for bool {
fn from(strategy: LoadingPrecomputeStrategy) -> Self {
match strategy {
LoadingPrecomputeStrategy::Always => true,
LoadingPrecomputeStrategy::Never => false,
}
}
}