1use crate::components::global::args::TensorArgs;
2use crate::components::{
3 AccG, AccS,
4 batch::{BatchMatmulFamily, CubeCountInputArgs},
5};
6use crate::components::{
7 AvailableLineSizes, InputRuntimeArg, LhsG, LhsS, MatmulAvailabilityError, MatmulLineSizes,
8 MatmulPrecision, MatmulProblem, MatmulSelection, MatmulSetupError, MatmulSpec, MatrixLayout,
9 OutputRuntimeArg, RhsG, RhsS,
10};
11use crate::components::{global::args::TensorMapArgs, tile::TileMatmulFamily};
12use crate::kernels::layered::selector::launch_kernel_concrete;
13use crate::{MatmulInputHandle, MatmulInputHandleRef};
14use core::any::TypeId;
15use cubecl_core::{Runtime, client::ComputeClient, frontend::TensorHandleRef};
16use cubecl_core::{prelude::*, try_tensor_line_size_parallel};
17use cubecl_runtime::TypeUsage;
18use cubecl_std::tensor::{MatrixBatchLayout, TensorHandle, matrix_batch_layout};
19
20use super::Algorithm;
21
22#[derive(Debug, Clone)]
23pub enum Selection<S> {
24 Forced(MatmulSelection),
26 Inferred(S),
28}
29
30impl<S: Default + Clone> Selection<S> {
31 pub fn maybe_forced_default(s: &Option<MatmulSelection>) -> Self {
32 s.as_ref()
33 .map(|s| Self::Forced(s.clone()))
34 .unwrap_or_default()
35 }
36 pub fn maybe_forced_or(s: &Option<MatmulSelection>, args: &S) -> Self {
37 s.as_ref()
38 .map(|s| Self::Forced(s.clone()))
39 .unwrap_or_else(|| Self::Inferred(args.clone()))
40 }
41}
42
43impl<S: Default> Default for Selection<S> {
44 fn default() -> Self {
45 Self::Inferred(Default::default())
46 }
47}
48
49#[allow(clippy::result_large_err)]
54pub fn launch<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
55 client: &ComputeClient<R::Server>,
56 lhs: MatmulInputHandle<R, LhsG<MP>>,
57 rhs: MatmulInputHandle<R, RhsG<MP>>,
58 out: TensorHandle<R, AccG<MP>>,
59 selection: &Selection<A::SelectionArgs>,
60) -> Result<TensorHandle<R, AccG<MP>>, MatmulSetupError> {
61 let result = launch_ref::<R, MP, A>(
62 client,
63 &lhs.as_ref(),
64 &rhs.as_ref(),
65 &out.as_ref(),
66 selection,
67 );
68
69 match result {
70 Ok(_) => Ok(out),
71 Err(e) => Err(e),
72 }
73}
74
75#[allow(clippy::result_large_err)]
80pub fn launch_ref<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
81 client: &ComputeClient<R::Server>,
82 lhs: &MatmulInputHandleRef<'_, R>,
83 rhs: &MatmulInputHandleRef<'_, R>,
84 out: &TensorHandleRef<'_, R>,
85 selection: &Selection<A::SelectionArgs>,
86) -> Result<(), MatmulSetupError> {
87 let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_batch_layout(tensor.strides) {
88 MatrixBatchLayout::Contiguous => (false, false),
89 MatrixBatchLayout::MildlyPermuted {
90 transposed,
91 batch_swap: _,
92 } => (false, transposed),
93 MatrixBatchLayout::HighlyPermuted => (true, false),
94 };
95
96 let (lhs_make_contiguous, lhs_transposed) = check_layout(lhs.data());
97 let (rhs_make_contiguous, rhs_transposed) = check_layout(rhs.data());
98
99 let lhs_owned;
100 let rhs_owned;
101 let lhs = if lhs_make_contiguous {
102 lhs_owned = lhs.into_contiguous::<LhsG<MP>>(client);
103 &lhs_owned.as_ref()
104 } else {
105 lhs
106 };
107 let rhs = if rhs_make_contiguous {
108 rhs_owned = rhs.into_contiguous::<RhsG<MP>>(client);
109 &rhs_owned.as_ref()
110 } else {
111 rhs
112 };
113
114 launch_inner_ref::<R, MP, A>(
115 client,
116 lhs,
117 rhs,
118 out,
119 (lhs_transposed, rhs_transposed),
120 selection,
121 )
122}
123
124#[allow(clippy::result_large_err, clippy::too_many_arguments)]
125fn launch_inner_ref<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
126 client: &ComputeClient<R::Server>,
127 lhs_handle: &MatmulInputHandleRef<'_, R>,
128 rhs_handle: &MatmulInputHandleRef<'_, R>,
129 out: &TensorHandleRef<'_, R>,
130 transposed: (bool, bool),
131 selection: &Selection<A::SelectionArgs>,
132) -> Result<(), MatmulSetupError> {
133 let lhs_shape = lhs_handle.shape();
134 let rhs_shape = rhs_handle.shape();
135
136 let rank = lhs_shape.len();
137 let lhs_elem = LhsG::<MP>::as_type_native().expect("To be a native type");
138 let rhs_elem = RhsG::<MP>::as_type_native().expect("To be a native type");
139 let acc_elem = AccG::<MP>::as_type_native().expect("To be a native type");
140
141 if !LhsG::<MP>::supported_uses(client).contains(TypeUsage::Conversion)
142 || !RhsG::<MP>::supported_uses(client).contains(TypeUsage::Conversion)
143 || !AccG::<MP>::supported_uses(client).contains(TypeUsage::Conversion)
144 {
145 return Err(MatmulSetupError::Unavailable(
146 MatmulAvailabilityError::TypesUnavailable {
147 lhs: lhs_elem,
148 rhs: rhs_elem,
149 output: acc_elem,
150 },
151 ));
152 }
153
154 let m = lhs_shape[rank - 2] as u32;
155 let k = lhs_shape[rank - 1] as u32;
156 let n = rhs_shape[rank - 1] as u32;
157
158 let lhs_layout = match transposed.0 {
159 true => MatrixLayout::ColMajor,
160 false => MatrixLayout::RowMajor,
161 };
162
163 let rhs_layout = match transposed.1 {
164 true => MatrixLayout::ColMajor,
165 false => MatrixLayout::RowMajor,
166 };
167
168 let problem = MatmulProblem {
169 m: m as usize,
170 n: n as usize,
171 k: k as usize,
172 lhs_batches: lhs_shape[..lhs_shape.len() - 2].to_vec(),
173 rhs_batches: rhs_shape[..rhs_shape.len() - 2].to_vec(),
174 out_batches: out.shape[..out.shape.len() - 2].to_vec(),
175 lhs_layout,
176 rhs_layout,
177 };
178
179 let lhs = lhs_handle.data();
180 let rhs = rhs_handle.data();
181
182 let line_sizes =
183 AvailableLineSizes::from_type_sizes::<R>(lhs.elem_size, rhs.elem_size, out.elem_size);
184 let line_sizes = A::filter_line_sizes(line_sizes);
185 let mut line_sizes = line_sizes
186 .filter_lhs_with_tensor(lhs.strides, lhs.shape, problem.lhs_layout)
187 .filter_rhs_with_tensor(rhs.strides, rhs.shape, problem.rhs_layout)
188 .filter_out_with_tensor(out.strides, out.shape)
189 .pick_max()?;
190
191 if lhs_handle.scale().is_some() {
194 line_sizes.lhs = 1;
195 }
196 if rhs_handle.scale().is_some() {
197 line_sizes.rhs = 1;
198 }
199
200 let fix_plane_dim = |plane_dim: u32| {
201 if plane_dim == 0 { 32 } else { plane_dim }
207 };
208
209 let plane_dim = fix_plane_dim(A::select_plane_dim::<R>(client));
210
211 launch_inner_ref_fix_dtype::<R, MP, A>(
212 client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
213 )
214}
215
216#[allow(clippy::result_large_err, clippy::too_many_arguments)]
217fn launch_inner_ref_fix_dtype<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
218 client: &ComputeClient<R::Server>,
219 lhs: &MatmulInputHandleRef<'_, R>,
220 rhs: &MatmulInputHandleRef<'_, R>,
221 out: &TensorHandleRef<'_, R>,
222 problem: MatmulProblem,
223 line_sizes: MatmulLineSizes,
224 plane_dim: u32,
225 selection: &Selection<A::SelectionArgs>,
226) -> Result<(), MatmulSetupError> {
227 if <A::TileMatmul as TileMatmulFamily>::requires_accelerator()
228 && tf32::supported_uses(client).contains(TypeUsage::Conversion)
229 {
230 match (
231 TypeId::of::<LhsG<MP>>() == TypeId::of::<f32>(),
232 TypeId::of::<RhsG<MP>>() == TypeId::of::<f32>(),
233 ) {
234 (true, true) => launch_kernel_concrete::<
235 ((f32, f32, AccG<MP>, tf32, tf32, AccS<MP>), TensorArgs),
236 R,
237 A,
238 >(
239 client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
240 ),
241 (true, false) => launch_kernel_concrete::<
242 (
243 (f32, RhsG<MP>, AccG<MP>, tf32, RhsS<MP>, AccS<MP>),
244 TensorArgs,
245 ),
246 R,
247 A,
248 >(
249 client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
250 ),
251 (false, true) => launch_kernel_concrete::<
252 (
253 (LhsG<MP>, f32, AccG<MP>, LhsS<MP>, tf32, AccS<MP>),
254 TensorArgs,
255 ),
256 R,
257 A,
258 >(
259 client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
260 ),
261 (false, false) => launch_kernel_concrete::<(MP, TensorArgs), R, A>(
262 client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
263 ),
264 }
265 } else {
266 launch_kernel_concrete::<(MP, TensorArgs), R, A>(
267 client, lhs, rhs, out, problem, line_sizes, plane_dim, selection,
268 )
269 }
270}
271
272#[allow(clippy::result_large_err, clippy::too_many_arguments)]
273pub fn matmul_cmma_tma_ref_no_check<R: Runtime, MP: MatmulPrecision, A: Algorithm>(
274 client: &ComputeClient<R::Server>,
275 lhs_handle: &MatmulInputHandleRef<'_, R>,
276 rhs_handle: &MatmulInputHandleRef<'_, R>,
277 out: &TensorHandleRef<'_, R>,
278 transposed: (bool, bool),
279 selection: &Selection<A::SelectionArgs>,
280) -> Result<(), MatmulSetupError> {
281 let lhs = lhs_handle.data();
282 let rhs = rhs_handle.data();
283
284 let rank = lhs.strides.len();
285
286 let m = lhs.shape[rank - 2] as u32;
287 let k = lhs.shape[rank - 1] as u32;
288 let n = rhs.shape[rank - 1] as u32;
289
290 let lhs_layout = match transposed.0 {
291 true => MatrixLayout::ColMajor,
292 false => MatrixLayout::RowMajor,
293 };
294 let rhs_layout = match transposed.1 {
295 true => MatrixLayout::ColMajor,
296 false => MatrixLayout::RowMajor,
297 };
298
299 let line_sizes = MatmulLineSizes {
300 lhs: 1,
301 rhs: 1,
302 out: try_tensor_line_size_parallel(
303 R::io_optimized_line_sizes_unchecked(out.elem_size),
304 out.shape,
305 out.strides,
306 rank - 1,
307 )?,
308 };
309
310 let batch_lhs: usize = lhs.shape[..lhs.shape.len() - 2].iter().product();
311 let batch_rhs: usize = rhs.shape[..rhs.shape.len() - 2].iter().product();
312 let batch_out: usize = out.shape[..out.shape.len() - 2].iter().product();
313
314 let problem = MatmulProblem {
315 m: m as usize,
316 n: n as usize,
317 k: k as usize,
318 lhs_batches: [batch_lhs].to_vec(),
319 rhs_batches: [batch_rhs].to_vec(),
320 out_batches: [batch_out].to_vec(),
321 lhs_layout,
322 rhs_layout,
323 };
324
325 let plane_size = client.properties().hardware.plane_size_max;
326
327 let plane_dim = match plane_size {
328 32 | 64 => plane_size,
329 _ => {
330 return Err(MatmulSetupError::Unavailable(
331 MatmulAvailabilityError::PlaneDimUnsupported {
332 plane_dim: plane_size,
333 },
334 ));
335 }
336 };
337
338 if tf32::supported_uses(client).contains(TypeUsage::Conversion) {
339 match (
340 TypeId::of::<LhsG<MP>>() == TypeId::of::<f32>(),
341 TypeId::of::<RhsG<MP>>() == TypeId::of::<f32>(),
342 ) {
343 (true, true) => launch_kernel_concrete::<
344 ((f32, f32, AccG<MP>, tf32, tf32, AccS<MP>), TensorMapArgs),
345 R,
346 A,
347 >(
348 client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
349 ),
350 (true, false) => launch_kernel_concrete::<
351 (
352 (f32, RhsG<MP>, AccG<MP>, tf32, RhsS<MP>, AccS<MP>),
353 TensorMapArgs,
354 ),
355 R,
356 A,
357 >(
358 client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
359 ),
360 (false, true) => launch_kernel_concrete::<
361 (
362 (LhsG<MP>, f32, AccG<MP>, LhsS<MP>, tf32, AccS<MP>),
363 TensorMapArgs,
364 ),
365 R,
366 A,
367 >(
368 client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
369 ),
370 (false, false) => launch_kernel_concrete::<(MP, TensorMapArgs), R, A>(
371 client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
372 ),
373 }
374 } else {
375 launch_kernel_concrete::<(MP, TensorMapArgs), R, A>(
376 client, lhs_handle, rhs_handle, out, problem, line_sizes, plane_dim, selection,
377 )
378 }
379}
380
381#[allow(clippy::too_many_arguments, clippy::result_large_err)]
382pub fn launch_with_config<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>(
383 client: &ComputeClient<R::Server>,
384 cube_dim: CubeDim,
385 cube_count: CubeCount,
386 input: InputRuntimeArg<'a, MS, R>,
387 output: OutputRuntimeArg<'a, MS, R>,
388 cube_count_input: CubeCountInputArgs<'a, R>,
389 config: <A::BatchMatmul as BatchMatmulFamily>::Config,
390) -> Result<(), MatmulSetupError> {
391 unsafe {
392 A::BatchMatmul::launch_unchecked::<MS, R>(
393 client,
394 cube_dim,
395 cube_count,
396 input,
397 output,
398 cube_count_input,
399 config,
400 );
401 };
402
403 Ok(())
404}