cubecl_matmul/kernels/layered/algorithm/
simple_tma.rs

1use core::marker::PhantomData;
2
3use cubecl_core::{Runtime, client::ComputeClient};
4
5use crate::{
6    components::{
7        MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
8        batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
9        global::{PlaneWriterFamily, single_stage::tma::SimpleTmaMatmulFamily},
10        stage::{FilledStageFamily, PlaneMatmulFamily, StridedStageFamily},
11        tile::{
12            TileMatmulFamily,
13            io::{Filled, Strided},
14        },
15    },
16    kernels::layered::{Algorithm, selector::plane_matmul_selection},
17};
18
19/// Plane accelerated single stage matmul with tma loading
20pub struct SimpleTmaAlgorithm<TMM> {
21    pub _tmm: PhantomData<TMM>,
22}
23
24impl<TMM> Algorithm for SimpleTmaAlgorithm<TMM>
25where
26    TMM:
27        TileMatmulFamily<LhsTile = Strided, RhsTile = Strided, AccTile = Filled, OutTile = Strided>,
28{
29    type SelectionArgs = ();
30    type TileMatmul = TMM;
31    type StageMatmul = PlaneMatmulFamily<
32        Self::TileMatmul,
33        StridedStageFamily,
34        StridedStageFamily,
35        FilledStageFamily,
36    >;
37    type GlobalMatmul = SimpleTmaMatmulFamily<Self::StageMatmul, PlaneWriterFamily>;
38    type BatchMatmul =
39        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
40
41    fn selection<R: Runtime>(
42        client: &ComputeClient<R::Server>,
43        problem: &MatmulProblem,
44        plane_dim: u32,
45        _line_sizes: &MatmulLineSizes,
46        elems: MatmulElems,
47        _args: &Self::SelectionArgs,
48    ) -> Result<MatmulSelection, MatmulSetupError> {
49        plane_matmul_selection::<TMM, R>(client, problem, plane_dim, elems, Default::default())
50    }
51}