use cubecl::{ir::DeviceProperties, prelude::*};
use cubek_std::stage::StageMemoryConfig;
use crate::components::global::{
GlobalWriterConfig, InputLoadFlow, LoadFlows, PlaneFlowConfig, SpecializedLoadingSides,
};
use crate::{
components::global::multi_stage::EventLoadingMode, components::global::read::ReaderMode,
};
use crate::{
components::stage::StageConfig,
components::{global::memory::GlobalMemoryConfig, stage::NumStages},
definition::StageIdent,
definition::TilingBlueprint,
definition::{AccG, MatmulSetupError},
definition::{LhsG, MatmulElems, MatmulVectorSizes, RhsG},
definition::{MatmulProblem, MatmulTypes},
{components::CubeDimResource, launch::RuntimeConfig},
};
use cubecl::std::tensor::{View, layout::Coords2d};
use std::{fmt::Debug, hash::Hash};
pub trait GlobalMatmulFamily<RC: RuntimeConfig>: Send + Sync + 'static {
type Matmul<MP: MatmulTypes>: GlobalMatmul<RC, MP, Config = Self::Config>;
type Config: GlobalConfig;
fn expand_config(
device_props: &DeviceProperties,
blueprint: &TilingBlueprint,
dtypes: &MatmulElems,
vector_sizes: &MatmulVectorSizes,
) -> Result<Self::Config, MatmulSetupError>;
fn num_stages() -> NumStages;
fn cubedim_resource(
blueprint: &TilingBlueprint,
dtypes: &MatmulElems,
vector_sizes: &MatmulVectorSizes,
) -> Result<CubeDimResource, MatmulSetupError>;
fn validate_blueprint<R: Runtime>(
client: &ComputeClient<R>,
blueprint: &TilingBlueprint,
problem: &MatmulProblem,
dtypes: &MatmulElems,
vector_sizes: &MatmulVectorSizes,
) -> Result<(), MatmulSetupError>;
}
#[cube]
pub trait GlobalMatmul<RC: RuntimeConfig, MP: MatmulTypes>: 'static + Send + Sync {
type Config: GlobalConfig;
type LhsGlobalReader: CubeType;
type RhsGlobalReader: CubeType;
type AccGlobalReader: CubeType;
type GlobalWriter: CubeType;
type Accumulators: CubeType;
fn execute(
lhs_reader: Self::LhsGlobalReader,
rhs_reader: Self::RhsGlobalReader,
acc_reader: Self::AccGlobalReader,
writer: Self::GlobalWriter,
k_range: (u32, u32),
#[comptime] config: Self::Config,
);
fn init_lhs_global_reader(
lhs: View<LhsG<MP>, Coords2d>,
runtime_config: RC,
#[comptime] config: Self::Config,
) -> Self::LhsGlobalReader;
fn init_rhs_global_reader(
rhs: View<RhsG<MP>, Coords2d>,
runtime_config: RC,
#[comptime] config: Self::Config,
) -> Self::RhsGlobalReader;
fn init_acc_global_reader(
acc: ComptimeOption<View<AccG<MP>, Coords2d>>,
runtime_config: RC,
#[comptime] config: Self::Config,
) -> Self::AccGlobalReader;
fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators;
fn init_global_writer(
out: View<AccG<MP>, Coords2d, ReadWrite>,
#[comptime] config: Self::Config,
) -> Self::GlobalWriter;
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct SharedGlobalMatmulConfig<S: StageConfig> {
pub stage_config: S,
pub num_planes: u32,
pub lhs_reader_config: GlobalReaderConfig,
pub rhs_reader_config: GlobalReaderConfig,
pub acc_reader_config: GlobalReaderConfig,
pub writer_config: GlobalWriterConfig,
pub must_sync_plane_after_execution: bool,
}
impl<S: StageConfig> SharedGlobalMatmulConfig<S> {
pub fn check_k_bounds(&self) -> bool {
let from_lhs = self.lhs_reader_config.gmem_config.check_col_bounds;
let from_rhs = self.rhs_reader_config.gmem_config.check_row_bounds;
assert!(from_lhs == from_rhs);
from_lhs
}
pub fn plane_dim(&self) -> u32 {
self.stage_config.plane_dim()
}
pub fn plane_flow_config(&self) -> PlaneFlowConfig {
self.stage_config.plane_flow_config()
}
pub fn specialized_loading_sides(&self) -> SpecializedLoadingSides {
LoadFlows {
lhs: self.lhs_reader_config.input_load_flow,
rhs: self.rhs_reader_config.input_load_flow,
}
.into()
}
}
impl<S: StageConfig> GlobalConfig for SharedGlobalMatmulConfig<S> {
type StageConfig = S;
fn stage_config(&self) -> Self::StageConfig {
self.stage_config
}
fn lhs_reader_config(&self) -> GlobalReaderConfig {
self.lhs_reader_config
}
fn rhs_reader_config(&self) -> GlobalReaderConfig {
self.rhs_reader_config
}
fn cube_dim(&self) -> CubeDim {
CubeDim::new_2d(self.plane_dim(), self.num_planes)
}
fn global_vector_sizes(&self) -> MatmulVectorSizes {
MatmulVectorSizes {
lhs: self.lhs_reader_config.gmem_config.vector_size,
rhs: self.rhs_reader_config.gmem_config.vector_size,
out: self.writer_config.gmem_config.vector_size,
}
}
fn writer_config(&self) -> GlobalWriterConfig {
self.writer_config
}
fn must_sync_plane_after_execution(&self) -> bool {
self.must_sync_plane_after_execution
}
}
pub trait GlobalConfig:
Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
{
type StageConfig: StageConfig;
fn stage_config(&self) -> Self::StageConfig;
fn lhs_reader_config(&self) -> GlobalReaderConfig;
fn rhs_reader_config(&self) -> GlobalReaderConfig;
fn writer_config(&self) -> GlobalWriterConfig;
fn cube_dim(&self) -> CubeDim;
fn global_vector_sizes(&self) -> MatmulVectorSizes;
fn must_sync_plane_after_execution(&self) -> bool;
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct GlobalReaderConfig {
pub gmem_config: GlobalMemoryConfig,
pub smem_config: StageMemoryConfig,
pub precompute_job: bool,
pub plane_dim: u32,
pub reader_mode: ReaderMode,
pub event_loading_mode: EventLoadingMode,
pub input_load_flow: InputLoadFlow,
pub plane_flow_config: PlaneFlowConfig,
pub stage_ident: StageIdent,
}
impl GlobalReaderConfig {
pub fn loading_planes_count(&self) -> u32 {
self.smem_config.num_planes
}
pub fn loading_units_count(&self) -> u32 {
self.plane_dim * self.loading_planes_count()
}
}