cubecl_matmul/components/global/
base.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::components::global::memory::GlobalMemoryConfig;
5use crate::components::{AccG, error::MatmulSetupError};
6use crate::components::{
7    AvailableLineSizes, MatmulPrecision, MatmulProblem, MatrixLayout, TilingScheme,
8    global::{PlaneRoleConfig, SpecializedLoadingSides, multi_stage::EventLoadingMode},
9    stage::StageConfig,
10};
11use crate::components::{LhsG, MatmulIdent, MatmulLineSizes, MatmulSelection, RhsG};
12use crate::components::{global::RoleRuleConfig, stage::StageMemoryConfig};
13use cubecl_std::{
14    CubeOption,
15    tensor::{View, layout::Coords2d},
16};
17use std::{fmt::Debug, hash::Hash};
18
19use super::read::ReaderMode;
20
21pub trait GlobalMatmulFamily: Send + Sync + 'static {
23    type Matmul<MP: MatmulPrecision>: GlobalMatmul<MP, Config = Self::Config>;
25
26    type Config: GlobalConfig;
28
29    fn setup<MP: MatmulPrecision, R: Runtime>(
33        client: &ComputeClient<R::Server>,
34        problem: &MatmulProblem,
35        selection: &MatmulSelection,
36        matmul_line_sizes: &MatmulLineSizes,
37    ) -> Result<Self::Config, MatmulSetupError>;
38
39    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
43        available_line_sizes
44    }
45}
46
47#[cube]
48pub trait GlobalMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
67    type Config: GlobalConfig;
68
69    type LhsGlobalReader: CubeType;
71    type RhsGlobalReader: CubeType;
73    type AccGlobalReader: CubeType;
75    type GlobalWriter: CubeType;
77
78    type Accumulators: CubeType;
80
81    fn execute(
88        lhs_reader: Self::LhsGlobalReader,
89        rhs_reader: Self::RhsGlobalReader,
90        acc_reader: Self::AccGlobalReader,
91        writer: Self::GlobalWriter,
92        acc: &mut Self::Accumulators,
93        k_range: (u32, u32),
94        #[comptime] config: Self::Config,
95    );
96
97    fn init_lhs_global_reader(
99        lhs: View<Line<LhsG<MP>>, Coords2d>,
100        #[comptime] config: Self::Config,
101    ) -> Self::LhsGlobalReader;
102
103    fn init_rhs_global_reader(
105        rhs: View<Line<RhsG<MP>>, Coords2d>,
106        #[comptime] config: Self::Config,
107    ) -> Self::RhsGlobalReader;
108
109    fn init_acc_global_reader(
111        acc: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
112        #[comptime] config: Self::Config,
113    ) -> Self::AccGlobalReader;
114
115    fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators;
117
118    fn init_global_writer(
120        out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
121        #[comptime] config: Self::Config,
122    ) -> Self::GlobalWriter;
123}
124
125pub trait GlobalConfig:
127    Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
128{
129    type StageConfig: StageConfig;
131
132    fn stage_config(&self) -> Self::StageConfig;
134
135    fn stage_memory_config(&self, ident: MatmulIdent) -> StageMemoryConfig {
136        self.stage_config().stage_memory_config(ident.into_stage())
137    }
138
139    fn global_memory_config(&self, ident: MatmulIdent) -> GlobalMemoryConfig {
140        GlobalMemoryConfig {
141            elements_in_tile_row: self.tiling_scheme().elements_in_tile_row(ident),
142            elements_in_tile_col: self.tiling_scheme().elements_in_tile_col(ident),
143            elements_in_stage_row: self.tiling_scheme().elements_in_stage_row(ident),
144            elements_in_stage_col: self.tiling_scheme().elements_in_stage_col(ident),
145            global_line_size: self.global_line_size(ident),
146            check_row_bounds: self.check_row_bounds(ident),
147            check_col_bounds: self.check_col_bounds(ident),
148            matrix_layout: self.matrix_layout(ident),
149        }
150    }
151
152    fn global_line_size(&self, ident: MatmulIdent) -> u32;
154
155    fn tiling_scheme(&self) -> TilingScheme {
157        self.stage_config().tiling_scheme()
158    }
159
160    fn matrix_layout(&self, ident: MatmulIdent) -> MatrixLayout;
162
163    fn num_loading_planes(&self, ident: MatmulIdent) -> u32;
165
166    fn plane_role_config(&self) -> PlaneRoleConfig;
168
169    fn specialized_loading_sides(&self) -> SpecializedLoadingSides;
171
172    fn role_rule_config(&self) -> RoleRuleConfig {
174        self.plane_role_config().rule
175    }
176
177    fn plane_dim(&self) -> u32;
179
180    fn check_row_bounds(&self, ident: MatmulIdent) -> bool;
182
183    fn check_col_bounds(&self, ident: MatmulIdent) -> bool;
185
186    fn check_k_bounds(&self) -> bool;
188
189    fn precompute_job(&self) -> bool;
191
192    fn num_stages(&self, ident: MatmulIdent) -> u32;
194
195    fn reader_mode(&self) -> ReaderMode;
199
200    fn event_loading_mode(&self, ident: MatmulIdent) -> EventLoadingMode;
202
203    fn quantized(&self) -> bool {
205        self.stage_config().quantized()
206    }
207
208    fn cube_dim(&self) -> CubeDim;
210}