cubecl_matmul/components/global/
base.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::components::global::memory::GlobalMemoryConfig;
5use crate::components::{AccG, error::MatmulSetupError};
6use crate::components::{
7 AvailableLineSizes, MatmulPrecision, MatmulProblem, MatrixLayout, TilingScheme,
8 global::{PlaneRoleConfig, SpecializedLoadingSides, multi_stage::EventLoadingMode},
9 stage::StageConfig,
10};
11use crate::components::{LhsG, MatmulIdent, MatmulLineSizes, MatmulSelection, RhsG};
12use crate::components::{global::RoleRuleConfig, stage::StageMemoryConfig};
13use cubecl_std::{
14 CubeOption,
15 tensor::{View, layout::Coords2d},
16};
17use std::{fmt::Debug, hash::Hash};
18
19use super::read::ReaderMode;
20
21pub trait GlobalMatmulFamily: Send + Sync + 'static {
23 type Matmul<MP: MatmulPrecision>: GlobalMatmul<MP, Config = Self::Config>;
25
26 type Config: GlobalConfig;
28
29 fn setup<MP: MatmulPrecision, R: Runtime>(
33 client: &ComputeClient<R::Server>,
34 problem: &MatmulProblem,
35 selection: &MatmulSelection,
36 matmul_line_sizes: &MatmulLineSizes,
37 ) -> Result<Self::Config, MatmulSetupError>;
38
39 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
43 available_line_sizes
44 }
45}
46
47#[cube]
48pub trait GlobalMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
67 type Config: GlobalConfig;
68
69 type LhsGlobalReader: CubeType;
71 type RhsGlobalReader: CubeType;
73 type AccGlobalReader: CubeType;
75 type GlobalWriter: CubeType;
77
78 type Accumulators: CubeType;
80
81 fn execute(
88 lhs_reader: Self::LhsGlobalReader,
89 rhs_reader: Self::RhsGlobalReader,
90 acc_reader: Self::AccGlobalReader,
91 writer: Self::GlobalWriter,
92 acc: &mut Self::Accumulators,
93 k_range: (u32, u32),
94 #[comptime] config: Self::Config,
95 );
96
97 fn init_lhs_global_reader(
99 lhs: View<Line<LhsG<MP>>, Coords2d>,
100 #[comptime] config: Self::Config,
101 ) -> Self::LhsGlobalReader;
102
103 fn init_rhs_global_reader(
105 rhs: View<Line<RhsG<MP>>, Coords2d>,
106 #[comptime] config: Self::Config,
107 ) -> Self::RhsGlobalReader;
108
109 fn init_acc_global_reader(
111 acc: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
112 #[comptime] config: Self::Config,
113 ) -> Self::AccGlobalReader;
114
115 fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators;
117
118 fn init_global_writer(
120 out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
121 #[comptime] config: Self::Config,
122 ) -> Self::GlobalWriter;
123}
124
125pub trait GlobalConfig:
127 Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
128{
129 type StageConfig: StageConfig;
131
132 fn stage_config(&self) -> Self::StageConfig;
134
135 fn stage_memory_config(&self, ident: MatmulIdent) -> StageMemoryConfig {
136 self.stage_config().stage_memory_config(ident.into_stage())
137 }
138
139 fn global_memory_config(&self, ident: MatmulIdent) -> GlobalMemoryConfig {
140 GlobalMemoryConfig {
141 elements_in_tile_row: self.tiling_scheme().elements_in_tile_row(ident),
142 elements_in_tile_col: self.tiling_scheme().elements_in_tile_col(ident),
143 elements_in_stage_row: self.tiling_scheme().elements_in_stage_row(ident),
144 elements_in_stage_col: self.tiling_scheme().elements_in_stage_col(ident),
145 global_line_size: self.global_line_size(ident),
146 check_row_bounds: self.check_row_bounds(ident),
147 check_col_bounds: self.check_col_bounds(ident),
148 matrix_layout: self.matrix_layout(ident),
149 }
150 }
151
152 fn global_line_size(&self, ident: MatmulIdent) -> u32;
154
155 fn tiling_scheme(&self) -> TilingScheme {
157 self.stage_config().tiling_scheme()
158 }
159
160 fn matrix_layout(&self, ident: MatmulIdent) -> MatrixLayout;
162
163 fn num_loading_planes(&self, ident: MatmulIdent) -> u32;
165
166 fn plane_role_config(&self) -> PlaneRoleConfig;
168
169 fn specialized_loading_sides(&self) -> SpecializedLoadingSides;
171
172 fn role_rule_config(&self) -> RoleRuleConfig {
174 self.plane_role_config().rule
175 }
176
177 fn plane_dim(&self) -> u32;
179
180 fn check_row_bounds(&self, ident: MatmulIdent) -> bool;
182
183 fn check_col_bounds(&self, ident: MatmulIdent) -> bool;
185
186 fn check_k_bounds(&self) -> bool;
188
189 fn precompute_job(&self) -> bool;
191
192 fn num_stages(&self, ident: MatmulIdent) -> u32;
194
195 fn reader_mode(&self) -> ReaderMode;
199
200 fn event_loading_mode(&self, ident: MatmulIdent) -> EventLoadingMode;
202
203 fn quantized(&self) -> bool {
205 self.stage_config().quantized()
206 }
207
208 fn cube_dim(&self) -> CubeDim;
210}