1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, server::TensorMapMeta, unexpanded};
3use cubecl_std::{
4 CubeOption, CubeOptionArgs, CubeOptionExpand,
5 tensor::{
6 View,
7 launch::ViewArg,
8 layout::{Coords1d, Coords3d, VirtualLayout, VirtualLayoutLaunch},
9 },
10};
11
12use crate::{
13 MatmulInputHandleRef,
14 components::{
15 self, MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection,
16 batch::BatchConfig,
17 global::{
18 GlobalConfig,
19 memory::{
20 BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutConfig,
21 GlobalLayoutLaunch, GlobalScaleLayout, NoopLayout, NoopLayoutLaunch,
22 SimpleTmaGlobalLayout, SimpleTmaGlobalLayoutLaunch,
23 },
24 },
25 stage::SwizzleMode,
26 },
27};
28
29pub trait ConcreteInputsFactory: LaunchArg {
32 #[allow(clippy::too_many_arguments)]
33 fn create<'a, R: Runtime>(
34 client: &ComputeClient<R::Server>,
35 lhs: &'a MatmulInputHandleRef<'a, R>,
36 rhs: &'a MatmulInputHandleRef<'a, R>,
37 selection: &MatmulSelection,
38 problem: &MatmulProblem,
39 line_sizes: &MatmulLineSizes,
40 config: impl BatchConfig,
41 dtypes: &MatmulElems,
42 ) -> Self::RuntimeArg<'a, R>;
43}
44
45pub trait ConcreteOutputFactory: LaunchArg {
48 #[allow(clippy::too_many_arguments)]
49 fn create<'a, R: Runtime>(
50 client: &ComputeClient<R::Server>,
51 out: &'a TensorHandleRef<'a, R>,
52 selection: &MatmulSelection,
53 problem: &MatmulProblem,
54 line_sizes: &MatmulLineSizes,
55 config: impl BatchConfig,
56 dtypes: &MatmulElems,
57 ) -> Self::RuntimeArg<'a, R>;
58}
59
60#[cube]
61pub trait MatmulArgs: Send + Sync + 'static + Clone {
63 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: LaunchArg + CubeType;
65
66 type Output<EO: Numeric>: LaunchArg + LaunchArg + CubeType;
68
69 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: CubeType;
72
73 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
75 input: &Self::Input<Lhs, Rhs, EO>,
76 output: &mut Self::Output<EO>,
77 #[comptime] config: G,
78 ) -> Self::State<Lhs, Rhs, EO>;
79
80 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
81 _state: &Self::State<Lhs, Rhs, EO>,
82 ) -> View<Line<Lhs>, Coords3d> {
83 unexpanded!()
84 }
85 fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
86 _state: &Self::State<Lhs, Rhs, EO>,
87 _batch: u32,
88 ) -> u32 {
89 unexpanded!()
90 }
91 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
92 _state: &Self::State<Lhs, Rhs, EO>,
93 ) -> View<Line<Rhs>, Coords3d> {
94 unexpanded!()
95 }
96 fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
97 _state: &Self::State<Lhs, Rhs, EO>,
98 _batch: u32,
99 ) -> u32 {
100 unexpanded!()
101 }
102 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
103 _state: &Self::State<Lhs, Rhs, EO>,
104 ) -> CubeOption<View<Line<EO>, Coords3d>> {
105 unexpanded!()
106 }
107 fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
108 _state: &Self::State<Lhs, Rhs, EO>,
109 _batch: u32,
110 ) -> u32 {
111 unexpanded!()
112 }
113 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
114 _state: &mut Self::State<Lhs, Rhs, EO>,
115 ) -> View<Line<EO>, Coords3d, ReadWrite> {
116 unexpanded!()
117 }
118 fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
119 _state: &Self::State<Lhs, Rhs, EO>,
120 _batch: u32,
121 ) -> u32 {
122 unexpanded!()
123 }
124}
125
126#[derive(Clone, Copy)]
127pub enum TensorInputIdent {
129 Lhs,
130 Rhs,
131}
132
133#[derive(Clone)]
134pub struct TensorArgs;
138
139#[derive(CubeLaunch, CubeType, Clone, Copy)]
140pub struct TensorInputs<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> {
142 lhs: View<Line<Lhs>, Coords3d>,
144 lhs_batch: VirtualLayout<Coords1d, Coords1d>,
145 rhs: View<Line<Rhs>, Coords3d>,
147 rhs_batch: VirtualLayout<Coords1d, Coords1d>,
148 acc: CubeOption<View<Line<Acc>, Coords3d>>,
150 acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
151}
152
153impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> ConcreteInputsFactory
154 for TensorInputs<Lhs, Rhs, Acc>
155{
156 fn create<'a, R: Runtime>(
157 client: &ComputeClient<R::Server>,
158 lhs: &'a MatmulInputHandleRef<'a, R>,
159 rhs: &'a MatmulInputHandleRef<'a, R>,
160 _selection: &MatmulSelection,
161 problem: &MatmulProblem,
162 line_sizes: &MatmulLineSizes,
163 config: impl BatchConfig,
164 _dtypes: &MatmulElems,
165 ) -> Self::RuntimeArg<'a, R> {
166 let view = |handle: &'a MatmulInputHandleRef<'a, R>,
167 config: GlobalLayoutConfig,
168 line_size| match handle {
169 MatmulInputHandleRef::Normal(handle, _dtype) => {
170 let layout = GlobalLayoutLaunch::from_handle(handle, line_size, config);
171 ViewArg::new::<GlobalLayout>(handle.as_array_arg(line_size), layout)
172 }
173 MatmulInputHandleRef::Quantized {
174 data,
175 scale,
176 shape,
177 scheme,
178 ..
179 } => {
180 let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
181 client, data, scale, shape, problem, **scheme, line_size, config,
182 );
183 let data_view =
184 ViewArg::new::<GlobalLayout>(data.as_array_arg(line_size), data_layout);
185 let scales_view =
186 ViewArg::new::<GlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
187 ViewArg::new_quantized(data_view, scales_view, **scheme)
188 }
189 };
190 let batch_layout = |handle: &'a MatmulInputHandleRef<'a, R>| match handle {
191 MatmulInputHandleRef::Normal(handle, _dtype) => {
192 let layout = BatchLayoutLaunch::from_handle(client, handle, problem);
193 VirtualLayoutLaunch::new::<BatchLayout>(layout)
194 }
195 MatmulInputHandleRef::Quantized { .. } => {
196 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
197 }
198 };
199
200 let config = config.global_config();
201 TensorInputsLaunch::new(
202 view(
203 lhs,
204 config.lhs_reader_config().gmem_config.into(),
205 line_sizes.lhs,
206 ),
207 batch_layout(lhs),
208 view(
209 rhs,
210 config.rhs_reader_config().gmem_config.into(),
211 line_sizes.rhs,
212 ),
213 batch_layout(rhs),
214 CubeOptionArgs::None,
215 CubeOptionArgs::None,
216 )
217 }
218}
219
220#[derive(CubeType, CubeLaunch, Clone, Copy)]
221pub struct TensorOutput<EG: Numeric> {
222 view: View<Line<EG>, Coords3d, ReadWrite>,
223 batch: VirtualLayout<Coords1d, Coords1d>,
224}
225
226impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
227 fn create<'a, R: Runtime>(
228 client: &ComputeClient<R::Server>,
229 out: &'a TensorHandleRef<'a, R>,
230 _selection: &MatmulSelection,
231 problem: &MatmulProblem,
232 line_sizes: &MatmulLineSizes,
233 config: impl BatchConfig,
234 _dtypes: &MatmulElems,
235 ) -> Self::RuntimeArg<'a, R> {
236 let config = config.global_config();
237 let layout = GlobalLayoutLaunch::from_handle(
238 out,
239 line_sizes.out,
240 config.writer_config().gmem_config.into(),
241 );
242 let batch = BatchLayoutLaunch::from_handle(client, out, problem);
243 let view = ViewArg::new::<GlobalLayout>(out.as_array_arg(line_sizes.out), layout);
244 TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
245 }
246}
247
248#[cube]
249impl MatmulArgs for TensorArgs {
250 type Output<EO: Numeric> = TensorOutput<EO>;
251 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorInputs<Lhs, Rhs, EO>;
252 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
253 (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>);
254
255 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
256 input: &Self::Input<Lhs, Rhs, EO>,
257 output: &mut Self::Output<EO>,
258 #[comptime] _config: G,
259 ) -> Self::State<Lhs, Rhs, EO> {
260 (*input, *output)
261 }
262
263 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
264 state: &Self::State<Lhs, Rhs, EO>,
265 ) -> View<Line<Lhs>, Coords3d> {
266 state.0.lhs
267 }
268
269 fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
270 state: &Self::State<Lhs, Rhs, EO>,
271 batch: u32,
272 ) -> u32 {
273 state.0.lhs_batch.to_source_pos(batch)
274 }
275
276 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
277 state: &Self::State<Lhs, Rhs, EO>,
278 ) -> View<Line<Rhs>, Coords3d> {
279 state.0.rhs
280 }
281
282 fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
283 state: &Self::State<Lhs, Rhs, EO>,
284 batch: u32,
285 ) -> u32 {
286 state.0.rhs_batch.to_source_pos(batch)
287 }
288
289 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
290 state: &Self::State<Lhs, Rhs, EO>,
291 ) -> CubeOption<View<Line<EO>, Coords3d>> {
292 state.0.acc
293 }
294
295 fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
296 state: &Self::State<Lhs, Rhs, EO>,
297 batch: u32,
298 ) -> u32 {
299 match state.0.acc_batch {
300 CubeOption::Some(layout) => layout.to_source_pos(batch),
301 CubeOption::None => batch,
302 }
303 }
304
305 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
306 state: &mut Self::State<Lhs, Rhs, EO>,
307 ) -> View<Line<EO>, Coords3d, ReadWrite> {
308 state.1.view
309 }
310
311 fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
312 state: &Self::State<Lhs, Rhs, EO>,
313 batch: u32,
314 ) -> u32 {
315 state.1.batch.to_source_pos(batch)
316 }
317}
318
319#[derive(Clone)]
320pub struct TensorMapArgs;
324
325#[derive(CubeLaunch, CubeType, Clone, Copy)]
326pub struct TensorMapInputs<Lhs: Numeric, Rhs: Numeric, EO: Numeric> {
328 pub lhs: View<Line<Lhs>, Coords3d>,
330 pub rhs: View<Line<Rhs>, Coords3d>,
332 pub acc: CubeOption<View<Line<EO>, Coords3d>>,
334 pub acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
336}
337
338impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
339 for TensorMapInputs<Lhs, Rhs, EO>
340{
341 fn create<'a, R: Runtime>(
342 _client: &ComputeClient<R::Server>,
343 lhs_handle: &'a MatmulInputHandleRef<'a, R>,
344 rhs_handle: &'a MatmulInputHandleRef<'a, R>,
345 selection: &MatmulSelection,
346 problem: &MatmulProblem,
347 line_sizes: &MatmulLineSizes,
348 config: impl BatchConfig,
349 dtypes: &MatmulElems,
350 ) -> Self::RuntimeArg<'a, R> {
351 let lhs = lhs_handle.data();
352 let rhs = rhs_handle.data();
353
354 let config = config.global_config();
355
356 let tiling_scheme = selection.tiling_scheme;
357 let stage_m = tiling_scheme.elements_per_stage_along_m();
358 let stage_n = tiling_scheme.elements_per_stage_along_n();
359 let stage_k = tiling_scheme.elements_per_stage_along_k();
360
361 let stage_size_lhs = match config.lhs_reader_config().smem_config.swizzle {
365 SwizzleMode::None => match problem.lhs_layout {
366 components::MatrixLayout::RowMajor => {
367 vec![1, stage_m, tiling_scheme.tile_size.k]
368 }
369 components::MatrixLayout::ColMajor => {
370 vec![1, stage_k, tiling_scheme.tile_size.m]
371 }
372 },
373 _ => match problem.lhs_layout {
374 components::MatrixLayout::RowMajor => {
375 vec![1, stage_m, stage_k]
376 }
377 components::MatrixLayout::ColMajor => {
378 vec![1, stage_k, stage_m]
379 }
380 },
381 };
382 let stage_size_rhs = match config.rhs_reader_config().smem_config.swizzle {
383 SwizzleMode::None => match problem.rhs_layout {
384 components::MatrixLayout::RowMajor => {
385 vec![1, stage_k, tiling_scheme.tile_size.n]
386 }
387 components::MatrixLayout::ColMajor => {
388 vec![1, stage_n, tiling_scheme.tile_size.k]
389 }
390 },
391 _ => match problem.rhs_layout {
392 components::MatrixLayout::RowMajor => {
393 vec![1, stage_k, stage_n]
394 }
395 components::MatrixLayout::ColMajor => {
396 vec![1, stage_n, stage_k]
397 }
398 },
399 };
400
401 let lhs_rank = lhs.shape.len();
402 let mut lhs_shape = vec![
403 problem.lhs_batches.iter().product(),
404 lhs.shape[lhs_rank - 2],
405 lhs.shape[lhs_rank - 1],
406 ];
407 let mut lhs_strides = if lhs_rank > 2 {
408 lhs.strides[lhs_rank - 3..].to_vec()
409 } else {
410 vec![lhs.strides[0], lhs.strides[1]]
411 };
412
413 let rhs_rank = rhs.shape.len();
414 let mut rhs_shape = vec![
415 problem.rhs_batches.iter().product(),
416 rhs.shape[rhs_rank - 2],
417 rhs.shape[rhs_rank - 1],
418 ];
419 let mut rhs_strides = if rhs_rank > 2 {
420 rhs.strides[rhs_rank - 3..].to_vec()
421 } else {
422 vec![rhs.strides[0], rhs.strides[1]]
423 };
424
425 let mut lhs_transposed = false;
426 let mut rhs_transposed = false;
427
428 let lhs_rank = lhs_strides.len();
429 let rhs_rank = rhs_strides.len();
430
431 if matches!(problem.lhs_layout, components::MatrixLayout::ColMajor) {
434 lhs_shape.swap(2, 1);
435 lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
436 lhs_transposed = true;
437 }
438 if matches!(problem.rhs_layout, components::MatrixLayout::ColMajor) {
439 rhs_shape.swap(2, 1);
440 rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
441 rhs_transposed = true;
442 }
443
444 if lhs_rank == 2 {
446 let stride = lhs_strides[0];
447 lhs_strides.insert(0, stride);
448 }
449 if rhs_rank == 2 {
450 let stride = rhs_strides[0];
451 rhs_strides.insert(0, stride);
452 }
453
454 fn swizzle(mode: SwizzleMode) -> TensorMapSwizzle {
455 match mode {
456 SwizzleMode::None => TensorMapSwizzle::None,
457 SwizzleMode::B32 => TensorMapSwizzle::B32,
458 SwizzleMode::B64 => TensorMapSwizzle::B64,
459 SwizzleMode::B128 => TensorMapSwizzle::B128,
460 }
461 }
462
463 let swizzle_lhs = swizzle(config.lhs_reader_config().smem_config.swizzle);
464 let swizzle_rhs = swizzle(config.rhs_reader_config().smem_config.swizzle);
465
466 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
469 tf32::as_type_native_unchecked()
470 } else {
471 dtypes.lhs_stage
472 };
473 let rhs_elem = if dtypes.rhs_stage == f32::as_type_native_unchecked() {
474 tf32::as_type_native_unchecked()
475 } else {
476 dtypes.rhs_stage
477 };
478
479 let meta_lhs = TensorMapMeta {
480 format: TensorMapFormat::Tiled {
481 tile_size: stage_size_lhs,
482 },
483 rank: 3,
484 shape: lhs_shape.clone(),
485 strides: lhs_strides,
486 elem_stride: vec![1, 1, 1],
487 interleave: TensorMapInterleave::None,
488 swizzle: swizzle_lhs,
489 prefetch: TensorMapPrefetch::None,
490 oob_fill: OobFill::Zero,
491 storage_ty: lhs_elem,
492 };
493
494 let meta_rhs = TensorMapMeta {
495 format: TensorMapFormat::Tiled {
496 tile_size: stage_size_rhs,
497 },
498 rank: 3,
499 shape: rhs_shape.clone(),
500 strides: rhs_strides,
501 elem_stride: vec![1, 1, 1],
502 interleave: TensorMapInterleave::None,
503 swizzle: swizzle_rhs,
504 prefetch: TensorMapPrefetch::None,
505 oob_fill: OobFill::Zero,
506 storage_ty: rhs_elem,
507 };
508
509 let lhs = TensorMapArg {
510 tensor: lhs.as_tensor_arg(line_sizes.lhs),
511 metadata: meta_lhs,
512 };
513 let rhs = TensorMapArg {
514 tensor: rhs.as_tensor_arg(line_sizes.rhs),
515 metadata: meta_rhs,
516 };
517
518 let view = |buffer, shape: &[usize], transposed| {
519 let batches = ScalarArg::new(shape[0] as u32);
520 let (rows, cols) = match transposed {
521 true => (
522 ScalarArg::new(shape[2] as u32),
523 ScalarArg::new(shape[1] as u32),
524 ),
525 false => (
526 ScalarArg::new(shape[1] as u32),
527 ScalarArg::new(shape[2] as u32),
528 ),
529 };
530 let shape = (batches, rows, cols);
531 let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
532 ViewArg::new_tensor_map::<SimpleTmaGlobalLayout>(buffer, layout)
533 };
534
535 TensorMapInputsLaunch::new(
536 view(lhs, &lhs_shape, lhs_transposed),
537 view(rhs, &rhs_shape, rhs_transposed),
538 CubeOptionArgs::None,
539 CubeOptionArgs::None,
540 )
541 }
542}
543
544#[cube]
545impl MatmulArgs for TensorMapArgs {
546 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorMapInputs<Lhs, Rhs, EO>;
547 type Output<EO: Numeric> = TensorOutput<EO>;
548 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
549 (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>);
550
551 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
552 input: &Self::Input<Lhs, Rhs, EO>,
553 output: &mut Self::Output<EO>,
554 #[comptime] _config: G,
555 ) -> Self::State<Lhs, Rhs, EO> {
556 (*input, *output)
557 }
558
559 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
560 state: &Self::State<Lhs, Rhs, EO>,
561 ) -> View<Line<Lhs>, Coords3d> {
562 state.0.lhs
563 }
564
565 fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
566 _state: &Self::State<Lhs, Rhs, EO>,
567 batch: u32,
568 ) -> u32 {
569 batch
570 }
571
572 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
573 state: &Self::State<Lhs, Rhs, EO>,
574 ) -> View<Line<Rhs>, Coords3d> {
575 state.0.rhs
576 }
577
578 fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
579 _state: &Self::State<Lhs, Rhs, EO>,
580 batch: u32,
581 ) -> u32 {
582 batch
583 }
584
585 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
586 state: &Self::State<Lhs, Rhs, EO>,
587 ) -> CubeOption<View<Line<EO>, Coords3d>> {
588 state.0.acc
589 }
590
591 fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
592 state: &Self::State<Lhs, Rhs, EO>,
593 batch: u32,
594 ) -> u32 {
595 match state.0.acc_batch {
596 CubeOption::Some(layout) => layout.to_source_pos(batch),
597 CubeOption::None => batch,
598 }
599 }
600
601 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
602 state: &mut Self::State<Lhs, Rhs, EO>,
603 ) -> View<Line<EO>, Coords3d, ReadWrite> {
604 state.1.view
605 }
606
607 fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
608 state: &Self::State<Lhs, Rhs, EO>,
609 batch: u32,
610 ) -> u32 {
611 state.1.batch.to_source_pos(batch)
612 }
613}