cubecl_matmul/components/global/
base.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::components::{AccG, error::MatmulSetupError, stage::TilingLayoutEnum};
5use crate::components::{
6 AvailableLineSizes, MatmulPrecision, MatmulProblem, MatrixLayout, TilingScheme,
7 global::{PlaneRoleConfig, SpecializedLoadingSides, multi_stage::EventLoadingMode},
8 stage::StageConfig,
9};
10use crate::components::{LhsG, MatmulElems, MatmulIdent, MatmulLineSizes, MatmulSelection, RhsG};
11use crate::components::{global::RoleRuleConfig, stage::StageMemoryConfig};
12use crate::components::{global::memory::GlobalMemoryConfig, stage::SwizzleMode};
13use cubecl_std::{
14 CubeOption,
15 tensor::{View, layout::Coords2d},
16};
17use std::{fmt::Debug, hash::Hash};
18
19use super::read::ReaderMode;
20
21pub trait GlobalMatmulFamily: Send + Sync + 'static {
23 type Matmul<MP: MatmulPrecision>: GlobalMatmul<MP, Config = Self::Config>;
25
26 type Config: GlobalConfig;
28
29 fn setup<R: Runtime>(
33 client: &ComputeClient<R::Server>,
34 problem: &MatmulProblem,
35 selection: &MatmulSelection,
36 matmul_line_sizes: &MatmulLineSizes,
37 dtypes: &MatmulElems,
38 ) -> Result<Self::Config, MatmulSetupError>;
39
40 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
44 available_line_sizes
45 }
46}
47
48#[cube]
49pub trait GlobalMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
68 type Config: GlobalConfig;
69
70 type LhsGlobalReader: CubeType;
72 type RhsGlobalReader: CubeType;
74 type AccGlobalReader: CubeType;
76 type GlobalWriter: CubeType;
78
79 type Accumulators: CubeType;
81
82 fn execute(
89 lhs_reader: Self::LhsGlobalReader,
90 rhs_reader: Self::RhsGlobalReader,
91 acc_reader: Self::AccGlobalReader,
92 writer: Self::GlobalWriter,
93 k_range: (u32, u32),
94 #[comptime] config: Self::Config,
95 );
96
97 fn init_lhs_global_reader(
99 lhs: View<Line<LhsG<MP>>, Coords2d>,
100 #[comptime] config: Self::Config,
101 ) -> Self::LhsGlobalReader;
102
103 fn init_rhs_global_reader(
105 rhs: View<Line<RhsG<MP>>, Coords2d>,
106 #[comptime] config: Self::Config,
107 ) -> Self::RhsGlobalReader;
108
109 fn init_acc_global_reader(
111 acc: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
112 #[comptime] config: Self::Config,
113 ) -> Self::AccGlobalReader;
114
115 fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators;
117
118 fn init_global_writer(
120 out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
121 #[comptime] config: Self::Config,
122 ) -> Self::GlobalWriter;
123}
124
125pub trait GlobalConfig:
127 Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
128{
129 type StageConfig: StageConfig;
131
132 fn stage_config(&self) -> Self::StageConfig;
134
135 fn stage_memory_config(&self, ident: MatmulIdent) -> StageMemoryConfig {
136 self.stage_config().stage_memory_config(ident.into_stage())
137 }
138
139 fn global_memory_config(&self, ident: MatmulIdent) -> GlobalMemoryConfig {
140 GlobalMemoryConfig::new(
141 self.tiling_scheme().elements_in_tile_row(ident),
142 self.tiling_scheme().elements_in_tile_col(ident),
143 self.tiling_scheme().elements_in_stage_row(ident),
144 self.tiling_scheme().elements_in_stage_col(ident),
145 self.global_line_size(ident),
146 self.check_row_bounds(ident),
147 self.check_col_bounds(ident),
148 self.matrix_layout(ident),
149 self.swizzle_mode(ident),
150 )
151 }
152
153 fn global_line_size(&self, ident: MatmulIdent) -> u32;
155
156 fn tiling_scheme(&self) -> TilingScheme {
158 self.stage_config().tiling_scheme()
159 }
160
161 fn matrix_layout(&self, ident: MatmulIdent) -> MatrixLayout;
163
164 fn swizzle_mode(&self, ident: MatmulIdent) -> SwizzleMode;
166
167 fn tiling_layout(&self, ident: MatmulIdent) -> TilingLayoutEnum;
169
170 fn num_loading_planes(&self, ident: MatmulIdent) -> u32;
172
173 fn plane_role_config(&self) -> PlaneRoleConfig;
175
176 fn specialized_loading_sides(&self) -> SpecializedLoadingSides;
178
179 fn role_rule_config(&self) -> RoleRuleConfig {
181 self.plane_role_config().rule
182 }
183
184 fn plane_dim(&self) -> u32;
186
187 fn check_row_bounds(&self, ident: MatmulIdent) -> bool;
189
190 fn check_col_bounds(&self, ident: MatmulIdent) -> bool;
192
193 fn check_k_bounds(&self) -> bool;
195
196 fn precompute_job(&self) -> bool;
198
199 fn num_stages(&self, ident: MatmulIdent) -> u32;
201
202 fn reader_mode(&self) -> ReaderMode;
206
207 fn event_loading_mode(&self, ident: MatmulIdent) -> EventLoadingMode;
209
210 fn quantized(&self) -> bool {
212 self.stage_config().quantized()
213 }
214
215 fn cube_dim(&self) -> CubeDim;
217}