cubecl_matmul/components/stage/
base.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::{CubeOption, CubeOptionExpand, tensor::layout::Coords2d};
4
5use crate::components::{AccS, global::MaxGlobalReaderPlanes};
6use crate::components::{
7    AvailableLineSizes, LhsS, MatmulLineSizes, MatmulSelection, RhsS, StageIdent,
8};
9use crate::components::{
10    MatmulPrecision, MatmulProblem, MatrixLayout, TilingScheme,
11    global::{self, PlaneRoleConfig, RoleRuleConfig},
12    tile::TileConfig,
13};
14use crate::components::{
15    error::MatmulSetupError, global::WriteEventListener, stage::StageMemoryConfig,
16};
17use crate::components::{
18    stage::{NumStages, PartitionScheduler, PartitionSchedulerScheme},
19    tile::io::TileKind,
20};
21use std::{fmt::Debug, hash::Hash};
22
23use super::{StageEventListener, TilingLayout};
24
25/// A family of [StageMatmul] implementations that operate with any [precision](MatmulPrecision).
26pub trait StageMatmulFamily: Send + Sync + 'static {
27    /// The specific [TileMatmul] implementation associated with this family.
28    type Matmul<MP: MatmulPrecision, TL: TilingLayout, TR: TilingLayout, TA: TilingLayout, TO: TilingLayout>: StageMatmul<
29            MP,
30            Config = Self::Config,
31            LhsStage = <Self::LhsStage as StageFamily>::Stage<LhsS<MP>, TL>,
32            RhsStage = <Self::RhsStage as StageFamily>::Stage<RhsS<MP>, TR>,
33            AccStage = <Self::AccStage as StageFamily>::Stage<AccS<MP>, TA>,
34            OutStage = <Self::OutStage as StageFamily<ReadWrite>>::Stage<AccS<MP>, TO>,
35        >;
36
37    /// Stage family for Lhs
38    type LhsStage: StageFamily;
39    /// Stage family for Rhs
40    type RhsStage: StageFamily;
41    /// Stage family for Acc
42    type AccStage: StageFamily;
43    /// Stage family for Out
44    type OutStage: StageFamily<ReadWrite>;
45
46    /// The configuration type associated with this matmul family.
47    type Config: StageConfig;
48
49    /// Constructs the configuration based on the matmul problem, selection, line sizes,
50    /// number of stages, maximum of tasks per plane, and whether the algorithm is an ordered variant
51    ///
52    /// This function may return an error if the configuration cannot be supported on the current runtime.
53    fn setup<MP: MatmulPrecision, R: Runtime>(
54        client: &ComputeClient<R::Server>,
55        problem: &MatmulProblem,
56        selection: &MatmulSelection,
57        line_sizes: &MatmulLineSizes,
58        num_stages: NumStages,
59        max_global_readers: Option<MaxGlobalReaderPlanes>,
60        ordered: bool,
61    ) -> Result<Self::Config, MatmulSetupError>;
62
63    /// Filters out line sizes that are incompatible with this matmul family.
64    ///
65    /// By default, returns the input unchanged.
66    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
67        available_line_sizes
68    }
69}
70
71#[cube]
72/// Provides matrix multiplication operations at the stage level.
73///
74/// At the stage level,
75///  - Inputs are assumed to be already staged into a shared memory.
76///  - All main flow planes within a Cube are used to solve the problem
77///  - Dimensions M, N and K are fixed to an integer, and the
78///    matrix multiplication works only for size (M, K) ยท (K, N) = (M, N).
79///    These integers are multiples of the underlying Tile matmul,
80///    corresponding to the number of tiles in each dimension.
81///
82/// Assumptions:
83///  - Data given as inputs by stage readers must always be valid. If the actual matrix multiplication
84///    should be done on smaller sizes than M, N and K, padding with zeros must be done beforehand.
85///  - Enough planes/units are launched to perform the whole computation
86pub trait StageMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
87    /// The configuration type associated with this Matmul.
88    type Config: StageConfig;
89
90    /// Contains the matrix multiplication output, that can be shared across the different planes of the cube.
91    /// The same Accumulator will be added to across multiple executions of the Stage Matmul.
92    type Accumulators: CubeType;
93
94    /// Stage for Lhs
95    type LhsStage: CubeType;
96    /// Stage for Rhs
97    type RhsStage: CubeType;
98    /// Stage for Accumulator
99    type AccStage: CubeType;
100    /// Stage for Out
101    type OutStage: CubeType;
102
103    /// Lhs input of the underlying Tile Matmul
104    type LhsTile: CubeType;
105    /// Rhs input of the underlying Tile Matmul
106    type RhsTile: CubeType;
107
108    /// Executes the matrix multiplication of Lhs and Rhs, adding the result to the accumulator
109    ///
110    /// Equivalent to execute_with_listener with SEL:=NoEvent
111    fn execute(
112        lhs: &Self::LhsStage,
113        rhs: &Self::RhsStage,
114        instruction_lhs: &mut Self::LhsTile,
115        instruction_rhs: &mut Self::RhsTile,
116        acc: &mut Self::Accumulators,
117        #[comptime] config: Self::Config,
118        partition_scheduler: &PartitionScheduler,
119    );
120
121    /// Executes the matrix multiplication of Lhs and Rhs, with the addition of injected
122    /// [event listener](StageEventListener).
123    fn execute_with_listener<SEL: StageEventListener<Self::Config>>(
124        lhs: &Self::LhsStage,
125        rhs: &Self::RhsStage,
126        instruction_lhs: &mut Self::LhsTile,
127        instruction_rhs: &mut Self::RhsTile,
128        acc: &mut Self::Accumulators,
129        #[comptime] config: Self::Config,
130        listener: SEL,
131        partition_scheduler: &PartitionScheduler,
132    );
133
134    /// Inits inputs of the underlying Tile Matmul
135    fn init_tile_inputs(#[comptime] config: Self::Config) -> (Self::LhsTile, Self::RhsTile);
136
137    /// Create an instance of the accumulators, without data
138    fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators;
139
140    /// Load all accumulators in the stage from data
141    fn load_accumulators(
142        reader: &Self::AccStage,
143        acc: &mut Self::Accumulators,
144        #[comptime] config: Self::Config,
145    );
146
147    /// Reads the result of the accumulator and hands it to the stage writer
148    fn write_results<W: WriteEventListener, G: global::GlobalConfig>(
149        acc: &Self::Accumulators,
150        stage: &mut Self::OutStage,
151        listener: &mut W,
152        partition_scheduler: &PartitionScheduler,
153        #[comptime] stage_config: Self::Config,
154        #[comptime] global_config: G,
155    );
156
157    fn init_scheduler(#[comptime] config: Self::Config) -> PartitionScheduler;
158}
159
160/// Configuration for the Stage matmul (SMM) level
161pub trait StageConfig:
162    Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
163{
164    /// Underlying Tile matmul config
165    type TileConfig: TileConfig;
166
167    /// Converts itself to the underlying Tile Matmul config
168    fn tile_config(self) -> Self::TileConfig;
169
170    /// Converts itself to the underlying Stage Memory config
171    fn stage_memory_config(self, ident: StageIdent) -> StageMemoryConfig {
172        let tiling = self.tiling_scheme();
173        StageMemoryConfig {
174            num_main_flow_planes: self.num_main_flow_planes(),
175            elements_in_tile_row: tiling.elements_in_tile_row(ident),
176            elements_in_tile_col: tiling.elements_in_tile_col(ident),
177            tiles_in_stage_row: tiling.tiles_in_stage_row(ident),
178            tiles_in_stage_col: tiling.tiles_in_stage_col(ident),
179            stage_line_size: self.stage_line_size(ident),
180            matrix_layout: self.matrix_layout(ident),
181            num_stages: self.num_stages(ident),
182        }
183    }
184
185    /// Returns the line size for the given ident
186    fn stage_line_size(&self, ident: StageIdent) -> u32;
187
188    /// Returns the line size for the given ident
189    fn global_line_size(&self, ident: StageIdent) -> u32;
190
191    /// Returns the [MatrixLayout] for the given ident
192    fn matrix_layout(&self, ident: StageIdent) -> MatrixLayout;
193
194    /// Returns how many units are in a plane
195    fn plane_dim(&self) -> u32;
196
197    /// Returns whether we must perform partition buffering
198    fn partition_buffering(&self) -> PartitionBuffering;
199
200    /// Returns the [TilingScheme]
201    fn tiling_scheme(&self) -> TilingScheme;
202
203    /// Indicates the specialization roles for the planes
204    fn plane_role_config(&self) -> PlaneRoleConfig;
205
206    /// How to identify the role of the plane depending on its index
207    fn role_rule_config(&self) -> RoleRuleConfig;
208
209    /// Number of planes participating in the main computation flow
210    fn num_main_flow_planes(&self) -> u32;
211
212    /// Whether the Matmul is quantized
213    fn quantized(&self) -> bool;
214
215    /// Whether we must sync planes after execution because the execution
216    /// is not sync by itself (depends on the runtime/compiler)
217    fn must_sync_plane_after_execution(&self) -> bool;
218
219    fn partition_schedule_scheme(&self) -> PartitionSchedulerScheme;
220
221    /// Number of stages in the stage
222    fn num_stages(&self, ident: StageIdent) -> u32;
223}
224
225#[derive(Default, Clone, Copy, PartialEq, Eq, Hash, Debug)]
226pub enum PartitionBuffering {
227    Single,
228    #[default]
229    Double,
230}
231
232/// Stage that can be divided into tiles, with the same kind used by the
233/// tile matmul readers.
234#[cube]
235pub trait Stage<ES: Numeric, IO: SliceVisibility = ReadOnly>:
236    CubeType + Send + Sync + 'static
237{
238    /// The kind (or family) of the tiles contained in this stage
239    type TileKind: TileKind<IO>;
240
241    /// Slices a tile with offset (`row`, `col`) from the stage and returns it
242    fn tile(this: &Self, tile: Coords2d) -> <Self::TileKind as TileKind<IO>>::Tile<ES>;
243}
244
245/// Stage family for any precision
246pub trait StageFamily<IO: SliceVisibility = ReadOnly>: Send + Sync + 'static {
247    /// The tile kind (family) contained in the stage
248    type TileKind: TileKind<IO>;
249    /// The concrete stage type of this family, instantiated with the type and layout
250    type Stage<ES: Numeric, T: TilingLayout>: Stage<ES, IO, TileKind = Self::TileKind>;
251}
252
253#[cube]
254impl<ES: Numeric, IO: SliceVisibility, Inner: Stage<ES, IO>> Stage<ES, IO> for CubeOption<Inner> {
255    type TileKind = CubeOption<Inner::TileKind>;
256
257    fn tile(this: &Self, tile: Coords2d) -> <Self::TileKind as TileKind<IO>>::Tile<ES> {
258        match this {
259            CubeOption::Some(stage) => CubeOption::new_Some(Inner::tile(stage, tile)),
260            CubeOption::None => CubeOption::new_None(),
261        }
262    }
263}
264
265impl<IO: SliceVisibility, Inner: StageFamily<IO>> StageFamily<IO> for Option<Inner> {
266    type TileKind = CubeOption<Inner::TileKind>;
267    type Stage<ES: Numeric, T: TilingLayout> = CubeOption<Inner::Stage<ES, T>>;
268}