cubecl_matmul/components/global/
base.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::components::global::memory::GlobalMemoryConfig;
5use crate::components::{AccG, error::MatmulSetupError};
6use crate::components::{
7    AvailableLineSizes, MatmulPrecision, MatmulProblem, MatrixLayout, TilingScheme,
8    global::{PlaneRoleConfig, SpecializedLoadingSides, multi_stage::EventLoadingMode},
9    stage::StageConfig,
10};
11use crate::components::{LhsG, MatmulIdent, MatmulLineSizes, MatmulSelection, RhsG};
12use crate::components::{global::RoleRuleConfig, stage::StageMemoryConfig};
13use cubecl_std::{
14    CubeOption,
15    tensor::{View, layout::Coords2d},
16};
17use std::{fmt::Debug, hash::Hash};
18
19use super::read::ReaderMode;
20
21/// A family of [matmuls](GlobalMatmul) working with any [precision](MatmulPrecision).
22pub trait GlobalMatmulFamily: Send + Sync + 'static {
23    /// The specific [GlobalMatmul] implementation associated with this family.
24    type Matmul<MP: MatmulPrecision>: GlobalMatmul<MP, Config = Self::Config>;
25
26    /// The configuration type associated with this matmul family.
27    type Config: GlobalConfig;
28
29    /// Constructs the configuration based on the matmul problem, selection, and line sizes.
30    ///
31    /// This function may return an error if the configuration cannot be supported on the current runtime.
32    fn setup<MP: MatmulPrecision, R: Runtime>(
33        client: &ComputeClient<R::Server>,
34        problem: &MatmulProblem,
35        selection: &MatmulSelection,
36        matmul_line_sizes: &MatmulLineSizes,
37    ) -> Result<Self::Config, MatmulSetupError>;
38
39    /// Filters out line sizes that are incompatible with this matmul family.
40    ///
41    /// By default, returns the input unchanged.
42    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
43        available_line_sizes
44    }
45}
46
47#[cube]
48/// Provides matrix multiplication operations at the global level.
49///
50/// At the global level,
51///  - Inputs are views over global memory, meaning access is given to
52///    only parts of the global memory inputs at once.
53///  - All planes within a Cube are used to solve the problem
54///  - Dimensions M and N are fixed to an integer, but K is arbitrary large.
55///    The matrix multiplication works only for size (M, _) ยท (_, N) = (M, N).
56///    M and N should match the underlying Stage matmul's M and N.
57///
58/// # Assumptions
59/// - Line sizes of the inputs evenly divide the dimension they are aligned with.
60///
61/// # Safety
62///
63/// It is not assumed that the matmul's dimensions match its inputs dimensions perfectly.
64/// It is therefore important that Readers and Writers perform checks to avoid out-of-bounds
65/// before reading data.
66pub trait GlobalMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
67    type Config: GlobalConfig;
68
69    /// Global reader for matrix A (Lhs)
70    type LhsGlobalReader: CubeType;
71    /// Global reader for matrix B (Rhs)
72    type RhsGlobalReader: CubeType;
73    /// Global reader for matrix C (Accumulator/Bias)
74    type AccGlobalReader: CubeType;
75    /// Writer to store the output stage into global memory
76    type GlobalWriter: CubeType;
77
78    /// The accumulator type for the tile matmul
79    type Accumulators: CubeType;
80
81    /// Performs the matrix multiplication over data loaded by the
82    /// Lhs and Rhs readers, over the range given for K, and stores with
83    /// using the output writer.
84    ///
85    /// To compute the whole range of k values, use k_range=(0, K) where
86    /// K is the K dimension of Lhs and Rhs.
87    fn execute(
88        lhs_reader: Self::LhsGlobalReader,
89        rhs_reader: Self::RhsGlobalReader,
90        acc_reader: Self::AccGlobalReader,
91        writer: Self::GlobalWriter,
92        acc: &mut Self::Accumulators,
93        k_range: (u32, u32),
94        #[comptime] config: Self::Config,
95    );
96
97    /// Initialize the global reader for Lhs, starting at row m and column k
98    fn init_lhs_global_reader(
99        lhs: View<Line<LhsG<MP>>, Coords2d>,
100        #[comptime] config: Self::Config,
101    ) -> Self::LhsGlobalReader;
102
103    /// Initialize the global reader for Rhs, starting at row k and column n
104    fn init_rhs_global_reader(
105        rhs: View<Line<RhsG<MP>>, Coords2d>,
106        #[comptime] config: Self::Config,
107    ) -> Self::RhsGlobalReader;
108
109    /// Initialize the global reader for Rhs, starting at row k and column n
110    fn init_acc_global_reader(
111        acc: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
112        #[comptime] config: Self::Config,
113    ) -> Self::AccGlobalReader;
114
115    /// Initialize the accumulator without data
116    fn init_accumulators(#[comptime] config: Self::Config) -> Self::Accumulators;
117
118    /// Initialize the global writer at row m and column n
119    fn init_global_writer(
120        out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
121        #[comptime] config: Self::Config,
122    ) -> Self::GlobalWriter;
123}
124
125/// Configuration for the [global matmul](GlobalMatmul) level.
126pub trait GlobalConfig:
127    Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
128{
129    /// Underlying Stage matmul config
130    type StageConfig: StageConfig;
131
132    /// Convert itself to the underlying stage matmul config
133    fn stage_config(&self) -> Self::StageConfig;
134
135    fn stage_memory_config(&self, ident: MatmulIdent) -> StageMemoryConfig {
136        self.stage_config().stage_memory_config(ident.into_stage())
137    }
138
139    fn global_memory_config(&self, ident: MatmulIdent) -> GlobalMemoryConfig {
140        GlobalMemoryConfig {
141            elements_in_tile_row: self.tiling_scheme().elements_in_tile_row(ident),
142            elements_in_tile_col: self.tiling_scheme().elements_in_tile_col(ident),
143            elements_in_stage_row: self.tiling_scheme().elements_in_stage_row(ident),
144            elements_in_stage_col: self.tiling_scheme().elements_in_stage_col(ident),
145            global_line_size: self.global_line_size(ident),
146            check_row_bounds: self.check_row_bounds(ident),
147            check_col_bounds: self.check_col_bounds(ident),
148            matrix_layout: self.matrix_layout(ident),
149        }
150    }
151
152    /// Returns the line size for the global memory corresponding to the given ident
153    fn global_line_size(&self, ident: MatmulIdent) -> u32;
154
155    /// Returns the [TilingScheme]
156    fn tiling_scheme(&self) -> TilingScheme {
157        self.stage_config().tiling_scheme()
158    }
159
160    /// Returns the [MatrixLayout] for the given ident
161    fn matrix_layout(&self, ident: MatmulIdent) -> MatrixLayout;
162
163    /// Returns the number of planes participating in loading `ident`
164    fn num_loading_planes(&self, ident: MatmulIdent) -> u32;
165
166    /// Indicates the specialization roles for the planes
167    fn plane_role_config(&self) -> PlaneRoleConfig;
168
169    /// Indicates plane roles are associated to loading which tensor input
170    fn specialized_loading_sides(&self) -> SpecializedLoadingSides;
171
172    /// How to identify the role of the plane depending on its index
173    fn role_rule_config(&self) -> RoleRuleConfig {
174        self.plane_role_config().rule
175    }
176
177    /// Returns the size of the plane dimension
178    fn plane_dim(&self) -> u32;
179
180    /// Whether to check if accessing a row would exceed bounds.
181    fn check_row_bounds(&self, ident: MatmulIdent) -> bool;
182
183    /// Whether to check if accessing a col would exceed bounds.
184    fn check_col_bounds(&self, ident: MatmulIdent) -> bool;
185
186    /// Whether to check if accessing a col for lhs or row for rhs would exceed bounds.
187    fn check_k_bounds(&self) -> bool;
188
189    /// Whether to put common computations for loading tasks once before loop
190    fn precompute_job(&self) -> bool;
191
192    /// The number of stages in stage memory
193    fn num_stages(&self, ident: MatmulIdent) -> u32;
194
195    /// Whether to check reader is balanced in comptime or runtime.
196    ///
197    /// Not supported by all loading strategies
198    fn reader_mode(&self) -> ReaderMode;
199
200    /// Whether event loading is constrained to be ordered
201    fn event_loading_mode(&self, ident: MatmulIdent) -> EventLoadingMode;
202
203    /// Whether the matmul is quantized
204    fn quantized(&self) -> bool {
205        self.stage_config().quantized()
206    }
207
208    /// The [CubeDim] arising from the [TilingScheme]
209    fn cube_dim(&self) -> CubeDim;
210}