1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, server::TensorMapMeta, unexpanded};
3use cubecl_std::{
4 CubeOption, CubeOptionArgs, CubeOptionExpand,
5 tensor::{
6 View,
7 launch::ViewArg,
8 layout::{Coords1d, Coords3d, VirtualLayout, VirtualLayoutLaunch},
9 },
10};
11
12use crate::{
13 MatmulInputHandleRef,
14 components::{
15 self, MatmulElems, MatmulIdent, MatmulLineSizes, MatmulProblem, MatmulSelection,
16 batch::BatchConfig,
17 global::{
18 GlobalConfig,
19 memory::{
20 BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutLaunch,
21 GlobalScaleLayout, NoopLayout, NoopLayoutLaunch, SimpleTmaGlobalLayout,
22 SimpleTmaGlobalLayoutLaunch,
23 },
24 },
25 stage::SwizzleMode,
26 },
27};
28
29pub trait ConcreteInputsFactory: LaunchArg {
32 #[allow(clippy::too_many_arguments)]
33 fn create<'a, R: Runtime>(
34 client: &ComputeClient<R::Server>,
35 lhs: &'a MatmulInputHandleRef<'a, R>,
36 rhs: &'a MatmulInputHandleRef<'a, R>,
37 selection: &MatmulSelection,
38 problem: &MatmulProblem,
39 line_sizes: &MatmulLineSizes,
40 config: impl BatchConfig,
41 dtypes: &MatmulElems,
42 ) -> Self::RuntimeArg<'a, R>;
43}
44
45pub trait ConcreteOutputFactory: LaunchArg {
48 #[allow(clippy::too_many_arguments)]
49 fn create<'a, R: Runtime>(
50 client: &ComputeClient<R::Server>,
51 out: &'a TensorHandleRef<'a, R>,
52 selection: &MatmulSelection,
53 problem: &MatmulProblem,
54 line_sizes: &MatmulLineSizes,
55 config: impl BatchConfig,
56 dtypes: &MatmulElems,
57 ) -> Self::RuntimeArg<'a, R>;
58}
59
60#[cube]
61pub trait MatmulArgs: Send + Sync + 'static + Clone {
63 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: LaunchArg + CubeType;
65
66 type Output<EO: Numeric>: LaunchArg + LaunchArg + CubeType;
68
69 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: CubeType;
72
73 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
75 input: &Self::Input<Lhs, Rhs, EO>,
76 output: &mut Self::Output<EO>,
77 #[comptime] config: G,
78 ) -> Self::State<Lhs, Rhs, EO>;
79
80 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
81 _state: &Self::State<Lhs, Rhs, EO>,
82 ) -> View<Line<Lhs>, Coords3d> {
83 unexpanded!()
84 }
85 fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
86 _state: &Self::State<Lhs, Rhs, EO>,
87 _batch: u32,
88 ) -> u32 {
89 unexpanded!()
90 }
91 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
92 _state: &Self::State<Lhs, Rhs, EO>,
93 ) -> View<Line<Rhs>, Coords3d> {
94 unexpanded!()
95 }
96 fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
97 _state: &Self::State<Lhs, Rhs, EO>,
98 _batch: u32,
99 ) -> u32 {
100 unexpanded!()
101 }
102 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
103 _state: &Self::State<Lhs, Rhs, EO>,
104 ) -> CubeOption<View<Line<EO>, Coords3d>> {
105 unexpanded!()
106 }
107 fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
108 _state: &Self::State<Lhs, Rhs, EO>,
109 _batch: u32,
110 ) -> u32 {
111 unexpanded!()
112 }
113 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
114 _state: &mut Self::State<Lhs, Rhs, EO>,
115 ) -> View<Line<EO>, Coords3d, ReadWrite> {
116 unexpanded!()
117 }
118 fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
119 _state: &Self::State<Lhs, Rhs, EO>,
120 _batch: u32,
121 ) -> u32 {
122 unexpanded!()
123 }
124}
125
126#[derive(Clone, Copy)]
127pub enum TensorInputIdent {
129 Lhs,
130 Rhs,
131}
132
133#[derive(Clone)]
134pub struct TensorArgs;
138
139#[derive(CubeLaunch, CubeType, Clone, Copy)]
140pub struct TensorInputs<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> {
142 lhs: View<Line<Lhs>, Coords3d>,
144 lhs_batch: VirtualLayout<Coords1d, Coords1d>,
145 rhs: View<Line<Rhs>, Coords3d>,
147 rhs_batch: VirtualLayout<Coords1d, Coords1d>,
148 acc: CubeOption<View<Line<Acc>, Coords3d>>,
150 acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
151}
152
153impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> ConcreteInputsFactory
154 for TensorInputs<Lhs, Rhs, Acc>
155{
156 fn create<'a, R: Runtime>(
157 client: &ComputeClient<R::Server>,
158 lhs: &'a MatmulInputHandleRef<'a, R>,
159 rhs: &'a MatmulInputHandleRef<'a, R>,
160 _selection: &MatmulSelection,
161 problem: &MatmulProblem,
162 line_sizes: &MatmulLineSizes,
163 config: impl BatchConfig,
164 _dtypes: &MatmulElems,
165 ) -> Self::RuntimeArg<'a, R> {
166 let config = config.global_config();
167 let view = |handle: &'a MatmulInputHandleRef<'a, R>, ident, line_size| match handle {
168 MatmulInputHandleRef::Normal(handle, _dtype) => {
169 let layout = GlobalLayoutLaunch::from_handle(
170 handle,
171 line_size,
172 config.global_memory_config(ident).into(),
173 );
174 ViewArg::new::<GlobalLayout>(handle.as_array_arg(line_size), layout)
175 }
176 MatmulInputHandleRef::Quantized {
177 data,
178 scale,
179 shape,
180 scheme,
181 ..
182 } => {
183 let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
184 client,
185 data,
186 scale,
187 shape,
188 problem,
189 **scheme,
190 line_size,
191 config.global_memory_config(ident).into(),
192 );
193 let data_view =
194 ViewArg::new::<GlobalLayout>(data.as_array_arg(line_size), data_layout);
195 let scales_view =
196 ViewArg::new::<GlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
197 ViewArg::new_quantized(data_view, scales_view, **scheme)
198 }
199 };
200 let batch_layout = |handle: &'a MatmulInputHandleRef<'a, R>| match handle {
201 MatmulInputHandleRef::Normal(handle, _dtype) => {
202 let layout = BatchLayoutLaunch::from_handle(client, handle, problem);
203 VirtualLayoutLaunch::new::<BatchLayout>(layout)
204 }
205 MatmulInputHandleRef::Quantized { .. } => {
206 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
207 }
208 };
209
210 TensorInputsLaunch::new(
211 view(lhs, MatmulIdent::Lhs, line_sizes.lhs),
212 batch_layout(lhs),
213 view(rhs, MatmulIdent::Rhs, line_sizes.rhs),
214 batch_layout(rhs),
215 CubeOptionArgs::None,
216 CubeOptionArgs::None,
217 )
218 }
219}
220
221#[derive(CubeType, CubeLaunch, Clone, Copy)]
222pub struct TensorOutput<EG: Numeric> {
223 view: View<Line<EG>, Coords3d, ReadWrite>,
224 batch: VirtualLayout<Coords1d, Coords1d>,
225}
226
227impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
228 fn create<'a, R: Runtime>(
229 client: &ComputeClient<R::Server>,
230 out: &'a TensorHandleRef<'a, R>,
231 _selection: &MatmulSelection,
232 problem: &MatmulProblem,
233 line_sizes: &MatmulLineSizes,
234 config: impl BatchConfig,
235 _dtypes: &MatmulElems,
236 ) -> Self::RuntimeArg<'a, R> {
237 let config = config.global_config();
238 let layout = GlobalLayoutLaunch::from_handle(
239 out,
240 line_sizes.out,
241 config.global_memory_config(MatmulIdent::Out).into(),
242 );
243 let batch = BatchLayoutLaunch::from_handle(client, out, problem);
244 let view = ViewArg::new::<GlobalLayout>(out.as_array_arg(line_sizes.out), layout);
245 TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
246 }
247}
248
249#[cube]
250impl MatmulArgs for TensorArgs {
251 type Output<EO: Numeric> = TensorOutput<EO>;
252 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorInputs<Lhs, Rhs, EO>;
253 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
254 (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>);
255
256 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
257 input: &Self::Input<Lhs, Rhs, EO>,
258 output: &mut Self::Output<EO>,
259 #[comptime] _config: G,
260 ) -> Self::State<Lhs, Rhs, EO> {
261 (*input, *output)
262 }
263
264 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
265 state: &Self::State<Lhs, Rhs, EO>,
266 ) -> View<Line<Lhs>, Coords3d> {
267 state.0.lhs
268 }
269
270 fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
271 state: &Self::State<Lhs, Rhs, EO>,
272 batch: u32,
273 ) -> u32 {
274 state.0.lhs_batch.to_source_pos(batch)
275 }
276
277 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
278 state: &Self::State<Lhs, Rhs, EO>,
279 ) -> View<Line<Rhs>, Coords3d> {
280 state.0.rhs
281 }
282
283 fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
284 state: &Self::State<Lhs, Rhs, EO>,
285 batch: u32,
286 ) -> u32 {
287 state.0.rhs_batch.to_source_pos(batch)
288 }
289
290 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
291 state: &Self::State<Lhs, Rhs, EO>,
292 ) -> CubeOption<View<Line<EO>, Coords3d>> {
293 state.0.acc
294 }
295
296 fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
297 state: &Self::State<Lhs, Rhs, EO>,
298 batch: u32,
299 ) -> u32 {
300 match state.0.acc_batch {
301 CubeOption::Some(layout) => layout.to_source_pos(batch),
302 CubeOption::None => batch,
303 }
304 }
305
306 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
307 state: &mut Self::State<Lhs, Rhs, EO>,
308 ) -> View<Line<EO>, Coords3d, ReadWrite> {
309 state.1.view
310 }
311
312 fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
313 state: &Self::State<Lhs, Rhs, EO>,
314 batch: u32,
315 ) -> u32 {
316 state.1.batch.to_source_pos(batch)
317 }
318}
319
320#[derive(Clone)]
321pub struct TensorMapArgs;
325
326#[derive(CubeLaunch, CubeType, Clone, Copy)]
327pub struct TensorMapInputs<Lhs: Numeric, Rhs: Numeric, EO: Numeric> {
329 pub lhs: View<Line<Lhs>, Coords3d>,
331 pub rhs: View<Line<Rhs>, Coords3d>,
333 pub acc: CubeOption<View<Line<EO>, Coords3d>>,
335 pub acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
337}
338
339impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
340 for TensorMapInputs<Lhs, Rhs, EO>
341{
342 fn create<'a, R: Runtime>(
343 _client: &ComputeClient<R::Server>,
344 lhs_handle: &'a MatmulInputHandleRef<'a, R>,
345 rhs_handle: &'a MatmulInputHandleRef<'a, R>,
346 selection: &MatmulSelection,
347 problem: &MatmulProblem,
348 line_sizes: &MatmulLineSizes,
349 config: impl BatchConfig,
350 dtypes: &MatmulElems,
351 ) -> Self::RuntimeArg<'a, R> {
352 let lhs = lhs_handle.data();
353 let rhs = rhs_handle.data();
354
355 let config = config.global_config();
356
357 let tiling_scheme = selection.tiling_scheme;
358 let stage_m = tiling_scheme.elements_in_stage_m();
359 let stage_n = tiling_scheme.elements_in_stage_n();
360 let stage_k = tiling_scheme.elements_in_stage_k();
361
362 let stage_size_lhs = match config.swizzle_mode(MatmulIdent::Lhs) {
366 SwizzleMode::None => match problem.lhs_layout {
367 components::MatrixLayout::RowMajor => {
368 vec![1, stage_m, tiling_scheme.elements_in_tile_k()]
369 }
370 components::MatrixLayout::ColMajor => {
371 vec![1, stage_k, tiling_scheme.elements_in_tile_m()]
372 }
373 },
374 _ => match problem.lhs_layout {
375 components::MatrixLayout::RowMajor => {
376 vec![1, stage_m, stage_k]
377 }
378 components::MatrixLayout::ColMajor => {
379 vec![1, stage_k, stage_m]
380 }
381 },
382 };
383 let stage_size_rhs = match config.swizzle_mode(MatmulIdent::Rhs) {
384 SwizzleMode::None => match problem.rhs_layout {
385 components::MatrixLayout::RowMajor => {
386 vec![1, stage_k, tiling_scheme.elements_in_tile_n()]
387 }
388 components::MatrixLayout::ColMajor => {
389 vec![1, stage_n, tiling_scheme.elements_in_tile_k()]
390 }
391 },
392 _ => match problem.rhs_layout {
393 components::MatrixLayout::RowMajor => {
394 vec![1, stage_k, stage_n]
395 }
396 components::MatrixLayout::ColMajor => {
397 vec![1, stage_n, stage_k]
398 }
399 },
400 };
401
402 let lhs_rank = lhs.shape.len();
403 let mut lhs_shape = vec![
404 problem.lhs_batches.iter().product(),
405 lhs.shape[lhs_rank - 2],
406 lhs.shape[lhs_rank - 1],
407 ];
408 let mut lhs_strides = if lhs_rank > 2 {
409 lhs.strides[lhs_rank - 3..].to_vec()
410 } else {
411 vec![lhs.strides[0], lhs.strides[1]]
412 };
413
414 let rhs_rank = rhs.shape.len();
415 let mut rhs_shape = vec![
416 problem.rhs_batches.iter().product(),
417 rhs.shape[rhs_rank - 2],
418 rhs.shape[rhs_rank - 1],
419 ];
420 let mut rhs_strides = if rhs_rank > 2 {
421 rhs.strides[rhs_rank - 3..].to_vec()
422 } else {
423 vec![rhs.strides[0], rhs.strides[1]]
424 };
425
426 let mut lhs_transposed = false;
427 let mut rhs_transposed = false;
428
429 let lhs_rank = lhs_strides.len();
430 let rhs_rank = rhs_strides.len();
431
432 if matches!(problem.lhs_layout, components::MatrixLayout::ColMajor) {
435 lhs_shape.swap(2, 1);
436 lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
437 lhs_transposed = true;
438 }
439 if matches!(problem.rhs_layout, components::MatrixLayout::ColMajor) {
440 rhs_shape.swap(2, 1);
441 rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
442 rhs_transposed = true;
443 }
444
445 if lhs_rank == 2 {
447 let stride = lhs_strides[0];
448 lhs_strides.insert(0, stride);
449 }
450 if rhs_rank == 2 {
451 let stride = rhs_strides[0];
452 rhs_strides.insert(0, stride);
453 }
454
455 fn swizzle(mode: SwizzleMode) -> TensorMapSwizzle {
456 match mode {
457 SwizzleMode::None => TensorMapSwizzle::None,
458 SwizzleMode::B32 => TensorMapSwizzle::B32,
459 SwizzleMode::B64 => TensorMapSwizzle::B64,
460 SwizzleMode::B128 => TensorMapSwizzle::B128,
461 }
462 }
463
464 let swizzle_lhs = swizzle(config.swizzle_mode(MatmulIdent::Lhs));
465 let swizzle_rhs = swizzle(config.swizzle_mode(MatmulIdent::Rhs));
466
467 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
470 tf32::as_type_native_unchecked()
471 } else {
472 dtypes.lhs_stage
473 };
474 let rhs_elem = if dtypes.rhs_stage == f32::as_type_native_unchecked() {
475 tf32::as_type_native_unchecked()
476 } else {
477 dtypes.rhs_stage
478 };
479
480 let meta_lhs = TensorMapMeta {
481 format: TensorMapFormat::Tiled {
482 tile_size: stage_size_lhs,
483 },
484 rank: 3,
485 shape: lhs_shape.clone(),
486 strides: lhs_strides,
487 elem_stride: vec![1, 1, 1],
488 interleave: TensorMapInterleave::None,
489 swizzle: swizzle_lhs,
490 prefetch: TensorMapPrefetch::None,
491 oob_fill: OobFill::Zero,
492 storage_ty: lhs_elem,
493 };
494
495 let meta_rhs = TensorMapMeta {
496 format: TensorMapFormat::Tiled {
497 tile_size: stage_size_rhs,
498 },
499 rank: 3,
500 shape: rhs_shape.clone(),
501 strides: rhs_strides,
502 elem_stride: vec![1, 1, 1],
503 interleave: TensorMapInterleave::None,
504 swizzle: swizzle_rhs,
505 prefetch: TensorMapPrefetch::None,
506 oob_fill: OobFill::Zero,
507 storage_ty: rhs_elem,
508 };
509
510 let lhs = TensorMapArg {
511 tensor: lhs.as_tensor_arg(line_sizes.lhs),
512 metadata: meta_lhs,
513 };
514 let rhs = TensorMapArg {
515 tensor: rhs.as_tensor_arg(line_sizes.rhs),
516 metadata: meta_rhs,
517 };
518
519 let view = |buffer, shape: &[usize], transposed| {
520 let batches = ScalarArg::new(shape[0] as u32);
521 let (rows, cols) = match transposed {
522 true => (
523 ScalarArg::new(shape[2] as u32),
524 ScalarArg::new(shape[1] as u32),
525 ),
526 false => (
527 ScalarArg::new(shape[1] as u32),
528 ScalarArg::new(shape[2] as u32),
529 ),
530 };
531 let shape = (batches, rows, cols);
532 let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
533 ViewArg::new_tensor_map::<SimpleTmaGlobalLayout>(buffer, layout)
534 };
535
536 TensorMapInputsLaunch::new(
537 view(lhs, &lhs_shape, lhs_transposed),
538 view(rhs, &rhs_shape, rhs_transposed),
539 CubeOptionArgs::None,
540 CubeOptionArgs::None,
541 )
542 }
543}
544
545#[cube]
546impl MatmulArgs for TensorMapArgs {
547 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorMapInputs<Lhs, Rhs, EO>;
548 type Output<EO: Numeric> = TensorOutput<EO>;
549 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
550 (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>);
551
552 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
553 input: &Self::Input<Lhs, Rhs, EO>,
554 output: &mut Self::Output<EO>,
555 #[comptime] _config: G,
556 ) -> Self::State<Lhs, Rhs, EO> {
557 (*input, *output)
558 }
559
560 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
561 state: &Self::State<Lhs, Rhs, EO>,
562 ) -> View<Line<Lhs>, Coords3d> {
563 state.0.lhs
564 }
565
566 fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
567 _state: &Self::State<Lhs, Rhs, EO>,
568 batch: u32,
569 ) -> u32 {
570 batch
571 }
572
573 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
574 state: &Self::State<Lhs, Rhs, EO>,
575 ) -> View<Line<Rhs>, Coords3d> {
576 state.0.rhs
577 }
578
579 fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
580 _state: &Self::State<Lhs, Rhs, EO>,
581 batch: u32,
582 ) -> u32 {
583 batch
584 }
585
586 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
587 state: &Self::State<Lhs, Rhs, EO>,
588 ) -> CubeOption<View<Line<EO>, Coords3d>> {
589 state.0.acc
590 }
591
592 fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
593 state: &Self::State<Lhs, Rhs, EO>,
594 batch: u32,
595 ) -> u32 {
596 match state.0.acc_batch {
597 CubeOption::Some(layout) => layout.to_source_pos(batch),
598 CubeOption::None => batch,
599 }
600 }
601
602 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
603 state: &mut Self::State<Lhs, Rhs, EO>,
604 ) -> View<Line<EO>, Coords3d, ReadWrite> {
605 state.1.view
606 }
607
608 fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
609 state: &Self::State<Lhs, Rhs, EO>,
610 batch: u32,
611 ) -> u32 {
612 state.1.batch.to_source_pos(batch)
613 }
614}