cubecl_matmul/kernels/layered/algorithm/
ordered_double_buffering.rs1use std::marker::PhantomData;
2
3use cubecl_core::Runtime;
4use cubecl_core::client::ComputeClient;
5
6use crate::components::stage::{PlaneMatmulFamily, RowMajorTilingOrder};
7use crate::components::{
8 MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
9 global::PlaneWriterFamily,
10};
11use crate::components::{MultiRowStrategy, tile};
12use crate::components::{
13 batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
14 stage::{FilledStageFamily, StridedStageFamily},
15};
16use crate::components::{
17 global::multi_stage::ordered::OrderedDoubleBufferingMatmulFamily, tile::io::Filled,
18};
19use crate::components::{
20 global::read::sync_partial_cyclic::SyncPartialCyclicLoading, tile::io::Strided,
21};
22use crate::kernels::layered::Algorithm;
23use crate::kernels::layered::selector::{PlaneMatmulSelectionOptions, plane_matmul_selection};
24
25pub struct OrderedDoubleBufferingAlgorithm<TMM> {
27 pub _phantom: PhantomData<TMM>,
28}
29
30#[derive(Debug, Clone, Default)]
31pub struct OrderedSelectionArgs {
32 pub partition_k: Option<u32>,
33 pub row_count: Option<u32>,
34 pub rows_per_plane: Option<u32>,
35}
36
37impl<TMM> Algorithm for OrderedDoubleBufferingAlgorithm<TMM>
38where
39 TMM: tile::TileMatmulFamily<
40 LhsTile = Strided,
41 RhsTile = Strided,
42 AccTile = Filled,
43 OutTile = Strided,
44 >,
45{
46 type SelectionArgs = OrderedSelectionArgs;
47 type TileMatmul = TMM;
48 type StageMatmul = PlaneMatmulFamily<
49 Self::TileMatmul,
50 StridedStageFamily,
51 StridedStageFamily,
52 FilledStageFamily,
53 >;
54 type GlobalMatmul = OrderedDoubleBufferingMatmulFamily<
55 Self::StageMatmul,
56 SyncPartialCyclicLoading<RowMajorTilingOrder>,
57 PlaneWriterFamily,
58 >;
59 type BatchMatmul =
60 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
61
62 fn selection<R: Runtime>(
63 client: &ComputeClient<R::Server>,
64 problem: &MatmulProblem,
65 plane_dim: u32,
66 _line_sizes: &MatmulLineSizes,
67 elems: MatmulElems,
68 args: &Self::SelectionArgs,
69 ) -> Result<MatmulSelection, MatmulSetupError> {
70 plane_matmul_selection::<TMM, R>(
71 client,
72 problem,
73 plane_dim,
74 elems,
75 PlaneMatmulSelectionOptions {
76 partition_k: args.partition_k,
77 row_count: args.row_count,
78 multi_row_strategy: args
79 .rows_per_plane
80 .map(MultiRowStrategy::Always)
81 .unwrap_or_else(|| MultiRowStrategy::Adaptive {
82 minimum_stage_count: 8,
83 }),
84 ..Default::default()
85 },
86 )
87 }
88}