cubecl_matmul/kernels/layered/algorithm/
simple_barrier.rs1use cubecl_core::{Runtime, client::ComputeClient};
2
3use std::marker::PhantomData;
4
5use crate::{
6 components::{
7 MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
8 batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
9 global::{
10 PlaneWriterFamily, read::AsyncFullLoadingStrategy,
11 single_stage::barrier::SimpleBarrierMatmulFamily,
12 },
13 stage::{FilledStageFamily, PlaneMatmulFamily, StridedStageFamily},
14 tile::{
15 self,
16 io::{Filled, Strided},
17 },
18 },
19 kernels::layered::{Algorithm, selector::plane_matmul_selection},
20};
21
22pub struct SimpleBarrierAlgorithm<TMM, L: AsyncFullLoadingStrategy> {
24 pub _tmm: PhantomData<TMM>,
25 pub _l: PhantomData<L>,
26}
27
28impl<TMM, L> Algorithm for SimpleBarrierAlgorithm<TMM, L>
29where
30 TMM: tile::TileMatmulFamily<
31 LhsTile = Strided,
32 RhsTile = Strided,
33 AccTile = Filled,
34 OutTile = Strided,
35 >,
36 L: AsyncFullLoadingStrategy,
37{
38 type SelectionArgs = ();
39 type TileMatmul = TMM;
40 type StageMatmul = PlaneMatmulFamily<
41 Self::TileMatmul,
42 StridedStageFamily,
43 StridedStageFamily,
44 FilledStageFamily,
45 >;
46 type GlobalMatmul = SimpleBarrierMatmulFamily<Self::StageMatmul, L, L, PlaneWriterFamily>;
47
48 type BatchMatmul =
49 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
50
51 fn selection<R: Runtime>(
52 client: &ComputeClient<R::Server>,
53 problem: &MatmulProblem,
54 plane_dim: u32,
55 _line_sizes: &MatmulLineSizes,
56 elems: MatmulElems,
57 _args: &Self::SelectionArgs,
58 ) -> Result<MatmulSelection, MatmulSetupError> {
59 plane_matmul_selection::<TMM, R>(client, problem, plane_dim, elems, Default::default())
60 }
61}