cubecl_matmul/components/global/
base.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::components::{AccG, error::MatmulSetupError, stage::TilingLayoutEnum};
5use crate::components::{
6    AvailableLineSizes, MatmulPrecision, MatmulProblem, MatrixLayout, TilingScheme,
7    global::{PlaneRoleConfig, SpecializedLoadingSides, multi_stage::EventLoadingMode},
8    stage::StageConfig,
9};
10use crate::components::{LhsG, MatmulElems, MatmulIdent, MatmulLineSizes, MatmulSelection, RhsG};
11use crate::components::{global::RoleRuleConfig, stage::StageMemoryConfig};
12use crate::components::{global::memory::GlobalMemoryConfig, stage::SwizzleMode};
13use cubecl_std::{
14    CubeOption,
15    tensor::{View, layout::Coords2d},
16};
17use std::{fmt::Debug, hash::Hash};
18
19use super::read::ReaderMode;
20
21/// A family of [matmuls](GlobalMatmul) working with any [precision](MatmulPrecision).
22pub trait GlobalMatmulFamily: Send + Sync + 'static {
23    /// The specific [GlobalMatmul] implementation associated with this family.
24    type Matmul<MP: MatmulPrecision>: GlobalMatmul<MP, Config = Self::Config>;
25
26    /// The configuration type associated with this matmul family.
27    type Config: GlobalConfig;
28
29    /// Constructs the configuration based on the matmul problem, selection, and line sizes.
30    ///
31    /// This function may return an error if the configuration cannot be supported on the current runtime.
32    fn setup<R: Runtime>(
33        client: &ComputeClient<R::Server>,
34        problem: &MatmulProblem,
35        selection: &MatmulSelection,
36        matmul_line_sizes: &MatmulLineSizes,
37        dtypes: &MatmulElems,
38    ) -> Result<Self::Config, MatmulSetupError>;
39
40    /// Filters out line sizes that are incompatible with this matmul family.
41    ///
42    /// By default, returns the input unchanged.
43    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
44        available_line_sizes
45    }
46}
47
48#[cube]
49/// Provides matrix multiplication operations at the global level.
50///
51/// At the global level,
52///  - Inputs are views over global memory, meaning access is given to
53///    only parts of the global memory inputs at once.
54///  - All planes within a Cube are used to solve the problem
55///  - Dimensions M and N are fixed to an integer, but K is arbitrary large.
56///    The matrix multiplication works only for size (M, _) ยท (_, N) = (M, N).
57///    M and N should match the underlying Stage matmul's M and N.
58///
59/// # Assumptions
60/// - Line sizes of the inputs evenly divide the dimension they are aligned with.
61///
62/// # Safety
63///
64/// It is not assumed that the matmul's dimensions match its inputs dimensions perfectly.
65/// It is therefore important that Readers and Writers perform checks to avoid out-of-bounds
66/// before reading data.
67pub trait GlobalMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
68    type Config: GlobalConfig;
69
70    /// Global reader for matrix A (Lhs)
71    type LhsGlobalReader: CubeType;
72    /// Global reader for matrix B (Rhs)
73    type RhsGlobalReader: CubeType;
74    /// Global reader for matrix C (Accumulator/Bias)
75    type AccGlobalReader: CubeType;
76    /// Writer to store the output stage into global memory
77    type GlobalWriter: CubeType;
78
79    /// The accumulator type for the tile matmul
80    type Accumulators: CubeType;
81
82    /// Performs the matrix multiplication over data loaded by the
83    /// Lhs and Rhs readers, over the range given for K, and stores with
84    /// using the output writer.
85    ///
86    /// To compute the whole range of k values, use k_range=(0, K) where
87    /// K is the K dimension of Lhs and Rhs.
88    fn execute(
89        lhs_reader: Self::LhsGlobalReader,
90        rhs_reader: Self::RhsGlobalReader,
91        acc_reader: Self::AccGlobalReader,
92        writer: Self::GlobalWriter,
93        k_range: (u32, u32),
94        #[comptime] config: Self::Config,
95    );
96
97    /// Initialize the global reader for Lhs, starting at row m and column k
98    fn init_lhs_global_reader(
99        lhs: View<Line<LhsG<MP>>, Coords2d>,
100        #[comptime] config: Self::Config,
101    ) -> Self::LhsGlobalReader;
102
103    /// Initialize the global reader for Rhs, starting at row k and column n
104    fn init_rhs_global_reader(
105        rhs: View<Line<RhsG<MP>>, Coords2d>,
106        #[comptime] config: Self::Config,
107    ) -> Self::RhsGlobalReader;
108
109    /// Initialize the global reader for Rhs, starting at row k and column n
110    fn init_acc_global_reader(
111        acc: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
112        #[comptime] config: Self::Config,
113    ) -> Self::AccGlobalReader;
114
115    /// Initialize the accumulator without data
116    fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators;
117
118    /// Initialize the global writer at row m and column n
119    fn init_global_writer(
120        out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
121        #[comptime] config: Self::Config,
122    ) -> Self::GlobalWriter;
123}
124
125/// Configuration for the [global matmul](GlobalMatmul) level.
126pub trait GlobalConfig:
127    Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
128{
129    /// Underlying Stage matmul config
130    type StageConfig: StageConfig;
131
132    /// Convert itself to the underlying stage matmul config
133    fn stage_config(&self) -> Self::StageConfig;
134
135    fn stage_memory_config(&self, ident: MatmulIdent) -> StageMemoryConfig {
136        self.stage_config().stage_memory_config(ident.into_stage())
137    }
138
139    fn global_memory_config(&self, ident: MatmulIdent) -> GlobalMemoryConfig {
140        GlobalMemoryConfig::new(
141            self.tiling_scheme().elements_in_tile_row(ident),
142            self.tiling_scheme().elements_in_tile_col(ident),
143            self.tiling_scheme().elements_in_stage_row(ident),
144            self.tiling_scheme().elements_in_stage_col(ident),
145            self.global_line_size(ident),
146            self.check_row_bounds(ident),
147            self.check_col_bounds(ident),
148            self.matrix_layout(ident),
149            self.swizzle_mode(ident),
150        )
151    }
152
153    /// Returns the line size for the global memory corresponding to the given ident
154    fn global_line_size(&self, ident: MatmulIdent) -> u32;
155
156    /// Returns the [TilingScheme]
157    fn tiling_scheme(&self) -> TilingScheme {
158        self.stage_config().tiling_scheme()
159    }
160
161    /// Returns the [MatrixLayout] for the given ident
162    fn matrix_layout(&self, ident: MatmulIdent) -> MatrixLayout;
163
164    /// Returns the [SwizzleMode] for the given ident
165    fn swizzle_mode(&self, ident: MatmulIdent) -> SwizzleMode;
166
167    /// Returns the [TilingLayoutEnum] for the loader of the given ident
168    fn tiling_layout(&self, ident: MatmulIdent) -> TilingLayoutEnum;
169
170    /// Returns the number of planes participating in loading `ident`
171    fn num_loading_planes(&self, ident: MatmulIdent) -> u32;
172
173    /// Indicates the specialization roles for the planes
174    fn plane_role_config(&self) -> PlaneRoleConfig;
175
176    /// Indicates plane roles are associated to loading which tensor input
177    fn specialized_loading_sides(&self) -> SpecializedLoadingSides;
178
179    /// How to identify the role of the plane depending on its index
180    fn role_rule_config(&self) -> RoleRuleConfig {
181        self.plane_role_config().rule
182    }
183
184    /// Returns the size of the plane dimension
185    fn plane_dim(&self) -> u32;
186
187    /// Whether to check if accessing a row would exceed bounds.
188    fn check_row_bounds(&self, ident: MatmulIdent) -> bool;
189
190    /// Whether to check if accessing a col would exceed bounds.
191    fn check_col_bounds(&self, ident: MatmulIdent) -> bool;
192
193    /// Whether to check if accessing a col for lhs or row for rhs would exceed bounds.
194    fn check_k_bounds(&self) -> bool;
195
196    /// Whether to put common computations for loading tasks once before loop
197    fn precompute_job(&self) -> bool;
198
199    /// The number of stages in stage memory
200    fn num_stages(&self, ident: MatmulIdent) -> u32;
201
202    /// Whether to check reader is balanced in comptime or runtime.
203    ///
204    /// Not supported by all loading strategies
205    fn reader_mode(&self) -> ReaderMode;
206
207    /// Whether event loading is constrained to be ordered
208    fn event_loading_mode(&self, ident: MatmulIdent) -> EventLoadingMode;
209
210    /// Whether the matmul is quantized
211    fn quantized(&self) -> bool {
212        self.stage_config().quantized()
213    }
214
215    /// The [CubeDim] arising from the [TilingScheme]
216    fn cube_dim(&self) -> CubeDim;
217}