cubecl_matmul/kernels/layered/algorithm/
simple_barrier.rs

1use cubecl_core::{Runtime, client::ComputeClient};
2
3use std::marker::PhantomData;
4
5use crate::{
6    components::{
7        MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
8        batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
9        global::{
10            PlaneWriterFamily, read::AsyncFullLoadingStrategy,
11            single_stage::barrier::SimpleBarrierMatmulFamily,
12        },
13        stage::{FilledStageFamily, PlaneMatmulFamily, StridedStageFamily},
14        tile::{
15            self,
16            io::{Filled, Strided},
17        },
18    },
19    kernels::layered::{Algorithm, selector::plane_matmul_selection},
20};
21
22/// Plane accelerated single stage matmul with async barrier loading
23pub struct SimpleBarrierAlgorithm<TMM, L: AsyncFullLoadingStrategy> {
24    pub _tmm: PhantomData<TMM>,
25    pub _l: PhantomData<L>,
26}
27
28impl<TMM, L> Algorithm for SimpleBarrierAlgorithm<TMM, L>
29where
30    TMM: tile::TileMatmulFamily<
31            LhsTile = Strided,
32            RhsTile = Strided,
33            AccTile = Filled,
34            OutTile = Strided,
35        >,
36    L: AsyncFullLoadingStrategy,
37{
38    type SelectionArgs = ();
39    type TileMatmul = TMM;
40    type StageMatmul = PlaneMatmulFamily<
41        Self::TileMatmul,
42        StridedStageFamily,
43        StridedStageFamily,
44        FilledStageFamily,
45    >;
46    type GlobalMatmul = SimpleBarrierMatmulFamily<Self::StageMatmul, L, L, PlaneWriterFamily>;
47
48    type BatchMatmul =
49        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
50
51    fn selection<R: Runtime>(
52        client: &ComputeClient<R::Server>,
53        problem: &MatmulProblem,
54        plane_dim: u32,
55        _line_sizes: &MatmulLineSizes,
56        elems: MatmulElems,
57        _args: &Self::SelectionArgs,
58    ) -> Result<MatmulSelection, MatmulSetupError> {
59        plane_matmul_selection::<TMM, R>(client, problem, plane_dim, elems, Default::default())
60    }
61}