use crate::{
components::stage::NumStages,
definition::{
AccG, Blueprint, LhsG, MatmulElems, MatmulProblem, MatmulSetupError, MatmulTypes,
MatmulVectorSizes, RhsG,
},
};
use crate::{
definition::{CubeMapping, CubeMappingLaunch},
launch::{InputRuntimeArg, MatmulArgs, OutputRuntimeArg},
{components::CubeDimResource, launch::RuntimeConfig},
{components::global::memory::GlobalLayoutConfig, launch::ConfigRuntimeArg},
};
use cubecl::{ir::DeviceProperties, prelude::*};
use std::{fmt::Debug, hash::Hash};
pub trait BatchMatmulFamily<RC: RuntimeConfig>: 'static + Send + Sync {
type Matmul<MP: MatmulTypes>: BatchMatmul<RC, MP, Config = Self::Config>;
type Config: BatchConfig;
type Blueprint: Blueprint;
fn expand_config(
device_props: &DeviceProperties,
blueprint: &Self::Blueprint,
dtypes: &MatmulElems,
vector_sizes: &MatmulVectorSizes,
) -> Result<Self::Config, MatmulSetupError>;
fn num_stages() -> NumStages;
#[allow(clippy::too_many_arguments)]
unsafe fn launch_unchecked<MA: MatmulArgs<Config = RC>, R: Runtime>(
client: &ComputeClient<R>,
cube_dim: CubeDim,
cube_count: CubeCount,
address_type: AddressType,
input: InputRuntimeArg<MA, R>,
output: OutputRuntimeArg<MA, R>,
config: ConfigRuntimeArg<MA, R>,
cube_mapping: CubeMappingLaunch<R>,
blueprint: Self::Blueprint,
dtypes: &MatmulElems,
vector_sizes: &MatmulVectorSizes,
) -> Result<(), LaunchError>;
fn cubedim_resource(
blueprint: &Self::Blueprint,
dtypes: &MatmulElems,
vector_sizes: &MatmulVectorSizes,
) -> Result<CubeDimResource, MatmulSetupError>;
fn validate_blueprint<R: Runtime>(
client: &ComputeClient<R>,
blueprint: &Self::Blueprint,
problem: &MatmulProblem,
dtypes: &MatmulElems,
vector_sizes: &MatmulVectorSizes,
) -> Result<(), MatmulSetupError>;
}
#[cube]
pub trait BatchMatmul<RC: RuntimeConfig, MP: MatmulTypes>: 'static + Send + Sync {
type Config: BatchConfig;
fn execute<Args: MatmulArgs<Config = RC>>(
state: &mut Args::State<LhsG<MP>, RhsG<MP>, AccG<MP>>,
cube_mapping: CubeMapping,
#[comptime] config: Self::Config,
);
}
pub trait BatchConfig:
Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
{
fn lhs_global_layout_config(&self) -> GlobalLayoutConfig;
fn rhs_global_layout_config(&self) -> GlobalLayoutConfig;
fn out_global_layout_config(&self) -> GlobalLayoutConfig;
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum CheckBounds {
None,
Checked,
Terminate,
}