cubecl_linalg/matmul/components/global/
base.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::matmul::components::{
5    Ident, InvalidConfigError, MatmulConfigFactory, MatmulPrecision, MatrixLayout,
6    TilingDimensions,
7    config::MatmulConfig,
8    stage::{self, StageWriter},
9    tile,
10};
11use cubecl_std::{
12    CubeOption,
13    tensor::r#virtual::{ReadWrite, VirtualTensor},
14};
15
16use super::Quantization;
17
18/// A family of [matmuls](GlobalMatmul) working with any [precision](MatmulPrecision).
19pub trait GlobalMatmulFamily:
20    MatmulConfigFactory<Config: GlobalConfig> + Send + Sync + 'static
21{
22    type Matmul<MP: MatmulPrecision>: GlobalMatmul<MP, Config = Self::Config>;
23}
24
25#[cube]
26/// Provides matrix multiplication operations at the global level.
27///
28/// At the global level,
29///  - Inputs are views over global memory, meaning access is given to
30///    only parts of the global memory inputs at once.
31///  - All planes within a Cube can collaborate to solve the problem
32///  - Dimensions M and N are fixed to an integer, but K is arbitrary large.
33///    The matrix multiplication works only for size (M, _) ยท (_, N) = (M, N).
34///    M and N should match the underlying Stage matmul's M and N.
35///
36/// # Assumptions
37/// - Line sizes of the inputs evenly divide the dimension they are aligned with.
38///
39/// # Safety
40///
41/// It is not assumed that the matmul's dimensions match its inputs dimensions perfectly.
42/// It is therefore important that Loaders and Unloaders perform checks to avoid out-of-bounds
43/// before loading data.
44pub trait GlobalMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
45    type Config: GlobalConfig;
46    type LhsLoader: CubeType;
47    type RhsLoader: CubeType;
48    type AccumulatorLoader: CubeType;
49    type Out: OutputLoader<MP::EO>;
50    type Accumulator: CubeType;
51
52    /// Performs the matrix multiplication over data loaded by the
53    /// LHS and RHS loaders, over the range given for K, and stores with
54    /// using the output unloader.
55    ///
56    /// To compute the whole range of k values, use k_range=(0, K) where
57    /// K is the K dimension of LHS and RHS.
58    fn execute(
59        lhs_loader: Self::LhsLoader,
60        rhs_loader: Self::RhsLoader,
61        unloader: Self::Out,
62        acc: &mut Self::Accumulator,
63        k_range: (u32, u32),
64        #[comptime] config: Self::Config,
65    );
66
67    /// Initialize the loader for Lhs, starting at row m and column k
68    fn init_lhs_loader(
69        lhs: VirtualTensor<MP::EI>,
70        m_offset: u32,
71        k_offset: u32,
72        nth_batch: u32,
73        batch_offset: u32,
74        quantization: CubeOption<Quantization<MP>>,
75        #[comptime] config: Self::Config,
76    ) -> Self::LhsLoader;
77
78    /// Initialize the loader for Rhs, starting at row k and column n
79    fn init_rhs_loader(
80        rhs: VirtualTensor<MP::EI>,
81        k_offset: u32,
82        n_offset: u32,
83        nth_batch: u32,
84        batch_offset: u32,
85        quantization: CubeOption<Quantization<MP>>,
86        #[comptime] config: Self::Config,
87    ) -> Self::RhsLoader;
88
89    /// Initialize the unloader at row m and column n
90    fn init_unloader(
91        out: VirtualTensor<MP::EO, ReadWrite>,
92        m_offset: u32,
93        n_offset: u32,
94        nth_batch: u32,
95        batch_offset: u32,
96    ) -> Self::Out;
97
98    /// Initialize the accumulator without data
99    fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
100
101    /// Fill the accumulator with zeros
102    fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config);
103}
104
105#[cube]
106/// Input to the global matmul accumulator, responsible of filling the stage and providing a reader
107/// for it.
108pub trait AccumulatorLoader<MP: MatmulPrecision>: CubeType + 'static + Send + Sync {
109    fn fill_stage<G: GlobalConfig>(this: &mut Self, #[comptime] config: G);
110
111    /// Load accumulator for `nth_tile`. Should call either `zero_accumulator` or `fill_accumulator`
112    /// for the underlying tile.
113    fn load<Tile: tile::TileMatmul<MP>>(
114        this: &mut Self,
115        acc: &mut Tile::Accumulator,
116        nth_tile: u32,
117        #[comptime] config: Tile::Config,
118    );
119}
120
121#[cube]
122/// Output to the global matmul
123///
124/// # Note
125///
126/// It is only a wrapper over the stage writer because there is no K for the output.
127/// Could be deleted in favor of having only the StageWriter
128pub trait OutputLoader<EO: Numeric>: CubeType + 'static + Send + Sync {
129    type StageWriter: StageWriter<EO>;
130
131    fn as_stage_writer<G: GlobalConfig>(unloader: Self) -> Self::StageWriter;
132}
133
134pub trait LoadingValidation {
135    fn check<C: GlobalConfig>(config: &C, ident: Ident) -> Result<(), InvalidConfigError>;
136}
137
138/// Configuration for the [global matmul](GlobalMatmul) level.
139pub trait GlobalConfig: MatmulConfig {
140    /// Underlying Stage matmul config
141    type SmmConfig: stage::StageConfig;
142
143    /// Convert itself to the underlying stage matmul config
144    fn to_smm_config(&self) -> Self::SmmConfig;
145
146    /// Returns the line size for the global memory corresponding to the given ident
147    fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32;
148
149    /// Returns the [StageTiling] for the given ident
150    fn tiling_dimensions<I: Into<Ident>>(&self, ident: I) -> TilingDimensions;
151
152    /// Returns the [MatrixLayout] for the given ident
153    fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout;
154
155    /// Returns the number of planes in the cube
156    fn num_planes(&self) -> u32;
157
158    /// Returns the size of the plane dimension
159    fn plane_dim(&self) -> u32;
160
161    /// Whether to check if accessing a row would exceed bounds.
162    fn check_row_bounds<I: Into<Ident>>(&self, ident: I) -> bool;
163
164    /// Whether to check if accessing a col would exceed bounds.
165    fn check_col_bounds<I: Into<Ident>>(&self, ident: I) -> bool;
166
167    /// Whether to check if accessing a col for lhs or row for rhs would exceed bounds.
168    fn check_k_bounds(&self) -> bool;
169
170    fn precompute_job(&self) -> bool;
171}