cubecl_matmul/components/stage/
base.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand, tensor::layout::Coords2d};
4
5use crate::components::{AccS, global::MaxGlobalReaderPlanes};
6use crate::components::{
7 AvailableLineSizes, LhsS, MatmulLineSizes, MatmulSelection, RhsS, StageIdent,
8};
9use crate::components::{
10 MatmulPrecision, MatmulProblem, MatrixLayout, TilingScheme,
11 global::{self, PlaneRoleConfig, RoleRuleConfig},
12 tile::TileConfig,
13};
14use crate::components::{
15 error::MatmulSetupError, global::WriteEventListener, stage::StageMemoryConfig,
16};
17use crate::components::{
18 stage::{NumStages, PartitionScheduler, PartitionSchedulerScheme},
19 tile::io::TileKind,
20};
21use std::{fmt::Debug, hash::Hash};
22
23use super::{StageEventListener, TilingLayout};
24
25pub trait StageMatmulFamily: Send + Sync + 'static {
27 type Matmul<MP: MatmulPrecision, TL: TilingLayout, TR: TilingLayout, TA: TilingLayout, TO: TilingLayout>: StageMatmul<
29 MP,
30 Config = Self::Config,
31 LhsStage = <Self::LhsStage as StageFamily>::Stage<LhsS<MP>, TL>,
32 RhsStage = <Self::RhsStage as StageFamily>::Stage<RhsS<MP>, TR>,
33 AccStage = <Self::AccStage as StageFamily>::Stage<AccS<MP>, TA>,
34 OutStage = <Self::OutStage as StageFamily<ReadWrite>>::Stage<AccS<MP>, TO>,
35 >;
36
37 type LhsStage: StageFamily;
39 type RhsStage: StageFamily;
41 type AccStage: StageFamily;
43 type OutStage: StageFamily<ReadWrite>;
45
46 type Config: StageConfig;
48
49 fn setup<MP: MatmulPrecision, R: Runtime>(
54 client: &ComputeClient<R::Server>,
55 problem: &MatmulProblem,
56 selection: &MatmulSelection,
57 line_sizes: &MatmulLineSizes,
58 num_stages: NumStages,
59 max_global_readers: Option<MaxGlobalReaderPlanes>,
60 ordered: bool,
61 ) -> Result<Self::Config, MatmulSetupError>;
62
63 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
67 available_line_sizes
68 }
69}
70
71#[cube]
72pub trait StageMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
87 type Config: StageConfig;
89
90 type Accumulators: CubeType;
93
94 type LhsStage: CubeType;
96 type RhsStage: CubeType;
98 type AccStage: CubeType;
100 type OutStage: CubeType;
102
103 type LhsTile: CubeType;
105 type RhsTile: CubeType;
107
108 fn execute(
112 lhs: &Self::LhsStage,
113 rhs: &Self::RhsStage,
114 instruction_lhs: &mut Self::LhsTile,
115 instruction_rhs: &mut Self::RhsTile,
116 acc: &mut Self::Accumulators,
117 #[comptime] config: Self::Config,
118 partition_scheduler: &PartitionScheduler,
119 );
120
121 fn execute_with_listener<SEL: StageEventListener<Self::Config>>(
124 lhs: &Self::LhsStage,
125 rhs: &Self::RhsStage,
126 instruction_lhs: &mut Self::LhsTile,
127 instruction_rhs: &mut Self::RhsTile,
128 acc: &mut Self::Accumulators,
129 #[comptime] config: Self::Config,
130 listener: SEL,
131 partition_scheduler: &PartitionScheduler,
132 );
133
134 fn init_tile_inputs(#[comptime] config: Self::Config) -> (Self::LhsTile, Self::RhsTile);
136
137 fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators;
139
140 fn load_accumulators(
142 reader: &Self::AccStage,
143 acc: &mut Self::Accumulators,
144 #[comptime] config: Self::Config,
145 );
146
147 fn write_results<W: WriteEventListener, G: global::GlobalConfig>(
149 acc: &Self::Accumulators,
150 stage: &mut Self::OutStage,
151 listener: &mut W,
152 partition_scheduler: &PartitionScheduler,
153 #[comptime] stage_config: Self::Config,
154 #[comptime] global_config: G,
155 );
156
157 fn init_scheduler(#[comptime] config: Self::Config) -> PartitionScheduler;
158}
159
160pub trait StageConfig:
162 Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
163{
164 type TileConfig: TileConfig;
166
167 fn tile_config(self) -> Self::TileConfig;
169
170 fn stage_memory_config(self, ident: StageIdent) -> StageMemoryConfig {
172 let tiling = self.tiling_scheme();
173 StageMemoryConfig {
174 num_main_flow_planes: self.num_main_flow_planes(),
175 elements_in_tile_row: tiling.elements_in_tile_row(ident),
176 elements_in_tile_col: tiling.elements_in_tile_col(ident),
177 tiles_in_stage_row: tiling.tiles_in_stage_row(ident),
178 tiles_in_stage_col: tiling.tiles_in_stage_col(ident),
179 stage_line_size: self.stage_line_size(ident),
180 matrix_layout: self.matrix_layout(ident),
181 num_stages: self.num_stages(ident),
182 }
183 }
184
185 fn stage_line_size(&self, ident: StageIdent) -> u32;
187
188 fn global_line_size(&self, ident: StageIdent) -> u32;
190
191 fn matrix_layout(&self, ident: StageIdent) -> MatrixLayout;
193
194 fn plane_dim(&self) -> u32;
196
197 fn partition_buffering(&self) -> PartitionBuffering;
199
200 fn tiling_scheme(&self) -> TilingScheme;
202
203 fn plane_role_config(&self) -> PlaneRoleConfig;
205
206 fn role_rule_config(&self) -> RoleRuleConfig;
208
209 fn num_main_flow_planes(&self) -> u32;
211
212 fn quantized(&self) -> bool;
214
215 fn must_sync_plane_after_execution(&self) -> bool;
218
219 fn partition_schedule_scheme(&self) -> PartitionSchedulerScheme;
220
221 fn num_stages(&self, ident: StageIdent) -> u32;
223}
224
225#[derive(Default, Clone, Copy, PartialEq, Eq, Hash, Debug)]
226pub enum PartitionBuffering {
227 Single,
228 #[default]
229 Double,
230}
231
232#[cube]
235pub trait Stage<ES: Numeric, IO: SliceVisibility = ReadOnly>:
236 CubeType + Send + Sync + 'static
237{
238 type TileKind: TileKind<IO>;
240
241 fn tile(this: &Self, tile: Coords2d) -> <Self::TileKind as TileKind<IO>>::Tile<ES>;
243}
244
245pub trait StageFamily<IO: SliceVisibility = ReadOnly>: Send + Sync + 'static {
247 type TileKind: TileKind<IO>;
249 type Stage<ES: Numeric, T: TilingLayout>: Stage<ES, IO, TileKind = Self::TileKind>;
251}
252
253#[cube]
254impl<ES: Numeric, IO: SliceVisibility, Inner: Stage<ES, IO>> Stage<ES, IO> for CubeOption<Inner> {
255 type TileKind = CubeOption<Inner::TileKind>;
256
257 fn tile(this: &Self, tile: Coords2d) -> <Self::TileKind as TileKind<IO>>::Tile<ES> {
258 match this {
259 CubeOption::Some(stage) => CubeOption::new_Some(Inner::tile(stage, tile)),
260 CubeOption::None => CubeOption::new_None(),
261 }
262 }
263}
264
265impl<IO: SliceVisibility, Inner: StageFamily<IO>> StageFamily<IO> for Option<Inner> {
266 type TileKind = CubeOption<Inner::TileKind>;
267 type Stage<ES: Numeric, T: TilingLayout> = CubeOption<Inner::Stage<ES, T>>;
268}