use std::fmt::Display;
use cubecl::{Runtime, client::ComputeClient, prelude::TensorBinding};
use cubek_std::InputBinding;
use crate::{
components::{
global::read::{
async_full_cooperative::AsyncFullCooperativeLoading,
async_full_cyclic::AsyncFullCyclicLoading,
},
stage::ColMajorTilingOrder,
tile::TileMatmulKind,
},
definition::{MatmulElems, MatmulSetupError},
launch::launch_tiling,
routines::{
BlueprintStrategy, Routine, TilingArgs, interleaved::InterleavedAlgorithm,
simple::SimpleBarrierAlgorithm,
},
};
#[allow(clippy::type_complexity)]
#[derive(Clone)]
pub enum TestStrategy {
SimpleBarrierCooperativeCmma(
BlueprintStrategy<(), SimpleBarrierAlgorithm<AsyncFullCooperativeLoading>>,
),
SimpleBarrierCooperativeMma(
BlueprintStrategy<(), SimpleBarrierAlgorithm<AsyncFullCooperativeLoading>>,
),
SimpleBarrierCyclicCmma(
BlueprintStrategy<(), SimpleBarrierAlgorithm<AsyncFullCyclicLoading<ColMajorTilingOrder>>>,
),
SimpleBarrierCyclicMma(
BlueprintStrategy<(), SimpleBarrierAlgorithm<AsyncFullCyclicLoading<ColMajorTilingOrder>>>,
),
Interleaved(BlueprintStrategy<(), InterleavedAlgorithm>),
}
impl Display for TestStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SimpleBarrierCooperativeCmma(s) => {
write!(f, "matmul_simple_barrier_cooperative_cmma{}", s)
}
Self::SimpleBarrierCooperativeMma(s) => {
write!(f, "matmul_simple_barrier_cooperative_mma{}", s)
}
Self::SimpleBarrierCyclicCmma(s) => {
write!(f, "matmul_simple_barrier_cyclic_cmma{}", s)
}
Self::SimpleBarrierCyclicMma(s) => {
write!(f, "matmul_simple_barrier_cyclic_mma{}", s)
}
Self::Interleaved(s) => write!(f, "matmul_interleaved{}", s),
}
}
}
fn with_kind<RC, A>(
sel: &BlueprintStrategy<RC, A>,
kind: TileMatmulKind,
) -> BlueprintStrategy<RC, A>
where
RC: crate::launch::RuntimeConfig,
A: Routine<RC>,
A::Strategy: TilingArgs,
{
let mut sel = sel.clone();
if let BlueprintStrategy::Inferred(args) = &mut sel {
args.set_tile_matmul(kind);
}
sel
}
#[allow(clippy::result_large_err)]
impl TestStrategy {
pub fn launch_ref<R: Runtime>(
&self,
client: &ComputeClient<R>,
lhs: InputBinding<R>,
rhs: InputBinding<R>,
out: TensorBinding<R>,
dtypes: &mut MatmulElems,
) -> Result<(), MatmulSetupError> {
use TileMatmulKind::{Cmma, Mma};
match self {
Self::SimpleBarrierCooperativeCmma(sel) => {
launch_tiling::launch_ref(client, lhs, rhs, out, &with_kind(sel, Cmma), dtypes)
}
Self::SimpleBarrierCooperativeMma(sel) => {
launch_tiling::launch_ref(client, lhs, rhs, out, &with_kind(sel, Mma), dtypes)
}
Self::SimpleBarrierCyclicCmma(sel) => {
launch_tiling::launch_ref(client, lhs, rhs, out, &with_kind(sel, Cmma), dtypes)
}
Self::SimpleBarrierCyclicMma(sel) => {
launch_tiling::launch_ref(client, lhs, rhs, out, &with_kind(sel, Mma), dtypes)
}
Self::Interleaved(sel) => launch_tiling::launch_ref(client, lhs, rhs, out, sel, dtypes),
}
}
}