cubecl_matmul/kernels/layered/algorithm/
double_buffering.rs1use std::marker::PhantomData;
2
3use cubecl_core::Runtime;
4use cubecl_core::client::ComputeClient;
5use cubecl_core::ir::Elem;
6
7use crate::components::MatmulSelection;
8use crate::components::batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul};
9use crate::components::global::load::sync_partial_cyclic::SyncPartialCyclicLoading;
10use crate::components::global::load::sync_partial_tilewise::SyncPartialTilewiseLoading;
11use crate::components::global::multi_stage::double_buffering::DoubleBufferingMatmulFamily;
12use crate::components::stage::{
13 ColMajorTilingOrder, PartialReaderFamily, PlaneMatmulFamily, RowMajorTilingOrder,
14};
15use crate::components::{MatmulProblem, MultiRowStrategy, tile};
16use crate::kernels::layered::Algorithm;
17use crate::kernels::layered::algorithm::base;
18use crate::kernels::layered::selector::{PlaneMatmulSelectionOptions, plane_matmul_selection};
19
20pub struct CyclicDoubleBufferingAlgorithm<TMM> {
22 pub _phantom: PhantomData<TMM>,
23}
24
25pub struct TilewiseDoubleBufferingAlgorithm<TMM> {
27 pub _phantom: PhantomData<TMM>,
28}
29
30pub struct HybridDoubleBufferingAlgorithm<TMM> {
32 pub _phantom: PhantomData<TMM>,
33}
34
35#[derive(Default, Debug, Clone, Copy)]
36pub struct DoubleBufferingArgs {
37 pub specialized: bool,
38}
39
40impl<TMM> base::Algorithm for CyclicDoubleBufferingAlgorithm<TMM>
41where
42 TMM: tile::TileMatmulFamily,
43{
44 type SelectionArgs = DoubleBufferingArgs;
45 type TileMatmul = TMM;
46 type StageMatmul =
47 PlaneMatmulFamily<Self::TileMatmul, PartialReaderFamily, PartialReaderFamily>;
48 type GlobalMatmul = DoubleBufferingMatmulFamily<
49 Self::StageMatmul,
50 SyncPartialCyclicLoading<RowMajorTilingOrder>,
51 SyncPartialCyclicLoading<RowMajorTilingOrder>,
52 >;
53 type BatchMatmul =
54 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
55
56 fn selection<R: Runtime>(
57 client: &ComputeClient<R::Server, R::Channel>,
58 problem: &MatmulProblem,
59 plane_dim: u32,
60 elem_stage: Elem,
61 elem_acc: Elem,
62 args: &Self::SelectionArgs,
63 ) -> MatmulSelection {
64 plane_matmul_selection::<TMM, R>(
65 client,
66 problem,
67 plane_dim,
68 elem_stage,
69 elem_acc,
70 PlaneMatmulSelectionOptions {
71 specialized: args.specialized,
72 multi_row_strategy: MultiRowStrategy::Adaptive {
73 minimum_stage_count: 8,
74 },
75 ..Default::default()
76 },
77 )
78 }
79}
80
81impl<TMM> Algorithm for TilewiseDoubleBufferingAlgorithm<TMM>
82where
83 TMM: tile::TileMatmulFamily,
84{
85 type SelectionArgs = DoubleBufferingArgs;
86 type TileMatmul = TMM;
87 type StageMatmul =
88 PlaneMatmulFamily<Self::TileMatmul, PartialReaderFamily, PartialReaderFamily>;
89 type GlobalMatmul = DoubleBufferingMatmulFamily<
90 Self::StageMatmul,
91 SyncPartialTilewiseLoading<RowMajorTilingOrder>,
93 SyncPartialTilewiseLoading<ColMajorTilingOrder>,
94 >;
95 type BatchMatmul =
96 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
97
98 fn selection<R: Runtime>(
99 client: &ComputeClient<R::Server, R::Channel>,
100 problem: &MatmulProblem,
101 plane_dim: u32,
102 elem_stage: Elem,
103 elem_acc: Elem,
104 args: &Self::SelectionArgs,
105 ) -> MatmulSelection {
106 plane_matmul_selection::<TMM, R>(
107 client,
108 problem,
109 plane_dim,
110 elem_stage,
111 elem_acc,
112 PlaneMatmulSelectionOptions {
113 specialized: args.specialized,
114 multi_row_strategy: MultiRowStrategy::Adaptive {
115 minimum_stage_count: 8,
116 },
117 ..Default::default()
118 },
119 )
120 }
121}
122
123impl<TMM> base::Algorithm for HybridDoubleBufferingAlgorithm<TMM>
124where
125 TMM: tile::TileMatmulFamily,
126{
127 type SelectionArgs = DoubleBufferingArgs;
128 type TileMatmul = TMM;
129 type StageMatmul =
130 PlaneMatmulFamily<Self::TileMatmul, PartialReaderFamily, PartialReaderFamily>;
131 type GlobalMatmul = DoubleBufferingMatmulFamily<
132 Self::StageMatmul,
133 SyncPartialTilewiseLoading<RowMajorTilingOrder>,
134 SyncPartialCyclicLoading<RowMajorTilingOrder>,
135 >;
136 type BatchMatmul =
137 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
138
139 fn selection<R: Runtime>(
140 client: &ComputeClient<R::Server, R::Channel>,
141 problem: &MatmulProblem,
142 plane_dim: u32,
143 elem_stage: Elem,
144 elem_acc: Elem,
145 args: &Self::SelectionArgs,
146 ) -> MatmulSelection {
147 plane_matmul_selection::<TMM, R>(
148 client,
149 problem,
150 plane_dim,
151 elem_stage,
152 elem_acc,
153 PlaneMatmulSelectionOptions {
154 specialized: args.specialized,
155 multi_row_strategy: MultiRowStrategy::Adaptive {
156 minimum_stage_count: 8,
157 },
158 ..Default::default()
159 },
160 )
161 }
162}