cubecl_linalg/matmul/components/global/
base.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::matmul::components::{
5 Ident, InvalidConfigError, MatmulConfigFactory, MatmulPrecision, MatrixLayout,
6 TilingDimensions,
7 config::MatmulConfig,
8 stage::{self, StageWriter},
9 tile,
10};
11use cubecl_std::{
12 CubeOption,
13 tensor::r#virtual::{ReadWrite, VirtualTensor},
14};
15
16use super::Quantization;
17
18pub trait GlobalMatmulFamily:
20 MatmulConfigFactory<Config: GlobalConfig> + Send + Sync + 'static
21{
22 type Matmul<MP: MatmulPrecision>: GlobalMatmul<MP, Config = Self::Config>;
23}
24
25#[cube]
26pub trait GlobalMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
45 type Config: GlobalConfig;
46 type LhsLoader: CubeType;
47 type RhsLoader: CubeType;
48 type AccumulatorLoader: CubeType;
49 type Out: OutputLoader<MP::EO>;
50 type Accumulator: CubeType;
51
52 fn execute(
59 lhs_loader: Self::LhsLoader,
60 rhs_loader: Self::RhsLoader,
61 unloader: Self::Out,
62 acc: &mut Self::Accumulator,
63 k_range: (u32, u32),
64 #[comptime] config: Self::Config,
65 );
66
67 fn init_lhs_loader(
69 lhs: VirtualTensor<MP::EI>,
70 m_offset: u32,
71 k_offset: u32,
72 nth_batch: u32,
73 batch_offset: u32,
74 quantization: CubeOption<Quantization<MP>>,
75 #[comptime] config: Self::Config,
76 ) -> Self::LhsLoader;
77
78 fn init_rhs_loader(
80 rhs: VirtualTensor<MP::EI>,
81 k_offset: u32,
82 n_offset: u32,
83 nth_batch: u32,
84 batch_offset: u32,
85 quantization: CubeOption<Quantization<MP>>,
86 #[comptime] config: Self::Config,
87 ) -> Self::RhsLoader;
88
89 fn init_unloader(
91 out: VirtualTensor<MP::EO, ReadWrite>,
92 m_offset: u32,
93 n_offset: u32,
94 nth_batch: u32,
95 batch_offset: u32,
96 ) -> Self::Out;
97
98 fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
100
101 fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config);
103}
104
105#[cube]
106pub trait AccumulatorLoader<MP: MatmulPrecision>: CubeType + 'static + Send + Sync {
109 fn fill_stage<G: GlobalConfig>(this: &mut Self, #[comptime] config: G);
110
111 fn load<Tile: tile::TileMatmul<MP>>(
114 this: &mut Self,
115 acc: &mut Tile::Accumulator,
116 nth_tile: u32,
117 #[comptime] config: Tile::Config,
118 );
119}
120
121#[cube]
122pub trait OutputLoader<EO: Numeric>: CubeType + 'static + Send + Sync {
129 type StageWriter: StageWriter<EO>;
130
131 fn as_stage_writer<G: GlobalConfig>(unloader: Self) -> Self::StageWriter;
132}
133
134pub trait LoadingValidation {
135 fn check<C: GlobalConfig>(config: &C, ident: Ident) -> Result<(), InvalidConfigError>;
136}
137
138pub trait GlobalConfig: MatmulConfig {
140 type SmmConfig: stage::StageConfig;
142
143 fn to_smm_config(&self) -> Self::SmmConfig;
145
146 fn global_line_size<I: Into<Ident>>(&self, ident: I) -> u32;
148
149 fn tiling_dimensions<I: Into<Ident>>(&self, ident: I) -> TilingDimensions;
151
152 fn matrix_layout<I: Into<Ident>>(&self, ident: I) -> MatrixLayout;
154
155 fn num_planes(&self) -> u32;
157
158 fn plane_dim(&self) -> u32;
160
161 fn check_row_bounds<I: Into<Ident>>(&self, ident: I) -> bool;
163
164 fn check_col_bounds<I: Into<Ident>>(&self, ident: I) -> bool;
166
167 fn check_k_bounds(&self) -> bool;
169
170 fn precompute_job(&self) -> bool;
171}