1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, server::TensorMapMeta, unexpanded};
5use cubecl_std::{
6 CubeOption, CubeOptionArgs,
7 tensor::{View, launch::ViewArg, layout::Coords3d},
8};
9
10use crate::{
11 MatmulInputHandleRef,
12 components::{
13 self, MatmulIdent, MatmulLineSizes, MatmulProblem, MatmulSelection,
14 batch::BatchConfig,
15 global::{
16 GlobalConfig,
17 memory::{
18 BatchedGlobalLayout, BatchedGlobalLayoutLaunch, BatchedGlobalScaleLayout,
19 SimpleTmaGlobalLayout, SimpleTmaGlobalLayoutLaunch,
20 },
21 },
22 },
23};
24
25pub trait ConcreteInputsFactory: LaunchArg {
28 fn create<'a, R: Runtime>(
29 client: &ComputeClient<R::Server>,
30 lhs: &'a MatmulInputHandleRef<'a, R>,
31 rhs: &'a MatmulInputHandleRef<'a, R>,
32 selection: &MatmulSelection,
33 problem: &MatmulProblem,
34 line_sizes: &MatmulLineSizes,
35 config: impl BatchConfig,
36 ) -> Self::RuntimeArg<'a, R>;
37}
38
39pub trait ConcreteOutputFactory: LaunchArg {
42 fn create<'a, R: Runtime>(
43 client: &ComputeClient<R::Server>,
44 out: &'a TensorHandleRef<'a, R>,
45 selection: &MatmulSelection,
46 problem: &MatmulProblem,
47 line_sizes: &MatmulLineSizes,
48 config: impl BatchConfig,
49 ) -> Self::RuntimeArg<'a, R>;
50}
51
52#[cube]
53pub trait MatmulArgs: Send + Sync + 'static + Clone {
55 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: LaunchArg + CubeType;
57 type Output<EO: Numeric>: LaunchArg + CubeType;
59 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: CubeType;
62
63 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
65 input: &Self::Input<Lhs, Rhs, EO>,
66 output: &mut Self::Output<EO>,
67 #[comptime] config: G,
68 ) -> Self::State<Lhs, Rhs, EO>;
69
70 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
71 _state: &Self::State<Lhs, Rhs, EO>,
72 ) -> View<Line<Lhs>, Coords3d> {
73 unexpanded!()
74 }
75 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
76 _state: &Self::State<Lhs, Rhs, EO>,
77 ) -> View<Line<Rhs>, Coords3d> {
78 unexpanded!()
79 }
80 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
81 _state: &Self::State<Lhs, Rhs, EO>,
82 ) -> CubeOption<View<Line<EO>, Coords3d>> {
83 unexpanded!()
84 }
85 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
86 _state: &mut Self::State<Lhs, Rhs, EO>,
87 ) -> View<Line<EO>, Coords3d, ReadWrite> {
88 unexpanded!()
89 }
90}
91
92#[derive(Clone, Copy)]
93pub enum TensorInputIdent {
95 Lhs,
96 Rhs,
97}
98
99#[derive(Clone)]
100pub struct TensorArgs;
104
105#[derive(CubeLaunch, CubeType)]
106pub struct TensorInputs<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> {
108 pub lhs: View<Line<Lhs>, Coords3d>,
110 pub rhs: View<Line<Rhs>, Coords3d>,
112 pub acc: CubeOption<View<Line<Acc>, Coords3d>>,
114}
115
116pub type TensorOutput<EO> = View<Line<EO>, Coords3d, ReadWrite>;
117
118impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> ConcreteInputsFactory
119 for TensorInputs<Lhs, Rhs, Acc>
120{
121 fn create<'a, R: Runtime>(
122 client: &ComputeClient<R::Server>,
123 lhs: &'a MatmulInputHandleRef<'a, R>,
124 rhs: &'a MatmulInputHandleRef<'a, R>,
125 _selection: &MatmulSelection,
126 problem: &MatmulProblem,
127 line_sizes: &MatmulLineSizes,
128 config: impl BatchConfig,
129 ) -> Self::RuntimeArg<'a, R> {
130 let config = config.global_config();
131 let view = |handle: &'a MatmulInputHandleRef<'a, R>, ident, line_size| match handle {
132 MatmulInputHandleRef::Normal(handle) => {
133 let layout = BatchedGlobalLayoutLaunch::from_handle(
134 client,
135 handle,
136 problem,
137 line_size,
138 config.global_memory_config(ident).into(),
139 );
140 ViewArg::new::<BatchedGlobalLayout>(handle.as_array_arg(line_size), layout)
141 }
142 MatmulInputHandleRef::Quantized {
143 data,
144 scale,
145 shape,
146 scheme,
147 } => {
148 let (data_layout, scales_layout) = BatchedGlobalLayoutLaunch::from_quantized_handle(
149 client,
150 data,
151 scale,
152 shape,
153 problem,
154 **scheme,
155 line_size,
156 config.global_memory_config(ident).into(),
157 );
158 let data_view =
159 ViewArg::new::<BatchedGlobalLayout>(data.as_array_arg(line_size), data_layout);
160 let scales_view =
161 ViewArg::new::<BatchedGlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
162 ViewArg::new_quantized(data_view, scales_view, **scheme)
163 }
164 };
165
166 TensorInputsLaunch::new(
167 view(lhs, MatmulIdent::Lhs, line_sizes.lhs),
168 view(rhs, MatmulIdent::Rhs, line_sizes.rhs),
169 CubeOptionArgs::None,
170 )
171 }
172}
173
174impl<EG: Numeric> ConcreteOutputFactory for View<Line<EG>, Coords3d, ReadWrite> {
175 fn create<'a, R: Runtime>(
176 client: &ComputeClient<R::Server>,
177 out: &'a TensorHandleRef<'a, R>,
178 _selection: &MatmulSelection,
179 problem: &MatmulProblem,
180 line_sizes: &MatmulLineSizes,
181 config: impl BatchConfig,
182 ) -> Self::RuntimeArg<'a, R> {
183 let config = config.global_config();
184 let layout = BatchedGlobalLayoutLaunch::from_handle(
185 client,
186 out,
187 problem,
188 line_sizes.out,
189 config.global_memory_config(MatmulIdent::Out).into(),
190 );
191 ViewArg::new::<BatchedGlobalLayout>(out.as_array_arg(line_sizes.out), layout)
192 }
193}
194
195#[cube]
196impl MatmulArgs for TensorArgs {
197 type Output<EO: Numeric> = TensorOutput<EO>;
198 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorInputs<Lhs, Rhs, EO>;
199 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = (
200 View<Line<Lhs>, Coords3d>,
201 View<Line<Rhs>, Coords3d>,
202 CubeOption<View<Line<EO>, Coords3d>>,
203 View<Line<EO>, Coords3d, ReadWrite>,
204 );
205
206 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
207 input: &Self::Input<Lhs, Rhs, EO>,
208 output: &mut Self::Output<EO>,
209 #[comptime] _config: G,
210 ) -> Self::State<Lhs, Rhs, EO> {
211 (input.lhs, input.rhs, input.acc, *output)
212 }
213
214 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
215 state: &Self::State<Lhs, Rhs, EO>,
216 ) -> View<Line<Lhs>, Coords3d> {
217 state.0
218 }
219
220 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
221 state: &Self::State<Lhs, Rhs, EO>,
222 ) -> View<Line<Rhs>, Coords3d> {
223 state.1
224 }
225
226 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
227 state: &Self::State<Lhs, Rhs, EO>,
228 ) -> CubeOption<View<Line<EO>, Coords3d>> {
229 state.2
230 }
231
232 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
233 state: &mut Self::State<Lhs, Rhs, EO>,
234 ) -> View<Line<EO>, Coords3d, ReadWrite> {
235 state.3
236 }
237}
238
239#[derive(Clone)]
240pub struct TensorMapArgs;
244
245#[derive(CubeLaunch, CubeType)]
246pub struct TensorMapInputs<Lhs: Numeric, Rhs: Numeric, EO: Numeric> {
248 pub lhs: View<Line<Lhs>, Coords3d>,
250 pub rhs: View<Line<Rhs>, Coords3d>,
252 pub acc: CubeOption<View<Line<EO>, Coords3d>>,
254}
255
256impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
257 for TensorMapInputs<Lhs, Rhs, EO>
258{
259 fn create<'a, R: Runtime>(
260 _client: &ComputeClient<R::Server>,
261 lhs_handle: &'a MatmulInputHandleRef<'a, R>,
262 rhs_handle: &'a MatmulInputHandleRef<'a, R>,
263 selection: &MatmulSelection,
264 problem: &MatmulProblem,
265 line_sizes: &MatmulLineSizes,
266 _config: impl BatchConfig,
267 ) -> Self::RuntimeArg<'a, R> {
268 let lhs = lhs_handle.data();
269 let rhs = rhs_handle.data();
270
271 let tiling_scheme = selection.tiling_scheme;
272 let stage_m = tiling_scheme.elements_in_stage_m();
273 let stage_n = tiling_scheme.elements_in_stage_n();
274 let stage_k = tiling_scheme.elements_in_stage_k();
275 let stage_size_lhs = match problem.lhs_layout {
276 components::MatrixLayout::RowMajor => {
277 vec![1, stage_m, tiling_scheme.elements_in_tile_k()]
278 }
279 components::MatrixLayout::ColMajor => {
280 vec![1, stage_k, tiling_scheme.elements_in_tile_m()]
281 }
282 };
283 let stage_size_rhs = match problem.rhs_layout {
284 components::MatrixLayout::RowMajor => {
285 vec![1, stage_k, tiling_scheme.elements_in_tile_n()]
286 }
287 components::MatrixLayout::ColMajor => {
288 vec![1, stage_n, tiling_scheme.elements_in_tile_k()]
289 }
290 };
291
292 let lhs_elem_size = size_of::<Lhs>();
293 let rhs_elem_size = size_of::<Rhs>();
294
295 let lhs_rank = lhs.shape.len();
296 let mut lhs_shape = vec![
297 problem.lhs_batches[0],
298 lhs.shape[lhs_rank - 2],
299 lhs.shape[lhs_rank - 1],
300 ];
301 let mut lhs_strides = if lhs_rank > 2 {
302 lhs.strides[lhs_rank - 3..].to_vec()
303 } else {
304 vec![1, lhs.strides[lhs_rank - 2], lhs.strides[lhs_rank - 1]]
305 };
306
307 let rhs_rank = rhs.shape.len();
308 let mut rhs_shape = vec![
309 problem.rhs_batches[0],
310 rhs.shape[rhs_rank - 2],
311 rhs.shape[rhs_rank - 1],
312 ];
313 let mut rhs_strides = if rhs_rank > 2 {
314 rhs.strides[rhs_rank - 3..].to_vec()
315 } else {
316 vec![1, rhs.strides[rhs_rank - 2], rhs.strides[rhs_rank - 1]]
317 };
318
319 let mut lhs_transposed = false;
320 let mut rhs_transposed = false;
321
322 if matches!(problem.lhs_layout, components::MatrixLayout::ColMajor) {
325 lhs_shape.swap(lhs_rank - 1, lhs_rank - 2);
326 lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
327 lhs_transposed = true;
328 }
329 if matches!(problem.rhs_layout, components::MatrixLayout::ColMajor) {
330 rhs_shape.swap(rhs_rank - 1, rhs_rank - 2);
331 rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
332 rhs_transposed = true;
333 }
334
335 fn prefetch(bytes: usize) -> TensorMapPrefetch {
336 match bytes {
337 ..64 => TensorMapPrefetch::None,
338 64..128 => TensorMapPrefetch::B64,
339 128..256 => TensorMapPrefetch::B128,
340 256.. => TensorMapPrefetch::B256,
341 }
342 }
343
344 let prefetch_lhs = prefetch(stage_size_lhs[2] as usize * lhs_elem_size);
345 let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * rhs_elem_size);
346
347 let lhs_elem = if TypeId::of::<Lhs>() == TypeId::of::<f32>() {
350 tf32::as_type_native_unchecked()
351 } else {
352 Lhs::as_type_native_unchecked()
353 };
354 let rhs_elem = if TypeId::of::<Rhs>() == TypeId::of::<f32>() {
355 tf32::as_type_native_unchecked()
356 } else {
357 Rhs::as_type_native_unchecked()
358 };
359
360 let meta_lhs = TensorMapMeta {
361 format: TensorMapFormat::Tiled {
362 tile_size: stage_size_lhs,
363 },
364 rank: 3,
365 shape: lhs_shape.clone(),
366 strides: lhs_strides,
367 elem_stride: vec![1, 1, 1],
368 interleave: TensorMapInterleave::None,
369 swizzle: TensorMapSwizzle::None,
370 prefetch: prefetch_lhs,
371 oob_fill: OobFill::Zero,
372 storage_ty: lhs_elem,
373 };
374
375 let meta_rhs = TensorMapMeta {
376 format: TensorMapFormat::Tiled {
377 tile_size: stage_size_rhs,
378 },
379 rank: 3,
380 shape: rhs_shape.clone(),
381 strides: rhs_strides,
382 elem_stride: vec![1, 1, 1],
383 interleave: TensorMapInterleave::None,
384 swizzle: TensorMapSwizzle::None,
385 prefetch: prefetch_rhs,
386 oob_fill: OobFill::Zero,
387 storage_ty: rhs_elem,
388 };
389
390 let lhs = TensorMapArg {
391 tensor: lhs.as_tensor_arg(line_sizes.lhs),
392 metadata: meta_lhs,
393 };
394 let rhs = TensorMapArg {
395 tensor: rhs.as_tensor_arg(line_sizes.rhs),
396 metadata: meta_rhs,
397 };
398
399 let view = |buffer, shape: &[usize], transposed| {
400 let batches = ScalarArg::new(shape[0] as u32);
401 let (rows, cols) = match transposed {
402 true => (
403 ScalarArg::new(shape[2] as u32),
404 ScalarArg::new(shape[1] as u32),
405 ),
406 false => (
407 ScalarArg::new(shape[1] as u32),
408 ScalarArg::new(shape[2] as u32),
409 ),
410 };
411 let shape = (batches, rows, cols);
412 let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
413 ViewArg::new_tensor_map::<SimpleTmaGlobalLayout>(buffer, layout)
414 };
415
416 TensorMapInputsLaunch::new(
417 view(lhs, &lhs_shape, lhs_transposed),
418 view(rhs, &rhs_shape, rhs_transposed),
419 CubeOptionArgs::None,
420 )
421 }
422}
423
424#[cube]
425impl MatmulArgs for TensorMapArgs {
426 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorMapInputs<Lhs, Rhs, EO>;
427 type Output<EO: Numeric> = TensorOutput<EO>;
428 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = (
429 View<Line<Lhs>, Coords3d>,
430 View<Line<Rhs>, Coords3d>,
431 CubeOption<View<Line<EO>, Coords3d>>,
432 View<Line<EO>, Coords3d, ReadWrite>,
433 );
434
435 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
436 input: &Self::Input<Lhs, Rhs, EO>,
437 output: &mut Self::Output<EO>,
438 #[comptime] _config: G,
439 ) -> Self::State<Lhs, Rhs, EO> {
440 (input.lhs, input.rhs, input.acc, *output)
441 }
442
443 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
444 state: &Self::State<Lhs, Rhs, EO>,
445 ) -> View<Line<Lhs>, Coords3d> {
446 state.0
447 }
448
449 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
450 state: &Self::State<Lhs, Rhs, EO>,
451 ) -> View<Line<Rhs>, Coords3d> {
452 state.1
453 }
454
455 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
456 state: &Self::State<Lhs, Rhs, EO>,
457 ) -> CubeOption<View<Line<EO>, Coords3d>> {
458 state.2
459 }
460
461 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
462 state: &mut Self::State<Lhs, Rhs, EO>,
463 ) -> View<Line<EO>, Coords3d, ReadWrite> {
464 state.3
465 }
466}