cubecl_matmul/kernels/layered/algorithm/
simple_tma.rs1use core::marker::PhantomData;
2
3use cubecl_core::{Runtime, client::ComputeClient};
4
5use crate::{
6 components::{
7 MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
8 batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
9 global::{PlaneWriterFamily, single_stage::tma::SimpleTmaMatmulFamily},
10 stage::{FilledStageFamily, PlaneMatmulFamily, StridedStageFamily},
11 tile::{
12 TileMatmulFamily,
13 io::{Filled, Strided},
14 },
15 },
16 kernels::layered::{Algorithm, selector::plane_matmul_selection},
17};
18
19pub struct SimpleTmaAlgorithm<TMM> {
21 pub _tmm: PhantomData<TMM>,
22}
23
24impl<TMM> Algorithm for SimpleTmaAlgorithm<TMM>
25where
26 TMM:
27 TileMatmulFamily<LhsTile = Strided, RhsTile = Strided, AccTile = Filled, OutTile = Strided>,
28{
29 type SelectionArgs = ();
30 type TileMatmul = TMM;
31 type StageMatmul = PlaneMatmulFamily<
32 Self::TileMatmul,
33 StridedStageFamily,
34 StridedStageFamily,
35 FilledStageFamily,
36 >;
37 type GlobalMatmul = SimpleTmaMatmulFamily<Self::StageMatmul, PlaneWriterFamily>;
38 type BatchMatmul =
39 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
40
41 fn selection<R: Runtime>(
42 client: &ComputeClient<R::Server>,
43 problem: &MatmulProblem,
44 plane_dim: u32,
45 _line_sizes: &MatmulLineSizes,
46 elems: MatmulElems,
47 _args: &Self::SelectionArgs,
48 ) -> Result<MatmulSelection, MatmulSetupError> {
49 plane_matmul_selection::<TMM, R>(client, problem, plane_dim, elems, Default::default())
50 }
51}