1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, server::TensorMapMeta};
5use cubecl_std::{
6 ReinterpretSlice,
7 tensor::r#virtual::{VirtualTensorOperations, VirtualTensorOperationsExpand},
8};
9
10use super::Quantization;
11use crate::matmul::components::{self, MatmulPrecision, MatmulProblem, MatmulSelection};
12
13pub trait ConcreteInputsFactory: LaunchArg {
16 fn create<'a, R: Runtime>(
17 lhs: &'a TensorHandleRef<'a, R>,
18 rhs: &'a TensorHandleRef<'a, R>,
19 selection: &MatmulSelection,
20 problem: &MatmulProblem,
21 ) -> Self::RuntimeArg<'a, R>;
22}
23
24pub trait ConcreteOutputFactory: LaunchArg {
27 fn create<'a, R: Runtime>(
28 out: &'a TensorHandleRef<'a, R>,
29 selection: &MatmulSelection,
30 problem: &MatmulProblem,
31 ) -> Self::RuntimeArg<'a, R>;
32}
33
34#[cube]
35pub trait MatmulArgs: Send + Sync + 'static + Clone {
37 type Input<EI: Numeric>: LaunchArg + CubeType;
39 type Output<EO: Numeric>: LaunchArg + CubeType;
41 type State<EI: Numeric, EO: Numeric>: CubeType;
44
45 fn init_state<EI: Numeric, EO: Numeric>(
47 input: &Self::Input<EI>,
48 output: &mut Self::Output<EO>,
49 ) -> Self::State<EI, EO>;
50
51 fn read_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, coordinate: u32)
53 -> Line<EI>;
54 fn read_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, coordinate: u32)
56 -> Line<EI>;
57
58 fn read_window_lhs<EI: Numeric, EO: Numeric>(
60 state: &Self::State<EI, EO>,
61 start: u32,
62 end: u32,
63 ) -> Slice<Line<EI>>;
64
65 fn read_window_rhs<EI: Numeric, EO: Numeric>(
67 state: &Self::State<EI, EO>,
68 start: u32,
69 end: u32,
70 ) -> Slice<Line<EI>>;
71
72 fn as_tensor_map_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> TensorMap<EI>;
74
75 fn as_tensor_map_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> TensorMap<EI>;
77
78 fn write_out<EI: Numeric, EO: Numeric>(
80 state: &mut Self::State<EI, EO>,
81 coordinate: u32,
82 value: Line<EO>,
83 );
84
85 fn rank_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
87 fn rank_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
89 fn rank_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
91
92 fn len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
94 fn len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
96 fn len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
98
99 fn buffer_len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
101 fn buffer_len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
103 fn buffer_len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
105
106 fn shape_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
108 fn shape_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
110 fn shape_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
112
113 fn stride_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
115 fn stride_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
117 fn stride_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
119
120 fn quantization<MP: MatmulPrecision>(state: &Self::State<MP::EI, MP::EO>) -> Quantization<MP>;
124}
125
126#[derive(Clone, Copy)]
127pub enum TensorInputIdent {
129 Lhs,
130 Rhs,
131}
132
133pub struct TensorInput<EI: Numeric, EO: Numeric, GA: MatmulArgs> {
137 state: *const GA::State<EI, EO>,
138 ident: TensorInputIdent,
139}
140
141impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> VirtualTensorOperations<EI>
142 for TensorInput<EI, EO, MA>
143{
144}
145
146impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> VirtualTensorOperations<EO>
147 for TensorOutput<EI, EO, MA>
148{
149}
150
151impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> VirtualTensorOperationsExpand<EO>
152 for TensorOutputExpand<EI, EO, MA>
153{
154 fn __expand_read_method(
155 &self,
156 _scope: &mut Scope,
157 _index: ExpandElementTyped<u32>,
158 ) -> ExpandElementTyped<Line<EO>> {
159 panic!("Can't read output tensor");
160 }
161
162 fn __expand_read_window_method(
163 &self,
164 _context: &mut Scope,
165 _start: ExpandElementTyped<u32>,
166 _end: ExpandElementTyped<u32>,
167 ) -> ExpandElementTyped<Slice<Line<EO>>> {
168 panic!("Can't read output tensor");
169 }
170
171 fn __expand_write_method(
172 &self,
173 scope: &mut Scope,
174 index: ExpandElementTyped<u32>,
175 value: ExpandElementTyped<Line<EO>>,
176 ) {
177 TensorOutputExpand::__expand_write_method(self.clone(), scope, index, value)
178 }
179
180 fn __expand_shape_method(
181 &self,
182 scope: &mut Scope,
183 axis: ExpandElementTyped<u32>,
184 ) -> ExpandElementTyped<u32> {
185 TensorOutputExpand::__expand_shape_method(self.clone(), scope, axis)
186 }
187
188 fn __expand_stride_method(
189 &self,
190 scope: &mut Scope,
191 axis: ExpandElementTyped<u32>,
192 ) -> ExpandElementTyped<u32> {
193 TensorOutputExpand::__expand_stride_method(self.clone(), scope, axis)
194 }
195
196 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
197 TensorOutputExpand::__expand_rank_method(self.clone(), scope)
198 }
199
200 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
201 TensorOutputExpand::__expand_len_method(self.clone(), scope)
202 }
203
204 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
205 TensorOutputExpand::__expand_buffer_len_method(self.clone(), scope)
206 }
207
208 fn __expand_as_tensor_map_method(
209 &self,
210 _scope: &mut Scope,
211 ) -> ExpandElementTyped<TensorMap<EO>> {
212 unimplemented!("TensorOutputExpand can't be turned into a tensor map");
213 }
214}
215
216impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> VirtualTensorOperationsExpand<EI>
217 for TensorInputExpand<EI, EO, MA>
218{
219 fn __expand_read_method(
220 &self,
221 scope: &mut Scope,
222 index: ExpandElementTyped<u32>,
223 ) -> ExpandElementTyped<Line<EI>> {
224 TensorInputExpand::__expand_read_method(self.clone(), scope, index)
225 }
226 fn __expand_read_window_method(
227 &self,
228 context: &mut Scope,
229 start: ExpandElementTyped<u32>,
230 end: ExpandElementTyped<u32>,
231 ) -> ExpandElementTyped<Slice<Line<EI>>> {
232 TensorInputExpand::__expand_read_window_method(self.clone(), context, start, end)
233 }
234
235 fn __expand_write_method(
236 &self,
237 _scope: &mut Scope,
238 _index: ExpandElementTyped<u32>,
239 _value: ExpandElementTyped<Line<EI>>,
240 ) {
241 panic!("Can't write to input tensor");
242 }
243
244 fn __expand_shape_method(
245 &self,
246 scope: &mut Scope,
247 axis: ExpandElementTyped<u32>,
248 ) -> ExpandElementTyped<u32> {
249 TensorInputExpand::__expand_shape_method(self.clone(), scope, axis)
250 }
251
252 fn __expand_stride_method(
253 &self,
254 scope: &mut Scope,
255 axis: ExpandElementTyped<u32>,
256 ) -> ExpandElementTyped<u32> {
257 TensorInputExpand::__expand_stride_method(self.clone(), scope, axis)
258 }
259
260 fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
261 TensorInputExpand::__expand_rank_method(self.clone(), scope)
262 }
263
264 fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
265 TensorInputExpand::__expand_len_method(self.clone(), scope)
266 }
267
268 fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
269 TensorInputExpand::__expand_buffer_len_method(self.clone(), scope)
270 }
271
272 fn __expand_as_tensor_map_method(
273 &self,
274 scope: &mut Scope,
275 ) -> ExpandElementTyped<TensorMap<EI>> {
276 TensorInputExpand::__expand_as_tensor_map_method(self.clone(), scope)
277 }
278}
279
280pub struct TensorOutput<EI: Numeric, EO: Numeric, GA: MatmulArgs> {
288 state: *mut GA::State<EI, EO>,
289}
290
291pub struct TensorInputExpand<EI: Numeric, EO: Numeric, GA: MatmulArgs> {
293 state: <GA::State<EI, EO> as CubeType>::ExpandType,
294 ident: TensorInputIdent,
295}
296
297pub struct TensorOutputExpand<EI: Numeric, EO: Numeric, GA: MatmulArgs> {
299 state: <GA::State<EI, EO> as CubeType>::ExpandType,
300}
301
302#[cube]
303impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> TensorInput<EI, EO, MA> {
304 pub fn new(
306 state: &MA::State<EI, EO>,
307 #[comptime] ident: TensorInputIdent,
308 ) -> TensorInput<EI, EO, MA> {
309 TensorInput::<EI, EO, MA> { state, ident }
310 }
311
312 pub fn read_window(&self, start: u32, end: u32) -> Slice<Line<EI>> {
314 unsafe {
315 match comptime![&self.ident] {
316 TensorInputIdent::Lhs => MA::read_window_lhs(&(*self.state), start, end),
317 TensorInputIdent::Rhs => MA::read_window_rhs(&(*self.state), start, end),
318 }
319 }
320 }
321
322 pub fn read(&self, coordinate: u32) -> Line<EI> {
324 unsafe {
325 match comptime![&self.ident] {
326 TensorInputIdent::Lhs => MA::read_lhs(&(*self.state), coordinate),
327 TensorInputIdent::Rhs => MA::read_rhs(&(*self.state), coordinate),
328 }
329 }
330 }
331
332 pub fn shape(&self, axis: u32) -> u32 {
334 unsafe {
335 match comptime![&self.ident] {
336 TensorInputIdent::Lhs => MA::shape_lhs(&(*self.state), axis),
337 TensorInputIdent::Rhs => MA::shape_rhs(&(*self.state), axis),
338 }
339 }
340 }
341
342 pub fn stride(&self, axis: u32) -> u32 {
344 unsafe {
345 match comptime![&self.ident] {
346 TensorInputIdent::Lhs => MA::stride_lhs(&(*self.state), axis),
347 TensorInputIdent::Rhs => MA::stride_rhs(&(*self.state), axis),
348 }
349 }
350 }
351
352 pub fn rank(&self) -> u32 {
354 unsafe {
355 match comptime![&self.ident] {
356 TensorInputIdent::Lhs => MA::rank_lhs(&(*self.state)),
357 TensorInputIdent::Rhs => MA::rank_rhs(&(*self.state)),
358 }
359 }
360 }
361
362 #[allow(clippy::len_without_is_empty)]
364 pub fn len(&self) -> u32 {
365 unsafe {
366 match comptime![&self.ident] {
367 TensorInputIdent::Lhs => MA::len_lhs(&(*self.state)),
368 TensorInputIdent::Rhs => MA::len_rhs(&(*self.state)),
369 }
370 }
371 }
372
373 pub fn buffer_len(&self) -> u32 {
375 unsafe {
376 match comptime![&self.ident] {
377 TensorInputIdent::Lhs => MA::buffer_len_lhs(&(*self.state)),
378 TensorInputIdent::Rhs => MA::buffer_len_rhs(&(*self.state)),
379 }
380 }
381 }
382
383 pub fn as_tensor_map(&self) -> TensorMap<EI> {
385 unsafe {
386 match comptime![&self.ident] {
387 TensorInputIdent::Lhs => MA::as_tensor_map_lhs(&(*self.state)),
388 TensorInputIdent::Rhs => MA::as_tensor_map_rhs(&(*self.state)),
389 }
390 }
391 }
392}
393
394#[cube]
395impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> TensorOutput<EI, EO, GA> {
396 pub fn new(state: &mut GA::State<EI, EO>) -> TensorOutput<EI, EO, GA> {
398 TensorOutput::<EI, EO, GA> { state }
399 }
400
401 pub fn write(&self, coordinate: u32, value: Line<EO>) {
403 unsafe { GA::write_out(&mut (*self.state), coordinate, value) }
404 }
405
406 pub fn shape(&self, axis: u32) -> u32 {
408 unsafe { GA::shape_out(&(*self.state), axis) }
409 }
410
411 pub fn stride(&self, dim: u32) -> u32 {
413 unsafe { GA::stride_out(&(*self.state), dim) }
414 }
415
416 pub fn rank(&self) -> u32 {
418 unsafe { GA::rank_out(&(*self.state)) }
419 }
420
421 #[allow(clippy::len_without_is_empty)]
423 pub fn len(&self) -> u32 {
424 unsafe { GA::len_out(&(*self.state)) }
425 }
426
427 pub fn buffer_len(&self) -> u32 {
429 unsafe { GA::len_out(&(*self.state)) }
430 }
431}
432
433#[derive(Clone)]
434pub struct TensorArgs;
438
439#[derive(CubeLaunch, CubeType)]
440pub struct TensorInputs<EG: Numeric> {
442 pub lhs: Tensor<Line<EG>>,
444 pub rhs: Tensor<Line<EG>>,
446}
447
448impl<EG: Numeric> ConcreteInputsFactory for TensorInputs<EG> {
449 fn create<'a, R: Runtime>(
450 lhs: &'a TensorHandleRef<'a, R>,
451 rhs: &'a TensorHandleRef<'a, R>,
452 _selection: &MatmulSelection,
453 problem: &MatmulProblem,
454 ) -> Self::RuntimeArg<'a, R> {
455 TensorInputsLaunch::new(
456 lhs.as_tensor_arg(problem.lhs_line_size),
457 rhs.as_tensor_arg(problem.rhs_line_size),
458 )
459 }
460}
461
462impl<EG: Numeric> ConcreteOutputFactory for Tensor<Line<EG>> {
463 fn create<'a, R: Runtime>(
464 out: &'a TensorHandleRef<'a, R>,
465 _selection: &MatmulSelection,
466 problem: &MatmulProblem,
467 ) -> Self::RuntimeArg<'a, R> {
468 out.as_tensor_arg(problem.out_line_size)
469 }
470}
471
472#[cube]
473impl MatmulArgs for TensorArgs {
474 type Output<EO: Numeric> = Tensor<Line<EO>>;
475 type Input<EI: Numeric> = TensorInputs<EI>;
476 type State<EI: Numeric, EO: Numeric> = (
477 *const Tensor<Line<EI>>,
478 *const Tensor<Line<EI>>,
479 *mut Tensor<Line<EO>>,
480 );
481
482 fn init_state<EI: Numeric, EO: Numeric>(
483 input: &Self::Input<EI>,
484 output: &mut Self::Output<EO>,
485 ) -> Self::State<EI, EO> {
486 (&input.lhs, &input.rhs, output)
487 }
488
489 fn read_lhs<EI: Numeric, EO: Numeric>(
490 state: &Self::State<EI, EO>,
491 coordinate: u32,
492 ) -> Line<EI> {
493 unsafe { (*state.0)[coordinate] }
494 }
495
496 fn read_rhs<EI: Numeric, EO: Numeric>(
497 state: &Self::State<EI, EO>,
498 coordinate: u32,
499 ) -> Line<EI> {
500 unsafe { (*state.1)[coordinate] }
501 }
502
503 fn read_window_lhs<EI: Numeric, EO: Numeric>(
504 state: &Self::State<EI, EO>,
505 start: u32,
506 end: u32,
507 ) -> Slice<Line<EI>> {
508 unsafe { (*state.0).slice(start, end) }
509 }
510
511 fn read_window_rhs<EI: Numeric, EO: Numeric>(
513 state: &Self::State<EI, EO>,
514 start: u32,
515 end: u32,
516 ) -> Slice<Line<EI>> {
517 unsafe { (*state.1).slice(start, end) }
518 }
519
520 fn as_tensor_map_lhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> TensorMap<EI> {
521 comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`"));
522 #[allow(unreachable_code)]
523 TensorMap::dummy()
524 }
525
526 fn as_tensor_map_rhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> TensorMap<EI> {
527 comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`"));
528 #[allow(unreachable_code)]
529 TensorMap::dummy()
530 }
531
532 fn shape_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
533 unsafe { (*state.0).shape(dim) }
534 }
535
536 fn shape_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
537 unsafe { (*state.1).shape(dim) }
538 }
539
540 fn shape_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
541 unsafe { (*state.2).shape(dim) }
542 }
543
544 fn stride_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
545 unsafe { (*state.0).stride(dim) }
546 }
547
548 fn stride_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
549 unsafe { (*state.1).stride(dim) }
550 }
551
552 fn stride_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
553 unsafe { (*state.2).stride(dim) }
554 }
555
556 fn write_out<EI: Numeric, EO: Numeric>(
557 state: &mut Self::State<EI, EO>,
558 coordinate: u32,
559 value: Line<EO>,
560 ) {
561 unsafe { (*state.2)[coordinate] = value }
562 }
563
564 fn rank_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
565 unsafe { (*state.0).rank() }
566 }
567
568 fn rank_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
569 unsafe { (*state.1).rank() }
570 }
571
572 fn rank_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
573 unsafe { (*state.2).rank() }
574 }
575
576 fn len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
577 unsafe { (*state.0).len() }
578 }
579
580 fn len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
581 unsafe { (*state.1).len() }
582 }
583
584 fn len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
585 unsafe { (*state.2).len() }
586 }
587
588 fn buffer_len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
589 unsafe { (*state.0).buffer_len() }
590 }
591
592 fn buffer_len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
593 unsafe { (*state.1).buffer_len() }
594 }
595
596 fn buffer_len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
597 unsafe { (*state.2).buffer_len() }
598 }
599
600 fn quantization<MP: MatmulPrecision>(state: &Self::State<MP::EI, MP::EO>) -> Quantization<MP> {
601 let (lhs, rhs, _) = *state;
605 unsafe {
606 let buffer_len_lhs = Self::buffer_len_lhs(state);
607 let line_size_lhs = (*lhs).line_size();
608 let reinterpreted_len_lhs = buffer_len_lhs * line_size_lhs / 4; let scaling_lhs =
611 ReinterpretSlice::<MP::EI, f32>::new((*lhs).to_slice(), line_size_lhs)
612 .read(reinterpreted_len_lhs - 1);
613
614 let buffer_len_rhs = Self::buffer_len_rhs(state);
615 let line_size_rhs = (*rhs).line_size();
616 let reinterpreted_len_rhs = buffer_len_rhs * line_size_rhs / 4; let scaling_rhs =
619 ReinterpretSlice::<MP::EI, f32>::new((*rhs).to_slice(), line_size_rhs)
620 .read(reinterpreted_len_rhs - 1);
621
622 Quantization::<MP> {
623 scaling_lhs: MP::ES::cast_from(scaling_lhs),
624 scaling_rhs: MP::ES::cast_from(scaling_rhs),
625 }
626 }
627 }
628}
629
630#[derive(Clone)]
631pub struct TensorMapArgs;
635
636#[derive(CubeLaunch, CubeType)]
637pub struct TensorMapInputs<EG: Numeric> {
639 pub lhs: TensorMap<EG>,
641 pub rhs: TensorMap<EG>,
643}
644
645impl<EG: Numeric> ConcreteInputsFactory for TensorMapInputs<EG> {
646 fn create<'a, R: Runtime>(
647 lhs: &'a TensorHandleRef<'a, R>,
648 rhs: &'a TensorHandleRef<'a, R>,
649 selection: &MatmulSelection,
650 problem: &MatmulProblem,
651 ) -> Self::RuntimeArg<'a, R> {
652 let stage_m = selection.tile_count.m * selection.tile_shape.m;
653 let stage_n = selection.tile_count.n * selection.tile_shape.n;
654 let stage_k = selection.tile_count.k * selection.tile_shape.k;
655 let stage_size_lhs = match problem.lhs_layout {
656 components::MatrixLayout::RowMajor => vec![1, stage_m, selection.tile_shape.k],
657 components::MatrixLayout::ColMajor => vec![1, stage_k, selection.tile_shape.m],
658 };
659 let stage_size_rhs = match problem.rhs_layout {
660 components::MatrixLayout::RowMajor => vec![1, stage_k, selection.tile_shape.n],
661 components::MatrixLayout::ColMajor => vec![1, stage_n, selection.tile_shape.k],
662 };
663
664 let elem_size = size_of::<EG>();
665
666 let lhs_rank = lhs.shape.len();
667 let mut lhs_shape = vec![
668 problem.batches.0[0],
669 lhs.shape[lhs_rank - 2],
670 lhs.shape[lhs_rank - 1],
671 ];
672 let mut lhs_strides = if lhs_rank > 2 {
673 lhs.strides[lhs_rank - 3..].to_vec()
674 } else {
675 vec![1, lhs.strides[lhs_rank - 2], lhs.strides[lhs_rank - 1]]
676 };
677
678 let rhs_rank = rhs.shape.len();
679 let mut rhs_shape = vec![
680 problem.batches.1[0],
681 rhs.shape[rhs_rank - 2],
682 rhs.shape[rhs_rank - 1],
683 ];
684 let mut rhs_strides = if rhs_rank > 2 {
685 rhs.strides[rhs_rank - 3..].to_vec()
686 } else {
687 vec![1, rhs.strides[rhs_rank - 2], rhs.strides[rhs_rank - 1]]
688 };
689
690 if matches!(problem.lhs_layout, components::MatrixLayout::ColMajor) {
693 lhs_shape.swap(lhs_rank - 1, lhs_rank - 2);
694 lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
695 }
696 if matches!(problem.rhs_layout, components::MatrixLayout::ColMajor) {
697 rhs_shape.swap(rhs_rank - 1, rhs_rank - 2);
698 rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
699 }
700
701 fn prefetch(bytes: usize) -> TensorMapPrefetch {
702 match bytes {
703 ..64 => TensorMapPrefetch::None,
704 64..128 => TensorMapPrefetch::B64,
705 128..256 => TensorMapPrefetch::B128,
706 256.. => TensorMapPrefetch::B256,
707 }
708 }
709
710 let prefetch_lhs = prefetch(stage_size_lhs[2] as usize * elem_size);
711 let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * elem_size);
712
713 let elem = if TypeId::of::<EG>() == TypeId::of::<f32>() {
716 tf32::as_elem_native_unchecked()
717 } else {
718 EG::as_elem_native_unchecked()
719 };
720
721 let meta_lhs = TensorMapMeta {
722 format: TensorMapFormat::Tiled {
723 tile_size: stage_size_lhs,
724 },
725 rank: 3,
726 shape: lhs_shape,
727 strides: lhs_strides,
728 elem_stride: vec![1, 1, 1],
729 interleave: TensorMapInterleave::None,
730 swizzle: TensorMapSwizzle::None,
731 prefetch: prefetch_lhs,
732 oob_fill: OobFill::Zero,
733 elem,
734 };
735
736 let meta_rhs = TensorMapMeta {
737 format: TensorMapFormat::Tiled {
738 tile_size: stage_size_rhs,
739 },
740 rank: 3,
741 shape: rhs_shape,
742 strides: rhs_strides,
743 elem_stride: vec![1, 1, 1],
744 interleave: TensorMapInterleave::None,
745 swizzle: TensorMapSwizzle::None,
746 prefetch: prefetch_rhs,
747 oob_fill: OobFill::Zero,
748 elem,
749 };
750
751 let lhs = TensorMapArg {
752 tensor: lhs.as_tensor_arg(problem.lhs_line_size),
753 metadata: meta_lhs,
754 };
755 let rhs = TensorMapArg {
756 tensor: rhs.as_tensor_arg(problem.rhs_line_size),
757 metadata: meta_rhs,
758 };
759
760 TensorMapInputsLaunch::new(lhs, rhs)
761 }
762}
763
764#[cube]
765impl MatmulArgs for TensorMapArgs {
766 type Input<EI: Numeric> = TensorMapInputs<EI>;
767 type Output<EO: Numeric> = Tensor<Line<EO>>;
768 type State<EI: Numeric, EO: Numeric> = (
769 *const TensorMap<EI>,
770 *const TensorMap<EI>,
771 *mut Tensor<Line<EO>>,
772 );
773
774 fn init_state<EI: Numeric, EO: Numeric>(
775 input: &Self::Input<EI>,
776 output: &mut Self::Output<EO>,
777 ) -> Self::State<EI, EO> {
778 (&input.lhs, &input.rhs, output)
779 }
780
781 fn read_lhs<EI: Numeric, EO: Numeric>(
782 _state: &Self::State<EI, EO>,
783 _coordinate: u32,
784 ) -> Line<EI> {
785 unimplemented!("Can't directly read from TensorMap")
786 }
787
788 fn read_rhs<EI: Numeric, EO: Numeric>(
789 _state: &Self::State<EI, EO>,
790 _coordinate: u32,
791 ) -> Line<EI> {
792 unimplemented!("Can't directly read from TensorMap")
793 }
794
795 #[allow(unused)]
796 fn read_window_lhs<EI: Numeric, EO: Numeric>(
797 state: &Self::State<EI, EO>,
798 start: u32,
799 end: u32,
800 ) -> Slice<Line<EI>> {
801 unimplemented!("Can't directly read from TensorMap")
802 }
803
804 #[allow(unused)]
806 fn read_window_rhs<EI: Numeric, EO: Numeric>(
807 state: &Self::State<EI, EO>,
808 start: u32,
809 end: u32,
810 ) -> Slice<Line<EI>> {
811 unimplemented!("Can't directly read from TensorMap")
812 }
813
814 fn as_tensor_map_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> TensorMap<EI> {
815 unsafe { *state.0 }
816 }
817
818 fn as_tensor_map_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> TensorMap<EI> {
819 unsafe { *state.1 }
820 }
821
822 fn shape_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
823 unsafe { (*state.0).shape(dim) }
824 }
825
826 fn shape_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
827 unsafe { (*state.1).shape(dim) }
828 }
829
830 fn shape_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
831 unsafe { &*state.2 }.shape(dim)
832 }
833
834 fn stride_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
835 unsafe { &*state.0 }.stride(dim)
836 }
837
838 fn stride_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
839 unsafe { &*state.1 }.stride(dim)
840 }
841
842 fn stride_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
843 unsafe { &*state.2 }.stride(dim)
844 }
845
846 fn write_out<EI: Numeric, EO: Numeric>(
847 state: &mut Self::State<EI, EO>,
848 coordinate: u32,
849 value: Line<EO>,
850 ) {
851 unsafe { (*state.2)[coordinate] = value }
852 }
853
854 fn rank_lhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
855 unimplemented!("Can't read metadata from TensorMap")
856 }
857
858 fn rank_rhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
859 unimplemented!("Can't read metadata from TensorMap")
860 }
861
862 fn rank_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
863 unsafe { (*state.2).rank() }
864 }
865
866 fn len_lhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
867 unimplemented!("Can't read metadata from TensorMap")
868 }
869
870 fn len_rhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
871 unimplemented!("Can't read metadata from TensorMap")
872 }
873
874 fn len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
875 unsafe { (*state.2).len() }
876 }
877
878 fn buffer_len_lhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
879 unimplemented!("Can't read metadata from TensorMap")
880 }
881
882 fn buffer_len_rhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
883 unimplemented!("Can't read metadata from TensorMap")
884 }
885
886 fn buffer_len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
887 unsafe { (*state.2).buffer_len() }
888 }
889
890 fn quantization<MP: MatmulPrecision>(_state: &Self::State<MP::EI, MP::EO>) -> Quantization<MP> {
891 todo!("Quantized TMA not yet supported")
892 }
893}
894
895mod __input {
896 use super::*;
897
898 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> CubeType for TensorInput<EI, EO, GA> {
899 type ExpandType = TensorInputExpand<EI, EO, GA>;
900 }
901
902 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Clone for TensorInputExpand<EI, EO, GA> {
903 fn clone(&self) -> Self {
904 Self {
905 state: self.state.clone(),
906 ident: self.ident,
907 }
908 }
909 }
910
911 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Init for TensorInputExpand<EI, EO, GA> {
912 fn init(mut self, scope: &mut Scope) -> Self {
913 self.state = self.state.init(scope);
914 self
915 }
916 }
917 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> CubeDebug for TensorInputExpand<EI, EO, GA> {
918 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
919 self.state.set_debug_name(scope, name);
920 }
921 }
922 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Clone for TensorInput<EI, EO, GA> {
923 fn clone(&self) -> Self {
924 *self
925 }
926 }
927 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Copy for TensorInput<EI, EO, GA> {}
928}
929
930mod __output {
931 use super::*;
932
933 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> CubeType for TensorOutput<EI, EO, GA> {
934 type ExpandType = TensorOutputExpand<EI, EO, GA>;
935 }
936
937 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Clone for TensorOutput<EI, EO, GA> {
938 fn clone(&self) -> Self {
939 *self
940 }
941 }
942
943 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Clone for TensorOutputExpand<EI, EO, GA> {
944 fn clone(&self) -> Self {
945 Self {
946 state: self.state.clone(),
947 }
948 }
949 }
950
951 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Init for TensorOutputExpand<EI, EO, GA> {
952 fn init(mut self, scope: &mut Scope) -> Self {
953 self.state = self.state.init(scope);
954 self
955 }
956 }
957
958 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> CubeDebug for TensorOutputExpand<EI, EO, GA> {
959 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
960 self.state.set_debug_name(scope, name);
961 }
962 }
963
964 impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Copy for TensorOutput<EI, EO, GA> {}
965}