cubecl_linalg/matmul/components/global/
args.rs

1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core::{self as cubecl, server::TensorMapMeta};
5use cubecl_std::{
6    ReinterpretSlice,
7    tensor::r#virtual::{VirtualTensorOperations, VirtualTensorOperationsExpand},
8};
9
10use super::Quantization;
11use crate::matmul::components::{self, MatmulPrecision, MatmulProblem, MatmulSelection};
12
13/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
14/// output (not fused).
15pub trait ConcreteInputsFactory: LaunchArg {
16    fn create<'a, R: Runtime>(
17        lhs: &'a TensorHandleRef<'a, R>,
18        rhs: &'a TensorHandleRef<'a, R>,
19        selection: &MatmulSelection,
20        problem: &MatmulProblem,
21    ) -> Self::RuntimeArg<'a, R>;
22}
23
24/// Create the output runtime argument for a matmul kernel that works on concrete inputs and
25/// output (not fused).
26pub trait ConcreteOutputFactory: LaunchArg {
27    fn create<'a, R: Runtime>(
28        out: &'a TensorHandleRef<'a, R>,
29        selection: &MatmulSelection,
30        problem: &MatmulProblem,
31    ) -> Self::RuntimeArg<'a, R>;
32}
33
34#[cube]
35/// Arguments for the matrix multiplication algorithm.
36pub trait MatmulArgs: Send + Sync + 'static + Clone {
37    /// Type used for the input.
38    type Input<EI: Numeric>: LaunchArg + CubeType;
39    /// Type used for the output.
40    type Output<EO: Numeric>: LaunchArg + CubeType;
41    /// Inner state that is used to create [tensor inputs](TensorInput) and
42    /// [tensor outputs](TensorOutput) .
43    type State<EI: Numeric, EO: Numeric>: CubeType;
44
45    /// Init the state.
46    fn init_state<EI: Numeric, EO: Numeric>(
47        input: &Self::Input<EI>,
48        output: &mut Self::Output<EO>,
49    ) -> Self::State<EI, EO>;
50
51    /// Read the line of the lhs tensor using the state at the given coordinate.
52    fn read_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, coordinate: u32)
53    -> Line<EI>;
54    /// Read the line of the rhs tensor using the state at the given coordinate.
55    fn read_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, coordinate: u32)
56    -> Line<EI>;
57
58    /// Read the line of the lhs tensor using the state at the given coordinate.
59    fn read_window_lhs<EI: Numeric, EO: Numeric>(
60        state: &Self::State<EI, EO>,
61        start: u32,
62        end: u32,
63    ) -> Slice<Line<EI>>;
64
65    /// Read the line of the rhs tensor using the state at the given coordinate.
66    fn read_window_rhs<EI: Numeric, EO: Numeric>(
67        state: &Self::State<EI, EO>,
68        start: u32,
69        end: u32,
70    ) -> Slice<Line<EI>>;
71
72    /// Reinterpret lhs as tensor map
73    fn as_tensor_map_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> TensorMap<EI>;
74
75    /// Reinterpret rhs as tensor map
76    fn as_tensor_map_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> TensorMap<EI>;
77
78    /// Write the line to the output at the given coordinate using the state.
79    fn write_out<EI: Numeric, EO: Numeric>(
80        state: &mut Self::State<EI, EO>,
81        coordinate: u32,
82        value: Line<EO>,
83    );
84
85    /// Get the rank of the lhs tensor using the state.
86    fn rank_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
87    /// Get the rank of the rhs tensor using the state.
88    fn rank_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
89    /// Get the rank of the out tensor using the state.
90    fn rank_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
91
92    /// Get the length of the lhs tensor using the state.
93    fn len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
94    /// Get the length of the rhs tensor using the state.
95    fn len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
96    /// Get the length of the out tensor using the state.
97    fn len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
98
99    /// Get the buffer length of the lhs tensor using the state.
100    fn buffer_len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
101    /// Get the buffer length of the rhs tensor using the state.
102    fn buffer_len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
103    /// Get the buffer length of the out tensor using the state.
104    fn buffer_len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
105
106    /// Get the shape of the lhs tensor using the state.
107    fn shape_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
108    /// Get the shape of the rhs tensor using the state.
109    fn shape_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
110    /// Get the shape of the out tensor using the state.
111    fn shape_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
112
113    /// Get the stride of the lhs tensor using the state.
114    fn stride_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
115    /// Get the stride of the rhs tensor using the state.
116    fn stride_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
117    /// Get the stride of the out tensor using the state.
118    fn stride_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, axis: u32) -> u32;
119
120    /// It is the responsibility of the caller to ensure it is safe to call this function.
121    /// That is, when a matmul is indeed quantized. Else, it will most likely results in
122    /// out-of-bound memory access.
123    fn quantization<MP: MatmulPrecision>(state: &Self::State<MP::EI, MP::EO>) -> Quantization<MP>;
124}
125
126#[derive(Clone, Copy)]
127/// Identification of the [tensor input](TensorInput).
128pub enum TensorInputIdent {
129    Lhs,
130    Rhs,
131}
132
133/// Tensor input representation.
134///
135/// You can use the tensor input as if it was a pointer to the actually tensor.
136pub struct TensorInput<EI: Numeric, EO: Numeric, GA: MatmulArgs> {
137    state: *const GA::State<EI, EO>,
138    ident: TensorInputIdent,
139}
140
141impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> VirtualTensorOperations<EI>
142    for TensorInput<EI, EO, MA>
143{
144}
145
146impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> VirtualTensorOperations<EO>
147    for TensorOutput<EI, EO, MA>
148{
149}
150
151impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> VirtualTensorOperationsExpand<EO>
152    for TensorOutputExpand<EI, EO, MA>
153{
154    fn __expand_read_method(
155        &self,
156        _scope: &mut Scope,
157        _index: ExpandElementTyped<u32>,
158    ) -> ExpandElementTyped<Line<EO>> {
159        panic!("Can't read output tensor");
160    }
161
162    fn __expand_read_window_method(
163        &self,
164        _context: &mut Scope,
165        _start: ExpandElementTyped<u32>,
166        _end: ExpandElementTyped<u32>,
167    ) -> ExpandElementTyped<Slice<Line<EO>>> {
168        panic!("Can't read output tensor");
169    }
170
171    fn __expand_write_method(
172        &self,
173        scope: &mut Scope,
174        index: ExpandElementTyped<u32>,
175        value: ExpandElementTyped<Line<EO>>,
176    ) {
177        TensorOutputExpand::__expand_write_method(self.clone(), scope, index, value)
178    }
179
180    fn __expand_shape_method(
181        &self,
182        scope: &mut Scope,
183        axis: ExpandElementTyped<u32>,
184    ) -> ExpandElementTyped<u32> {
185        TensorOutputExpand::__expand_shape_method(self.clone(), scope, axis)
186    }
187
188    fn __expand_stride_method(
189        &self,
190        scope: &mut Scope,
191        axis: ExpandElementTyped<u32>,
192    ) -> ExpandElementTyped<u32> {
193        TensorOutputExpand::__expand_stride_method(self.clone(), scope, axis)
194    }
195
196    fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
197        TensorOutputExpand::__expand_rank_method(self.clone(), scope)
198    }
199
200    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
201        TensorOutputExpand::__expand_len_method(self.clone(), scope)
202    }
203
204    fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
205        TensorOutputExpand::__expand_buffer_len_method(self.clone(), scope)
206    }
207
208    fn __expand_as_tensor_map_method(
209        &self,
210        _scope: &mut Scope,
211    ) -> ExpandElementTyped<TensorMap<EO>> {
212        unimplemented!("TensorOutputExpand can't be turned into a tensor map");
213    }
214}
215
216impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> VirtualTensorOperationsExpand<EI>
217    for TensorInputExpand<EI, EO, MA>
218{
219    fn __expand_read_method(
220        &self,
221        scope: &mut Scope,
222        index: ExpandElementTyped<u32>,
223    ) -> ExpandElementTyped<Line<EI>> {
224        TensorInputExpand::__expand_read_method(self.clone(), scope, index)
225    }
226    fn __expand_read_window_method(
227        &self,
228        context: &mut Scope,
229        start: ExpandElementTyped<u32>,
230        end: ExpandElementTyped<u32>,
231    ) -> ExpandElementTyped<Slice<Line<EI>>> {
232        TensorInputExpand::__expand_read_window_method(self.clone(), context, start, end)
233    }
234
235    fn __expand_write_method(
236        &self,
237        _scope: &mut Scope,
238        _index: ExpandElementTyped<u32>,
239        _value: ExpandElementTyped<Line<EI>>,
240    ) {
241        panic!("Can't write to input tensor");
242    }
243
244    fn __expand_shape_method(
245        &self,
246        scope: &mut Scope,
247        axis: ExpandElementTyped<u32>,
248    ) -> ExpandElementTyped<u32> {
249        TensorInputExpand::__expand_shape_method(self.clone(), scope, axis)
250    }
251
252    fn __expand_stride_method(
253        &self,
254        scope: &mut Scope,
255        axis: ExpandElementTyped<u32>,
256    ) -> ExpandElementTyped<u32> {
257        TensorInputExpand::__expand_stride_method(self.clone(), scope, axis)
258    }
259
260    fn __expand_rank_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
261        TensorInputExpand::__expand_rank_method(self.clone(), scope)
262    }
263
264    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
265        TensorInputExpand::__expand_len_method(self.clone(), scope)
266    }
267
268    fn __expand_buffer_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<u32> {
269        TensorInputExpand::__expand_buffer_len_method(self.clone(), scope)
270    }
271
272    fn __expand_as_tensor_map_method(
273        &self,
274        scope: &mut Scope,
275    ) -> ExpandElementTyped<TensorMap<EI>> {
276        TensorInputExpand::__expand_as_tensor_map_method(self.clone(), scope)
277    }
278}
279
280/// Tensor output representation.
281///
282/// You can use the tensor output as if it was a pointer to the actually tensor.
283///
284/// # Warning
285///
286/// There is no mutability guarantee.
287pub struct TensorOutput<EI: Numeric, EO: Numeric, GA: MatmulArgs> {
288    state: *mut GA::State<EI, EO>,
289}
290
291/// Expand type for [tensor input](TensorInput).
292pub struct TensorInputExpand<EI: Numeric, EO: Numeric, GA: MatmulArgs> {
293    state: <GA::State<EI, EO> as CubeType>::ExpandType,
294    ident: TensorInputIdent,
295}
296
297/// Expand type for [tensor output](TensorOutput).
298pub struct TensorOutputExpand<EI: Numeric, EO: Numeric, GA: MatmulArgs> {
299    state: <GA::State<EI, EO> as CubeType>::ExpandType,
300}
301
302#[cube]
303impl<EI: Numeric, EO: Numeric, MA: MatmulArgs> TensorInput<EI, EO, MA> {
304    /// Create a [tensor input](TensorInput) from the state and the [ident](TensorInputIdent).
305    pub fn new(
306        state: &MA::State<EI, EO>,
307        #[comptime] ident: TensorInputIdent,
308    ) -> TensorInput<EI, EO, MA> {
309        TensorInput::<EI, EO, MA> { state, ident }
310    }
311
312    //// Read the tensor at the given coordinate.
313    pub fn read_window(&self, start: u32, end: u32) -> Slice<Line<EI>> {
314        unsafe {
315            match comptime![&self.ident] {
316                TensorInputIdent::Lhs => MA::read_window_lhs(&(*self.state), start, end),
317                TensorInputIdent::Rhs => MA::read_window_rhs(&(*self.state), start, end),
318            }
319        }
320    }
321
322    /// Read the tensor at the given coordinate.
323    pub fn read(&self, coordinate: u32) -> Line<EI> {
324        unsafe {
325            match comptime![&self.ident] {
326                TensorInputIdent::Lhs => MA::read_lhs(&(*self.state), coordinate),
327                TensorInputIdent::Rhs => MA::read_rhs(&(*self.state), coordinate),
328            }
329        }
330    }
331
332    /// Get the shape of the tensor at the given axis.
333    pub fn shape(&self, axis: u32) -> u32 {
334        unsafe {
335            match comptime![&self.ident] {
336                TensorInputIdent::Lhs => MA::shape_lhs(&(*self.state), axis),
337                TensorInputIdent::Rhs => MA::shape_rhs(&(*self.state), axis),
338            }
339        }
340    }
341
342    /// Get the stride of the tensor at the given axis.
343    pub fn stride(&self, axis: u32) -> u32 {
344        unsafe {
345            match comptime![&self.ident] {
346                TensorInputIdent::Lhs => MA::stride_lhs(&(*self.state), axis),
347                TensorInputIdent::Rhs => MA::stride_rhs(&(*self.state), axis),
348            }
349        }
350    }
351
352    /// Get the rank of the tensor.
353    pub fn rank(&self) -> u32 {
354        unsafe {
355            match comptime![&self.ident] {
356                TensorInputIdent::Lhs => MA::rank_lhs(&(*self.state)),
357                TensorInputIdent::Rhs => MA::rank_rhs(&(*self.state)),
358            }
359        }
360    }
361
362    /// Get the length of the tensor.
363    #[allow(clippy::len_without_is_empty)]
364    pub fn len(&self) -> u32 {
365        unsafe {
366            match comptime![&self.ident] {
367                TensorInputIdent::Lhs => MA::len_lhs(&(*self.state)),
368                TensorInputIdent::Rhs => MA::len_rhs(&(*self.state)),
369            }
370        }
371    }
372
373    /// Get the buffer length of the tensor.
374    pub fn buffer_len(&self) -> u32 {
375        unsafe {
376            match comptime![&self.ident] {
377                TensorInputIdent::Lhs => MA::buffer_len_lhs(&(*self.state)),
378                TensorInputIdent::Rhs => MA::buffer_len_rhs(&(*self.state)),
379            }
380        }
381    }
382
383    /// Get the buffer length of the tensor.
384    pub fn as_tensor_map(&self) -> TensorMap<EI> {
385        unsafe {
386            match comptime![&self.ident] {
387                TensorInputIdent::Lhs => MA::as_tensor_map_lhs(&(*self.state)),
388                TensorInputIdent::Rhs => MA::as_tensor_map_rhs(&(*self.state)),
389            }
390        }
391    }
392}
393
394#[cube]
395impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> TensorOutput<EI, EO, GA> {
396    /// Create a [tensor output](TensorOutput) from the state.
397    pub fn new(state: &mut GA::State<EI, EO>) -> TensorOutput<EI, EO, GA> {
398        TensorOutput::<EI, EO, GA> { state }
399    }
400
401    /// Write the value to tensor at the given coordinate.
402    pub fn write(&self, coordinate: u32, value: Line<EO>) {
403        unsafe { GA::write_out(&mut (*self.state), coordinate, value) }
404    }
405
406    /// Get the shape of the tensor at the given axis.
407    pub fn shape(&self, axis: u32) -> u32 {
408        unsafe { GA::shape_out(&(*self.state), axis) }
409    }
410
411    /// Get the stride of the tensor at the given axis.
412    pub fn stride(&self, dim: u32) -> u32 {
413        unsafe { GA::stride_out(&(*self.state), dim) }
414    }
415
416    /// Get the rank of the tensor.
417    pub fn rank(&self) -> u32 {
418        unsafe { GA::rank_out(&(*self.state)) }
419    }
420
421    /// Get the length of the tensor.
422    #[allow(clippy::len_without_is_empty)]
423    pub fn len(&self) -> u32 {
424        unsafe { GA::len_out(&(*self.state)) }
425    }
426
427    /// Get the buffer length of the tensor.
428    pub fn buffer_len(&self) -> u32 {
429        unsafe { GA::len_out(&(*self.state)) }
430    }
431}
432
433#[derive(Clone)]
434/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensors.
435///
436/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
437pub struct TensorArgs;
438
439#[derive(CubeLaunch, CubeType)]
440/// Input representation for [TensorArgs] implementing [MatmulArgs].
441pub struct TensorInputs<EG: Numeric> {
442    /// The lhs tensor.
443    pub lhs: Tensor<Line<EG>>,
444    /// The rhs tensor.
445    pub rhs: Tensor<Line<EG>>,
446}
447
448impl<EG: Numeric> ConcreteInputsFactory for TensorInputs<EG> {
449    fn create<'a, R: Runtime>(
450        lhs: &'a TensorHandleRef<'a, R>,
451        rhs: &'a TensorHandleRef<'a, R>,
452        _selection: &MatmulSelection,
453        problem: &MatmulProblem,
454    ) -> Self::RuntimeArg<'a, R> {
455        TensorInputsLaunch::new(
456            lhs.as_tensor_arg(problem.lhs_line_size),
457            rhs.as_tensor_arg(problem.rhs_line_size),
458        )
459    }
460}
461
462impl<EG: Numeric> ConcreteOutputFactory for Tensor<Line<EG>> {
463    fn create<'a, R: Runtime>(
464        out: &'a TensorHandleRef<'a, R>,
465        _selection: &MatmulSelection,
466        problem: &MatmulProblem,
467    ) -> Self::RuntimeArg<'a, R> {
468        out.as_tensor_arg(problem.out_line_size)
469    }
470}
471
472#[cube]
473impl MatmulArgs for TensorArgs {
474    type Output<EO: Numeric> = Tensor<Line<EO>>;
475    type Input<EI: Numeric> = TensorInputs<EI>;
476    type State<EI: Numeric, EO: Numeric> = (
477        *const Tensor<Line<EI>>,
478        *const Tensor<Line<EI>>,
479        *mut Tensor<Line<EO>>,
480    );
481
482    fn init_state<EI: Numeric, EO: Numeric>(
483        input: &Self::Input<EI>,
484        output: &mut Self::Output<EO>,
485    ) -> Self::State<EI, EO> {
486        (&input.lhs, &input.rhs, output)
487    }
488
489    fn read_lhs<EI: Numeric, EO: Numeric>(
490        state: &Self::State<EI, EO>,
491        coordinate: u32,
492    ) -> Line<EI> {
493        unsafe { (*state.0)[coordinate] }
494    }
495
496    fn read_rhs<EI: Numeric, EO: Numeric>(
497        state: &Self::State<EI, EO>,
498        coordinate: u32,
499    ) -> Line<EI> {
500        unsafe { (*state.1)[coordinate] }
501    }
502
503    fn read_window_lhs<EI: Numeric, EO: Numeric>(
504        state: &Self::State<EI, EO>,
505        start: u32,
506        end: u32,
507    ) -> Slice<Line<EI>> {
508        unsafe { (*state.0).slice(start, end) }
509    }
510
511    /// Read the line of the rhs tensor using the state at the given coordinate.
512    fn read_window_rhs<EI: Numeric, EO: Numeric>(
513        state: &Self::State<EI, EO>,
514        start: u32,
515        end: u32,
516    ) -> Slice<Line<EI>> {
517        unsafe { (*state.1).slice(start, end) }
518    }
519
520    fn as_tensor_map_lhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> TensorMap<EI> {
521        comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`"));
522        #[allow(unreachable_code)]
523        TensorMap::dummy()
524    }
525
526    fn as_tensor_map_rhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> TensorMap<EI> {
527        comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`"));
528        #[allow(unreachable_code)]
529        TensorMap::dummy()
530    }
531
532    fn shape_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
533        unsafe { (*state.0).shape(dim) }
534    }
535
536    fn shape_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
537        unsafe { (*state.1).shape(dim) }
538    }
539
540    fn shape_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
541        unsafe { (*state.2).shape(dim) }
542    }
543
544    fn stride_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
545        unsafe { (*state.0).stride(dim) }
546    }
547
548    fn stride_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
549        unsafe { (*state.1).stride(dim) }
550    }
551
552    fn stride_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
553        unsafe { (*state.2).stride(dim) }
554    }
555
556    fn write_out<EI: Numeric, EO: Numeric>(
557        state: &mut Self::State<EI, EO>,
558        coordinate: u32,
559        value: Line<EO>,
560    ) {
561        unsafe { (*state.2)[coordinate] = value }
562    }
563
564    fn rank_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
565        unsafe { (*state.0).rank() }
566    }
567
568    fn rank_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
569        unsafe { (*state.1).rank() }
570    }
571
572    fn rank_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
573        unsafe { (*state.2).rank() }
574    }
575
576    fn len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
577        unsafe { (*state.0).len() }
578    }
579
580    fn len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
581        unsafe { (*state.1).len() }
582    }
583
584    fn len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
585        unsafe { (*state.2).len() }
586    }
587
588    fn buffer_len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
589        unsafe { (*state.0).buffer_len() }
590    }
591
592    fn buffer_len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
593        unsafe { (*state.1).buffer_len() }
594    }
595
596    fn buffer_len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
597        unsafe { (*state.2).buffer_len() }
598    }
599
600    fn quantization<MP: MatmulPrecision>(state: &Self::State<MP::EI, MP::EO>) -> Quantization<MP> {
601        // TODO Currently, this assume that the scaling is always the last value in the buffer.
602        //      Also, in burn the scaling is presently fix to f32, hence the extra conversions.
603
604        let (lhs, rhs, _) = *state;
605        unsafe {
606            let buffer_len_lhs = Self::buffer_len_lhs(state);
607            let line_size_lhs = (*lhs).line_size();
608            let reinterpreted_len_lhs = buffer_len_lhs * line_size_lhs / 4; // TODO Change this when we stop using u32 to pack 4 i8 in burn.
609
610            let scaling_lhs =
611                ReinterpretSlice::<MP::EI, f32>::new((*lhs).to_slice(), line_size_lhs)
612                    .read(reinterpreted_len_lhs - 1);
613
614            let buffer_len_rhs = Self::buffer_len_rhs(state);
615            let line_size_rhs = (*rhs).line_size();
616            let reinterpreted_len_rhs = buffer_len_rhs * line_size_rhs / 4; // TODO See above comment.
617
618            let scaling_rhs =
619                ReinterpretSlice::<MP::EI, f32>::new((*rhs).to_slice(), line_size_rhs)
620                    .read(reinterpreted_len_rhs - 1);
621
622            Quantization::<MP> {
623                scaling_lhs: MP::ES::cast_from(scaling_lhs),
624                scaling_rhs: MP::ES::cast_from(scaling_rhs),
625            }
626        }
627    }
628}
629
630#[derive(Clone)]
631/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensor maps.
632///
633/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
634pub struct TensorMapArgs;
635
636#[derive(CubeLaunch, CubeType)]
637/// Input representation for [TensorArgs] implementing [MatmulArgs].
638pub struct TensorMapInputs<EG: Numeric> {
639    /// The lhs tensor.
640    pub lhs: TensorMap<EG>,
641    /// The rhs tensor.
642    pub rhs: TensorMap<EG>,
643}
644
645impl<EG: Numeric> ConcreteInputsFactory for TensorMapInputs<EG> {
646    fn create<'a, R: Runtime>(
647        lhs: &'a TensorHandleRef<'a, R>,
648        rhs: &'a TensorHandleRef<'a, R>,
649        selection: &MatmulSelection,
650        problem: &MatmulProblem,
651    ) -> Self::RuntimeArg<'a, R> {
652        let stage_m = selection.tile_count.m * selection.tile_shape.m;
653        let stage_n = selection.tile_count.n * selection.tile_shape.n;
654        let stage_k = selection.tile_count.k * selection.tile_shape.k;
655        let stage_size_lhs = match problem.lhs_layout {
656            components::MatrixLayout::RowMajor => vec![1, stage_m, selection.tile_shape.k],
657            components::MatrixLayout::ColMajor => vec![1, stage_k, selection.tile_shape.m],
658        };
659        let stage_size_rhs = match problem.rhs_layout {
660            components::MatrixLayout::RowMajor => vec![1, stage_k, selection.tile_shape.n],
661            components::MatrixLayout::ColMajor => vec![1, stage_n, selection.tile_shape.k],
662        };
663
664        let elem_size = size_of::<EG>();
665
666        let lhs_rank = lhs.shape.len();
667        let mut lhs_shape = vec![
668            problem.batches.0[0],
669            lhs.shape[lhs_rank - 2],
670            lhs.shape[lhs_rank - 1],
671        ];
672        let mut lhs_strides = if lhs_rank > 2 {
673            lhs.strides[lhs_rank - 3..].to_vec()
674        } else {
675            vec![1, lhs.strides[lhs_rank - 2], lhs.strides[lhs_rank - 1]]
676        };
677
678        let rhs_rank = rhs.shape.len();
679        let mut rhs_shape = vec![
680            problem.batches.1[0],
681            rhs.shape[rhs_rank - 2],
682            rhs.shape[rhs_rank - 1],
683        ];
684        let mut rhs_strides = if rhs_rank > 2 {
685            rhs.strides[rhs_rank - 3..].to_vec()
686        } else {
687            vec![1, rhs.strides[rhs_rank - 2], rhs.strides[rhs_rank - 1]]
688        };
689
690        // TMA assumes the last stride is contiguous and won't even take it, so we need to map it
691        // with transposed shape and stride. Tensor metadata still has the normal layout.
692        if matches!(problem.lhs_layout, components::MatrixLayout::ColMajor) {
693            lhs_shape.swap(lhs_rank - 1, lhs_rank - 2);
694            lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
695        }
696        if matches!(problem.rhs_layout, components::MatrixLayout::ColMajor) {
697            rhs_shape.swap(rhs_rank - 1, rhs_rank - 2);
698            rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
699        }
700
701        fn prefetch(bytes: usize) -> TensorMapPrefetch {
702            match bytes {
703                ..64 => TensorMapPrefetch::None,
704                64..128 => TensorMapPrefetch::B64,
705                128..256 => TensorMapPrefetch::B128,
706                256.. => TensorMapPrefetch::B256,
707            }
708        }
709
710        let prefetch_lhs = prefetch(stage_size_lhs[2] as usize * elem_size);
711        let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * elem_size);
712
713        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
714        // It shouldn't matter, but it's better to be safe.
715        let elem = if TypeId::of::<EG>() == TypeId::of::<f32>() {
716            tf32::as_elem_native_unchecked()
717        } else {
718            EG::as_elem_native_unchecked()
719        };
720
721        let meta_lhs = TensorMapMeta {
722            format: TensorMapFormat::Tiled {
723                tile_size: stage_size_lhs,
724            },
725            rank: 3,
726            shape: lhs_shape,
727            strides: lhs_strides,
728            elem_stride: vec![1, 1, 1],
729            interleave: TensorMapInterleave::None,
730            swizzle: TensorMapSwizzle::None,
731            prefetch: prefetch_lhs,
732            oob_fill: OobFill::Zero,
733            elem,
734        };
735
736        let meta_rhs = TensorMapMeta {
737            format: TensorMapFormat::Tiled {
738                tile_size: stage_size_rhs,
739            },
740            rank: 3,
741            shape: rhs_shape,
742            strides: rhs_strides,
743            elem_stride: vec![1, 1, 1],
744            interleave: TensorMapInterleave::None,
745            swizzle: TensorMapSwizzle::None,
746            prefetch: prefetch_rhs,
747            oob_fill: OobFill::Zero,
748            elem,
749        };
750
751        let lhs = TensorMapArg {
752            tensor: lhs.as_tensor_arg(problem.lhs_line_size),
753            metadata: meta_lhs,
754        };
755        let rhs = TensorMapArg {
756            tensor: rhs.as_tensor_arg(problem.rhs_line_size),
757            metadata: meta_rhs,
758        };
759
760        TensorMapInputsLaunch::new(lhs, rhs)
761    }
762}
763
764#[cube]
765impl MatmulArgs for TensorMapArgs {
766    type Input<EI: Numeric> = TensorMapInputs<EI>;
767    type Output<EO: Numeric> = Tensor<Line<EO>>;
768    type State<EI: Numeric, EO: Numeric> = (
769        *const TensorMap<EI>,
770        *const TensorMap<EI>,
771        *mut Tensor<Line<EO>>,
772    );
773
774    fn init_state<EI: Numeric, EO: Numeric>(
775        input: &Self::Input<EI>,
776        output: &mut Self::Output<EO>,
777    ) -> Self::State<EI, EO> {
778        (&input.lhs, &input.rhs, output)
779    }
780
781    fn read_lhs<EI: Numeric, EO: Numeric>(
782        _state: &Self::State<EI, EO>,
783        _coordinate: u32,
784    ) -> Line<EI> {
785        unimplemented!("Can't directly read from TensorMap")
786    }
787
788    fn read_rhs<EI: Numeric, EO: Numeric>(
789        _state: &Self::State<EI, EO>,
790        _coordinate: u32,
791    ) -> Line<EI> {
792        unimplemented!("Can't directly read from TensorMap")
793    }
794
795    #[allow(unused)]
796    fn read_window_lhs<EI: Numeric, EO: Numeric>(
797        state: &Self::State<EI, EO>,
798        start: u32,
799        end: u32,
800    ) -> Slice<Line<EI>> {
801        unimplemented!("Can't directly read from TensorMap")
802    }
803
804    /// Read the line of the rhs tensor using the state at the given coordinate.
805    #[allow(unused)]
806    fn read_window_rhs<EI: Numeric, EO: Numeric>(
807        state: &Self::State<EI, EO>,
808        start: u32,
809        end: u32,
810    ) -> Slice<Line<EI>> {
811        unimplemented!("Can't directly read from TensorMap")
812    }
813
814    fn as_tensor_map_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> TensorMap<EI> {
815        unsafe { *state.0 }
816    }
817
818    fn as_tensor_map_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> TensorMap<EI> {
819        unsafe { *state.1 }
820    }
821
822    fn shape_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
823        unsafe { (*state.0).shape(dim) }
824    }
825
826    fn shape_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
827        unsafe { (*state.1).shape(dim) }
828    }
829
830    fn shape_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
831        unsafe { &*state.2 }.shape(dim)
832    }
833
834    fn stride_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
835        unsafe { &*state.0 }.stride(dim)
836    }
837
838    fn stride_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
839        unsafe { &*state.1 }.stride(dim)
840    }
841
842    fn stride_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>, dim: u32) -> u32 {
843        unsafe { &*state.2 }.stride(dim)
844    }
845
846    fn write_out<EI: Numeric, EO: Numeric>(
847        state: &mut Self::State<EI, EO>,
848        coordinate: u32,
849        value: Line<EO>,
850    ) {
851        unsafe { (*state.2)[coordinate] = value }
852    }
853
854    fn rank_lhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
855        unimplemented!("Can't read metadata from TensorMap")
856    }
857
858    fn rank_rhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
859        unimplemented!("Can't read metadata from TensorMap")
860    }
861
862    fn rank_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
863        unsafe { (*state.2).rank() }
864    }
865
866    fn len_lhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
867        unimplemented!("Can't read metadata from TensorMap")
868    }
869
870    fn len_rhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
871        unimplemented!("Can't read metadata from TensorMap")
872    }
873
874    fn len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
875        unsafe { (*state.2).len() }
876    }
877
878    fn buffer_len_lhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
879        unimplemented!("Can't read metadata from TensorMap")
880    }
881
882    fn buffer_len_rhs<EI: Numeric, EO: Numeric>(_state: &Self::State<EI, EO>) -> u32 {
883        unimplemented!("Can't read metadata from TensorMap")
884    }
885
886    fn buffer_len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32 {
887        unsafe { (*state.2).buffer_len() }
888    }
889
890    fn quantization<MP: MatmulPrecision>(_state: &Self::State<MP::EI, MP::EO>) -> Quantization<MP> {
891        todo!("Quantized TMA not yet supported")
892    }
893}
894
895mod __input {
896    use super::*;
897
898    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> CubeType for TensorInput<EI, EO, GA> {
899        type ExpandType = TensorInputExpand<EI, EO, GA>;
900    }
901
902    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Clone for TensorInputExpand<EI, EO, GA> {
903        fn clone(&self) -> Self {
904            Self {
905                state: self.state.clone(),
906                ident: self.ident,
907            }
908        }
909    }
910
911    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Init for TensorInputExpand<EI, EO, GA> {
912        fn init(mut self, scope: &mut Scope) -> Self {
913            self.state = self.state.init(scope);
914            self
915        }
916    }
917    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> CubeDebug for TensorInputExpand<EI, EO, GA> {
918        fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
919            self.state.set_debug_name(scope, name);
920        }
921    }
922    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Clone for TensorInput<EI, EO, GA> {
923        fn clone(&self) -> Self {
924            *self
925        }
926    }
927    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Copy for TensorInput<EI, EO, GA> {}
928}
929
930mod __output {
931    use super::*;
932
933    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> CubeType for TensorOutput<EI, EO, GA> {
934        type ExpandType = TensorOutputExpand<EI, EO, GA>;
935    }
936
937    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Clone for TensorOutput<EI, EO, GA> {
938        fn clone(&self) -> Self {
939            *self
940        }
941    }
942
943    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Clone for TensorOutputExpand<EI, EO, GA> {
944        fn clone(&self) -> Self {
945            Self {
946                state: self.state.clone(),
947            }
948        }
949    }
950
951    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Init for TensorOutputExpand<EI, EO, GA> {
952        fn init(mut self, scope: &mut Scope) -> Self {
953            self.state = self.state.init(scope);
954            self
955        }
956    }
957
958    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> CubeDebug for TensorOutputExpand<EI, EO, GA> {
959        fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
960            self.state.set_debug_name(scope, name);
961        }
962    }
963
964    impl<EI: Numeric, EO: Numeric, GA: MatmulArgs> Copy for TensorOutput<EI, EO, GA> {}
965}