cubecl_matmul/kernels/layered/algorithm/
double_buffering.rs1use std::marker::PhantomData;
2
3use cubecl_core::Runtime;
4use cubecl_core::client::ComputeClient;
5
6use crate::components::global::read::sync_partial_cyclic::SyncPartialCyclicLoading;
7use crate::components::global::{
8 PlaneWriterFamily, read::sync_partial_tilewise::SyncPartialTilewiseLoading,
9};
10use crate::components::stage::{ColMajorTilingOrder, PlaneMatmulFamily, RowMajorTilingOrder};
11use crate::components::{MatmulElems, MatmulLineSizes, MatmulSelection, MatmulSetupError};
12use crate::components::{MatmulProblem, MultiRowStrategy, tile};
13use crate::components::{
14 batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
15 tile::io::{Filled, Strided},
16};
17use crate::components::{
18 global::multi_stage::double_buffering::DoubleBufferingMatmulFamily,
19 stage::{FilledStageFamily, StridedStageFamily},
20};
21use crate::kernels::layered::Algorithm;
22use crate::kernels::layered::algorithm::base;
23use crate::kernels::layered::selector::{PlaneMatmulSelectionOptions, plane_matmul_selection};
24
25pub struct CyclicDoubleBufferingAlgorithm<TMM> {
27 pub _phantom: PhantomData<TMM>,
28}
29
30pub struct TilewiseDoubleBufferingAlgorithm<TMM> {
32 pub _phantom: PhantomData<TMM>,
33}
34
35pub struct HybridDoubleBufferingAlgorithm<TMM> {
37 pub _phantom: PhantomData<TMM>,
38}
39
40#[derive(Default, Debug, Clone, Copy)]
41pub struct DoubleBufferingArgs {
42 pub specialized: bool,
43}
44
45impl<TMM> base::Algorithm for CyclicDoubleBufferingAlgorithm<TMM>
46where
47 TMM: tile::TileMatmulFamily<
48 LhsTile = Strided,
49 RhsTile = Strided,
50 AccTile = Filled,
51 OutTile = Strided,
52 >,
53{
54 type SelectionArgs = DoubleBufferingArgs;
55 type TileMatmul = TMM;
56 type StageMatmul = PlaneMatmulFamily<
57 Self::TileMatmul,
58 StridedStageFamily,
59 StridedStageFamily,
60 FilledStageFamily,
61 >;
62 type GlobalMatmul = DoubleBufferingMatmulFamily<
63 Self::StageMatmul,
64 SyncPartialCyclicLoading<RowMajorTilingOrder>,
65 SyncPartialCyclicLoading<RowMajorTilingOrder>,
66 PlaneWriterFamily,
67 >;
68 type BatchMatmul =
69 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
70
71 fn selection<R: Runtime>(
72 client: &ComputeClient<R::Server>,
73 problem: &MatmulProblem,
74 plane_dim: u32,
75 _line_sizes: &MatmulLineSizes,
76 elems: MatmulElems,
77 args: &Self::SelectionArgs,
78 ) -> Result<MatmulSelection, MatmulSetupError> {
79 plane_matmul_selection::<TMM, R>(
80 client,
81 problem,
82 plane_dim,
83 elems,
84 PlaneMatmulSelectionOptions {
85 specialized: args.specialized,
86 multi_row_strategy: MultiRowStrategy::Adaptive {
87 minimum_stage_count: 8,
88 },
89 ..Default::default()
90 },
91 )
92 }
93}
94
95impl<TMM> Algorithm for TilewiseDoubleBufferingAlgorithm<TMM>
96where
97 TMM: tile::TileMatmulFamily<
98 LhsTile = Strided,
99 RhsTile = Strided,
100 AccTile = Filled,
101 OutTile = Strided,
102 >,
103{
104 type SelectionArgs = DoubleBufferingArgs;
105 type TileMatmul = TMM;
106 type StageMatmul = PlaneMatmulFamily<
107 Self::TileMatmul,
108 StridedStageFamily,
109 StridedStageFamily,
110 FilledStageFamily,
111 >;
112 type GlobalMatmul = DoubleBufferingMatmulFamily<
113 Self::StageMatmul,
114 SyncPartialTilewiseLoading<RowMajorTilingOrder>,
116 SyncPartialTilewiseLoading<ColMajorTilingOrder>,
117 PlaneWriterFamily,
118 >;
119 type BatchMatmul =
120 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
121
122 fn selection<R: Runtime>(
123 client: &ComputeClient<R::Server>,
124 problem: &MatmulProblem,
125 plane_dim: u32,
126 _line_sizes: &MatmulLineSizes,
127 elems: MatmulElems,
128 args: &Self::SelectionArgs,
129 ) -> Result<MatmulSelection, MatmulSetupError> {
130 plane_matmul_selection::<TMM, R>(
131 client,
132 problem,
133 plane_dim,
134 elems,
135 PlaneMatmulSelectionOptions {
136 specialized: args.specialized,
137 multi_row_strategy: MultiRowStrategy::Adaptive {
138 minimum_stage_count: 8,
139 },
140 ..Default::default()
141 },
142 )
143 }
144}
145
146impl<TMM> base::Algorithm for HybridDoubleBufferingAlgorithm<TMM>
147where
148 TMM: tile::TileMatmulFamily<
149 LhsTile = Strided,
150 RhsTile = Strided,
151 AccTile = Filled,
152 OutTile = Strided,
153 >,
154{
155 type SelectionArgs = DoubleBufferingArgs;
156 type TileMatmul = TMM;
157 type StageMatmul = PlaneMatmulFamily<
158 Self::TileMatmul,
159 StridedStageFamily,
160 StridedStageFamily,
161 FilledStageFamily,
162 >;
163 type GlobalMatmul = DoubleBufferingMatmulFamily<
164 Self::StageMatmul,
165 SyncPartialTilewiseLoading<RowMajorTilingOrder>,
166 SyncPartialCyclicLoading<RowMajorTilingOrder>,
167 PlaneWriterFamily,
168 >;
169 type BatchMatmul =
170 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
171
172 fn selection<R: Runtime>(
173 client: &ComputeClient<R::Server>,
174 problem: &MatmulProblem,
175 plane_dim: u32,
176 _line_sizes: &MatmulLineSizes,
177 elems: MatmulElems,
178 args: &Self::SelectionArgs,
179 ) -> Result<MatmulSelection, MatmulSetupError> {
180 plane_matmul_selection::<TMM, R>(
181 client,
182 problem,
183 plane_dim,
184 elems,
185 PlaneMatmulSelectionOptions {
186 specialized: args.specialized,
187 multi_row_strategy: MultiRowStrategy::Adaptive {
188 minimum_stage_count: 8,
189 },
190 ..Default::default()
191 },
192 )
193 }
194}