cubecl_linalg/matmul/components/tile/
accelerated.rs1use crate::matmul::components::config::MatmulConfig;
2use crate::matmul::components::tile::{TileConfig, TileMatmul, TileMatmulFamily};
3use crate::matmul::components::{
4 Ident, InvalidConfigError, MatmulConfigFactory, MatmulPrecision, MatmulProblem, MatmulSize,
5 MatrixLayout, as_cmma_layout,
6};
7use crate::matmul::kernels::MatmulAvailabilityError;
8use cubecl_core::ir::{Elem, FloatKind};
9use cubecl_core::{self as cubecl, Feature};
10use cubecl_core::{cmma, prelude::*};
11
12use super::{Tile, TileMatmulConfigInput};
13
14pub struct Accelerated;
15
16impl TileMatmulFamily for Accelerated {
17 type Matmul<MP: MatmulPrecision> = Accelerated;
18
19 fn tile_shape(config: &Self::Config) -> MatmulSize {
20 config.size
21 }
22
23 fn requires_tensor_cores() -> bool {
24 true
25 }
26}
27
28#[cube]
29impl<MP: MatmulPrecision> TileMatmul<MP> for Accelerated {
30 type Config = Config;
31 type Lhs = cmma::Matrix<MP::ES>;
32 type Rhs = cmma::Matrix<MP::ES>;
33 type Accumulator = cmma::Matrix<MP::EA>;
34
35 fn execute(
36 lhs: &Self::Lhs,
37 rhs: &Self::Rhs,
38 out: &mut Self::Accumulator,
39 #[comptime] _config: Config,
40 ) {
41 cmma::execute::<MP::ES, MP::ES, MP::EA, MP::EA>(lhs, rhs, out, out);
42 }
43
44 fn allocate_lhs(#[comptime] config: Config) -> Self::Lhs {
45 let size = config.size;
46 let layout = config.matrix_layout(Ident::Lhs);
47 unsafe {
48 cmma::Matrix::<MP::ES>::uninitialized(
49 cmma::MatrixIdent::A, size.m,
51 size.n,
52 size.k,
53 as_cmma_layout(layout),
54 )
55 }
56 }
57
58 fn allocate_rhs(#[comptime] config: Config) -> Self::Rhs {
59 let size = config.size;
60 let layout = config.matrix_layout(Ident::Rhs);
61 unsafe {
62 cmma::Matrix::<MP::ES>::uninitialized(
63 cmma::MatrixIdent::B,
64 size.m,
65 size.n,
66 size.k,
67 as_cmma_layout(layout),
68 )
69 }
70 }
71
72 fn fill_lhs(tile: &Tile<MP::ES>, lhs: &mut Self::Lhs, #[comptime] config: Config) {
73 let (slice, stride) = tile.as_unlined::<Config>(Ident::Lhs, config);
74 cmma::load(lhs, &slice, stride);
75 }
76
77 fn fill_rhs(tile: &Tile<MP::ES>, rhs: &mut Self::Rhs, #[comptime] config: Config) {
78 let (slice, stride) = tile.as_unlined::<Config>(Ident::Rhs, config);
79 cmma::load(rhs, &slice, stride);
80 }
81
82 fn fill_accumulator(
83 tile: &Tile<MP::EA>,
84 acc: &mut Self::Accumulator,
85 #[comptime] config: Config,
86 ) {
87 let layout = comptime!(as_cmma_layout(config.matrix_layout(Ident::Out)));
88 let (slice, stride) = tile.as_unlined::<Config>(Ident::Out, config);
89 cmma::load_with_layout(acc, &slice, stride, layout);
90 }
91
92 fn read_accumulator<C: Numeric>(
93 out: &Self::Accumulator,
94 slice: &mut SliceMut<Line<C>>,
95 #[comptime] config: Config,
96 ) {
97 let acc = cmma::cast::<MP::EA, C>(out);
98 cmma::store(slice, &acc, config.size.n, cmma::MatrixLayout::RowMajor);
99 }
100
101 fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator {
102 let size = config.size;
103 unsafe {
104 cmma::Matrix::<MP::EA>::uninitialized(
105 cmma::MatrixIdent::Accumulator,
106 size.m,
107 size.n,
108 size.k,
109 cmma::MatrixLayout::Undefined,
110 )
111 }
112 }
113
114 fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] _config: Self::Config) {
115 cmma::fill(acc, MP::EA::from_int(0));
116 }
117}
118
119impl MatmulConfigFactory for Accelerated {
120 type Input = TileMatmulConfigInput;
121 type Config = Config;
122
123 fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError> {
124 if config.plane_dim != 32 {
125 return Err(Box::new(
126 "Error: Expected plane dimension to be 32, but found {}. Please ensure that cube dimension x is set correctly.",
127 ));
128 }
129 Ok(())
130 }
131
132 fn check_availability<R: Runtime, MP: MatmulPrecision>(
133 client: &ComputeClient<R::Server, R::Channel>,
134 config: &Self::Config,
135 ) -> Result<(), MatmulAvailabilityError> {
136 if config.stage_dynamic_line_size
137 && !client
138 .properties()
139 .feature_enabled(Feature::DynamicLineSize)
140 {
141 return Err(MatmulAvailabilityError::DynamicLineSizeUnavailable);
142 }
143
144 let es = MP::ES::as_elem_native().expect("to be a native type");
145 let ea = MP::EA::as_elem_native().expect("to be a native type");
146
147 let es = match es {
148 Elem::Float(FloatKind::Flex32) => Elem::Float(FloatKind::F32),
149 _ => es,
150 };
151
152 let ea = match ea {
153 Elem::Float(FloatKind::Flex32) => Elem::Float(FloatKind::F32),
154 _ => ea,
155 };
156
157 let size = config.size;
158 if !client.properties().feature_enabled(Feature::Cmma {
159 a: es,
160 b: es,
161 c: ea,
162 m: size.m as u8,
163 k: size.k as u8,
164 n: size.n as u8,
165 }) {
166 return Err(MatmulAvailabilityError::CmmaInstructionUnavailable {
167 input: es,
168 output: ea,
169 shape: Some(MatmulSize {
170 m: size.m,
171 n: size.n,
172 k: size.k,
173 }),
174 });
175 }
176
177 if !(MP::ES::is_supported(client) && MP::EA::is_supported(client)) {
178 return Err(MatmulAvailabilityError::TypesUnavailable {
179 input: es,
180 output: ea,
181 });
182 }
183
184 Ok(())
185 }
186
187 fn make_config(
188 input: Self::Input,
189 problem: &MatmulProblem,
190 cube_dim: &CubeDim,
191 _cube_count: &CubeCount,
192 _quantized: bool,
193 ) -> Self::Config {
194 let (lhs_line_size, rhs_line_size, stage_line_size_update) =
195 if input.vectorization.stage_line_size == 0 {
196 (
197 problem.lhs_line_size as u32,
198 problem.rhs_line_size as u32,
199 false,
200 )
201 } else {
202 (
203 input.vectorization.stage_line_size as u32,
204 input.vectorization.stage_line_size as u32,
205 true,
206 )
207 };
208 Config::new(
209 input.size,
210 cube_dim.x,
211 problem.lhs_layout,
212 problem.rhs_layout,
213 stage_line_size_update,
214 lhs_line_size,
215 rhs_line_size,
216 problem.out_line_size as u32,
217 )
218 }
219}
220
221#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
222pub struct Config {
224 size: MatmulSize,
225 plane_dim: u32,
226 lhs_layout: MatrixLayout,
227 rhs_layout: MatrixLayout,
228 stage_dynamic_line_size: bool,
229 lhs_line_size: u32,
230 rhs_line_size: u32,
231 out_line_size: u32,
232}
233
234impl TileConfig for Config {
235 fn plane_dim(&self) -> u32 {
236 self.plane_dim
237 }
238
239 fn matrix_layout(&self, ident: Ident) -> MatrixLayout {
240 match ident {
241 Ident::Lhs => self.lhs_layout,
242 Ident::Rhs => self.rhs_layout,
243 Ident::Out => MatrixLayout::RowMajor,
244 }
245 }
246
247 fn stage_line_size(&self, ident: Ident) -> u32 {
248 match ident {
249 Ident::Lhs => self.lhs_line_size,
250 Ident::Rhs => self.rhs_line_size,
251 Ident::Out => self.out_line_size,
252 }
253 }
254
255 fn tile_shape(&self) -> &MatmulSize {
256 &self.size
257 }
258}
259
260impl MatmulConfig for Config {}
261
262impl Config {
263 #[allow(clippy::too_many_arguments)]
264 pub fn new(
265 size: MatmulSize,
266 plane_dim: u32,
267 lhs_layout: MatrixLayout,
268 rhs_layout: MatrixLayout,
269 stage_dynamic_line_size: bool,
270 lhs_line_size: u32,
271 rhs_line_size: u32,
272 out_line_size: u32,
273 ) -> Self {
274 Self {
275 size,
276 plane_dim,
277 lhs_layout,
278 rhs_layout,
279 stage_dynamic_line_size,
280 lhs_line_size,
281 rhs_line_size,
282 out_line_size,
283 }
284 }
285}