use cubecl::{
std::tensor::layout::Coords2d,
{ir::DeviceProperties, prelude::*},
};
use cubek_std::{
InvalidConfigError,
stage::StageMemoryConfig,
tile::{Tile, TileScope},
};
use crate::{
components::{
CubeDimResource,
global::{PlaneFlowConfig, WriteEventListener},
stage::{NumStages, PartitionScheduler},
},
definition::{
Acc, Lhs, MatmulElems, MatmulSetupError, MatmulTypes, MatmulVectorSizes, Rhs,
TilingBlueprint,
},
};
use std::{fmt::Debug, hash::Hash};
use super::{StageEventListener, TilingLayout};
type Ty<T> = crate::definition::Stage<T>;
type Sz<T> = crate::definition::StageSize<T>;
pub trait StageMatmulFamily: Send + Sync + 'static {
type Scope: TileScope;
type Matmul<MP: MatmulTypes, TL: TilingLayout, TR: TilingLayout, TA: TilingLayout, TO: TilingLayout>: StageMatmul<
MP,
Config = Self::Config,
Scope = Self::Scope,
LhsStage = <Self::LhsStage as StageFamily>::Stage<Ty<Lhs<MP>>, Sz<Lhs<MP>>, TL>,
RhsStage = <Self::RhsStage as StageFamily>::Stage<Ty<Rhs<MP>>, Sz<Rhs<MP>>, TR>,
AccStage = <Self::AccStage as StageFamily>::Stage<Ty<Acc<MP>>, Sz<Acc<MP>>, TA>,
OutStage = <Self::OutStage as StageFamily<ReadWrite>>::Stage<Ty<Acc<MP>>, Sz<Acc<MP>>, TO>,
>;
type LhsStage: StageFamily;
type RhsStage: StageFamily;
type AccStage: StageFamily;
type OutStage: StageFamily<ReadWrite>;
type Config: StageConfig;
#[allow(clippy::too_many_arguments)]
fn expand_config(
device_props: &DeviceProperties,
blueprint: &TilingBlueprint,
plane_flow_config: PlaneFlowConfig,
num_stages: NumStages,
dtypes: &MatmulElems,
vector_sizes: &MatmulVectorSizes,
) -> Result<Self::Config, MatmulSetupError>;
fn cubedim_resource(blueprint: &TilingBlueprint)
-> Result<CubeDimResource, InvalidConfigError>;
fn validate_blueprint<R: Runtime>(
client: &ComputeClient<R>,
blueprint: &TilingBlueprint,
dtypes: &MatmulElems,
vector_sizes: &MatmulVectorSizes,
) -> Result<(), MatmulSetupError>;
}
#[cube]
pub trait StageMatmul<MP: MatmulTypes>: 'static + Send + Sync {
type Config: StageConfig;
type Scope: TileScope;
type Accumulators: CubeType;
type LhsStage: CubeType;
type RhsStage: CubeType;
type AccStage: CubeType;
type OutStage: CubeType;
type LhsTile: CubeType;
type RhsTile: CubeType;
fn execute(
lhs: &Self::LhsStage,
rhs: &Self::RhsStage,
instruction_lhs: &mut Self::LhsTile,
instruction_rhs: &mut Self::RhsTile,
acc: &mut Self::Accumulators,
#[comptime] config: Self::Config,
partition_scheduler: &PartitionScheduler,
);
fn execute_with_listener<SEL: StageEventListener>(
lhs: &Self::LhsStage,
rhs: &Self::RhsStage,
instruction_lhs: &mut Self::LhsTile,
instruction_rhs: &mut Self::RhsTile,
acc: &mut Self::Accumulators,
#[comptime] config: Self::Config,
listener: SEL,
partition_scheduler: &PartitionScheduler,
);
fn init_tile_inputs(#[comptime] config: Self::Config) -> (Self::LhsTile, Self::RhsTile);
fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators;
fn load_accumulators(
reader: &Self::AccStage,
acc: &mut Self::Accumulators,
partition_scheduler: &PartitionScheduler,
#[comptime] config: Self::Config,
);
fn write_results<W: WriteEventListener>(
acc: &mut Self::Accumulators,
stage: &mut Self::OutStage,
listener: &mut W,
partition_scheduler: &PartitionScheduler,
#[comptime] stage_config: Self::Config,
);
fn init_scheduler(#[comptime] config: Self::Config) -> PartitionScheduler;
}
pub trait StageConfig:
Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
{
fn elements_in_stage_m(&self) -> u32;
fn elements_in_stage_n(&self) -> u32;
fn elements_in_stage_k(&self) -> u32;
fn elements_in_tile_k(&self) -> u32;
fn tiles_in_partition_mn(&self) -> u32;
fn num_main_flow_planes(&self) -> u32;
fn plane_dim(&self) -> u32;
fn plane_flow_config(&self) -> PlaneFlowConfig;
fn lhs_smem_config(&self) -> StageMemoryConfig;
fn rhs_smem_config(&self) -> StageMemoryConfig;
fn acc_smem_config(&self) -> StageMemoryConfig;
fn out_smem_config(&self) -> StageMemoryConfig;
}
#[derive(Default, Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum PartitionBuffering {
Single,
#[default]
Double,
}
#[cube]
pub trait Stage<ES: Numeric, IO: SliceVisibility = ReadOnly>:
CubeType + Clone + Send + Sync + 'static
{
fn tile<Sc: TileScope>(this: &Self, tile: Coords2d) -> Tile<ES, Sc, IO>;
}
pub trait StageFamily<IO: SliceVisibility = ReadOnly>: Send + Sync + 'static {
type Stage<ES: Numeric, NS: Size, T: TilingLayout>: Stage<ES, IO>;
}
#[cube]
pub trait LoadStageFamily<IO: SliceVisibility = ReadOnly>: StageFamily {
fn create<ES: Numeric, NS: Size, T: TilingLayout>(
#[comptime] alignment: usize,
#[comptime] config: StageMemoryConfig,
) -> Self::Stage<ES, NS, T>;
fn with_buffer_index<ES: Numeric, NS: Size, T: TilingLayout>(
stage: &Self::Stage<ES, NS, T>,
buffer_index: u32,
) -> Self::Stage<ES, NS, T>;
fn free<ES: Numeric, NS: Size, T: TilingLayout>(stage: &Self::Stage<ES, NS, T>);
}
#[cube]
impl<ES: Numeric, IO: SliceVisibility, Inner: Stage<ES, IO>> Stage<ES, IO>
for ComptimeOption<Inner>
{
fn tile<Sc: TileScope>(this: &Self, tile: Coords2d) -> Tile<ES, Sc, IO> {
#[comptime]
if let ComptimeOption::Some(inner) = this {
Inner::tile::<Sc>(inner, tile)
} else {
Tile::new_None()
}
}
}
#[cube]
impl<IO: SliceVisibility, S: LoadStageFamily<IO>> LoadStageFamily<IO> for Option<S> {
fn create<ES: Numeric, NS: Size, T: TilingLayout>(
#[comptime] alignment: usize,
#[comptime] config: StageMemoryConfig,
) -> Self::Stage<ES, NS, T> {
ComptimeOption::new_Some(S::create(alignment, config))
}
fn with_buffer_index<ES: Numeric, NS: Size, T: TilingLayout>(
stage: &Self::Stage<ES, NS, T>,
index: u32,
) -> Self::Stage<ES, NS, T> {
stage.as_ref().map(|s| S::with_buffer_index(s, index))
}
fn free<ES: Numeric, NS: Size, T: TilingLayout>(stage: &Self::Stage<ES, NS, T>) {
#[comptime]
if let ComptimeOption::Some(inner) = stage {
S::free(inner)
}
}
}
impl<IO: SliceVisibility, Inner: StageFamily<IO>> StageFamily<IO> for Option<Inner> {
type Stage<ES: Numeric, NS: Size, T: TilingLayout> = ComptimeOption<Inner::Stage<ES, NS, T>>;
}