Skip to main content

cubek_matmul/launch/
args.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl::std::{
5    CubeOption, CubeOptionArgs, CubeOptionExpand,
6    tensor::{
7        View,
8        launch::ViewArg,
9        layout::{Coords1d, VirtualLayout, VirtualLayoutLaunch},
10    },
11};
12use cubecl::{server::TensorMapMeta, unexpanded};
13
14use crate::components::{
15    global::memory::{
16        BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutConfig, GlobalLayoutLaunch,
17        GlobalScaleLayout, NoopLayout, NoopLayoutLaunch, SimpleTmaGlobalLayout,
18        SimpleTmaGlobalLayoutLaunch,
19    },
20    stage::SwizzleMode,
21};
22use crate::definition::{
23    self, Blueprint as _, MatmulElems, MatmulLineSizes, MatmulProblem, TilingBlueprint,
24};
25use crate::launch::handle::MatmulInputHandleRef;
26use crate::routines::Routine;
27
28/// Input argument
29pub type InputArg<MA> =
30    <MA as MatmulArgs>::Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>;
31
32/// Output argument
33pub type OutputArg<MA> = <MA as MatmulArgs>::Output<NumericExpand<2>>;
34
35/// Config argument
36pub type ConfigArg<MA> = <MA as MatmulArgs>::Config;
37
38/// Input runtime argument
39pub type InputRuntimeArg<'a, MA, R> = <InputArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
40
41/// Config runtime argument
42pub type ConfigRuntimeArg<'a, MA, R> = <ConfigArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
43
44/// Output runtime argument
45pub type OutputRuntimeArg<'a, MA, R> = <OutputArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
46
47pub type BatchedCoords = (usize, u32, u32);
48
49/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
50/// output (not fused).
51pub trait ConcreteInputsFactory<A: Routine<()>>: LaunchArg {
52    #[allow(clippy::too_many_arguments)]
53    fn create<'a, R: Runtime>(
54        client: &ComputeClient<R>,
55        lhs: &'a MatmulInputHandleRef<'a, R>,
56        rhs: &'a MatmulInputHandleRef<'a, R>,
57        blueprint: &A::Blueprint,
58        problem: &MatmulProblem,
59        line_sizes: &MatmulLineSizes,
60        dtypes: &MatmulElems,
61    ) -> Self::RuntimeArg<'a, R>;
62}
63
64/// Create the output runtime argument for a matmul kernel that works on concrete inputs and
65/// output (not fused).
66pub trait ConcreteOutputFactory<A: Routine<()>>: LaunchArg {
67    #[allow(clippy::too_many_arguments)]
68    fn create<'a, R: Runtime>(
69        client: &ComputeClient<R>,
70        out: &'a TensorHandleRef<'a, R>,
71        blueprint: &A::Blueprint,
72        problem: &MatmulProblem,
73        line_sizes: &MatmulLineSizes,
74        dtypes: &MatmulElems,
75    ) -> Self::RuntimeArg<'a, R>;
76}
77
78pub trait RuntimeConfig: LaunchArg + CubeType + Clone + Send + Sync {}
79impl<T: LaunchArg + CubeType + Clone + Send + Sync> RuntimeConfig for T {}
80
81#[cube]
82/// Arguments for the matrix multiplication algorithm.
83pub trait MatmulArgs: Send + Sync + 'static + Clone {
84    /// Type used for the input.
85    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: LaunchArg + CubeType;
86
87    /// Type used for the output.
88    type Output<EO: Numeric>: LaunchArg + CubeType;
89
90    /// Type used for runtime configuration.
91    type Config: RuntimeConfig;
92
93    /// Inner state that is used to create tensor inputs and
94    /// tensor outputs.
95    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: CubeType;
96
97    /// Init the state.
98    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
99        input: &Self::Input<Lhs, Rhs, EO>,
100        output: &mut Self::Output<EO>,
101        config: Self::Config,
102        #[comptime] lhs_layout_config: GlobalLayoutConfig,
103        #[comptime] rhs_layout_config: GlobalLayoutConfig,
104        #[comptime] out_layout_config: GlobalLayoutConfig,
105    ) -> Self::State<Lhs, Rhs, EO>;
106
107    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
108        _state: &Self::State<Lhs, Rhs, EO>,
109    ) -> View<Line<Lhs>, BatchedCoords> {
110        unexpanded!()
111    }
112    fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
113        _state: &Self::State<Lhs, Rhs, EO>,
114        _batch: usize,
115    ) -> usize {
116        unexpanded!()
117    }
118    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
119        _state: &Self::State<Lhs, Rhs, EO>,
120    ) -> View<Line<Rhs>, BatchedCoords> {
121        unexpanded!()
122    }
123    fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
124        _state: &Self::State<Lhs, Rhs, EO>,
125        _batch: usize,
126    ) -> usize {
127        unexpanded!()
128    }
129    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
130        _state: &Self::State<Lhs, Rhs, EO>,
131    ) -> CubeOption<View<Line<EO>, BatchedCoords>> {
132        unexpanded!()
133    }
134    fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
135        _state: &Self::State<Lhs, Rhs, EO>,
136        _batch: usize,
137    ) -> usize {
138        unexpanded!()
139    }
140    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
141        _state: &mut Self::State<Lhs, Rhs, EO>,
142    ) -> View<Line<EO>, BatchedCoords, ReadWrite> {
143        unexpanded!()
144    }
145    fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
146        _state: &Self::State<Lhs, Rhs, EO>,
147        _batch: usize,
148    ) -> usize {
149        unexpanded!()
150    }
151
152    fn runtime_config<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
153        _state: &Self::State<Lhs, Rhs, EO>,
154    ) -> Self::Config {
155        unexpanded!()
156    }
157}
158
159#[derive(Clone, Copy)]
160/// Identification of the tensor input.
161pub enum TensorInputIdent {
162    Lhs,
163    Rhs,
164}
165
166#[derive(Clone)]
167/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensors.
168///
169/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
170pub struct TensorArgs<Config: RuntimeConfig = ()> {
171    _config: PhantomData<Config>,
172}
173
174#[derive(CubeLaunch, CubeType, Clone, Copy)]
175/// Input representation for [TensorArgs] implementing [MatmulArgs].
176pub struct TensorInputs<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> {
177    /// The lhs tensor.
178    lhs: View<Line<Lhs>, BatchedCoords>,
179    lhs_batch: VirtualLayout<Coords1d, Coords1d>,
180    /// The rhs tensor.
181    rhs: View<Line<Rhs>, BatchedCoords>,
182    rhs_batch: VirtualLayout<Coords1d, Coords1d>,
183    /// The tensor for loading the accumulator, if present
184    acc: CubeOption<View<Line<Acc>, BatchedCoords>>,
185    acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
186}
187
188impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric, A: Routine<()>> ConcreteInputsFactory<A>
189    for TensorInputs<Lhs, Rhs, Acc>
190{
191    fn create<'a, R: Runtime>(
192        client: &ComputeClient<R>,
193        lhs: &'a MatmulInputHandleRef<'a, R>,
194        rhs: &'a MatmulInputHandleRef<'a, R>,
195        blueprint: &A::Blueprint,
196        problem: &MatmulProblem,
197        line_sizes: &MatmulLineSizes,
198        _dtypes: &MatmulElems,
199    ) -> Self::RuntimeArg<'a, R> {
200        let view = |handle: &'a MatmulInputHandleRef<'a, R>,
201                    config: GlobalLayoutConfig,
202                    line_size| match handle {
203            MatmulInputHandleRef::Normal(handle, _dtype) => {
204                let layout = GlobalLayoutLaunch::from_handle(handle, line_size, config);
205                ViewArg::new::<GlobalLayout>(handle.as_array_arg(line_size), layout)
206            }
207            MatmulInputHandleRef::Quantized {
208                data,
209                scale,
210                shape,
211                scheme,
212                ..
213            } => {
214                let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
215                    client, data, scale, shape, problem, **scheme, line_size, config,
216                );
217                let data_view =
218                    ViewArg::new::<GlobalLayout>(data.as_array_arg(line_size), data_layout);
219                let scales_view =
220                    ViewArg::new::<GlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
221                ViewArg::new_quantized(data_view, scales_view, **scheme)
222            }
223        };
224        let batch_layout = |handle: &'a MatmulInputHandleRef<'a, R>| match handle {
225            MatmulInputHandleRef::Normal(handle, _dtype) => {
226                let layout = BatchLayoutLaunch::from_handle(client, handle, problem);
227                VirtualLayoutLaunch::new::<BatchLayout>(layout)
228            }
229            MatmulInputHandleRef::Quantized { .. } => {
230                VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
231            }
232        };
233
234        TensorInputsLaunch::new(
235            view(lhs, blueprint.lhs_global_layout_config(), line_sizes.lhs),
236            batch_layout(lhs),
237            view(rhs, blueprint.rhs_global_layout_config(), line_sizes.rhs),
238            batch_layout(rhs),
239            CubeOptionArgs::None,
240            CubeOptionArgs::None,
241        )
242    }
243}
244
245#[derive(CubeType, CubeLaunch, Clone, Copy)]
246pub struct TensorOutput<EG: Numeric> {
247    view: View<Line<EG>, BatchedCoords, ReadWrite>,
248    batch: VirtualLayout<Coords1d, Coords1d>,
249}
250
251impl<EG: Numeric, A: Routine<()>> ConcreteOutputFactory<A> for TensorOutput<EG> {
252    fn create<'a, R: Runtime>(
253        client: &ComputeClient<R>,
254        out: &'a TensorHandleRef<'a, R>,
255        blueprint: &A::Blueprint,
256        problem: &MatmulProblem,
257        line_sizes: &MatmulLineSizes,
258        _dtypes: &MatmulElems,
259    ) -> Self::RuntimeArg<'a, R> {
260        let layout = GlobalLayoutLaunch::from_handle(
261            out,
262            line_sizes.out,
263            blueprint.out_global_layout_config(),
264        );
265        let batch = BatchLayoutLaunch::from_handle(client, out, problem);
266        let view = ViewArg::new::<GlobalLayout>(out.as_array_arg(line_sizes.out), layout);
267        TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
268    }
269}
270
271#[cube]
272impl<Config: RuntimeConfig> MatmulArgs for TensorArgs<Config> {
273    type Output<EO: Numeric> = TensorOutput<EO>;
274    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorInputs<Lhs, Rhs, EO>;
275    type Config = Config;
276    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
277        (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
278
279    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
280        input: &Self::Input<Lhs, Rhs, EO>,
281        output: &mut Self::Output<EO>,
282        config: Self::Config,
283        #[comptime] _lhs_layout_config: GlobalLayoutConfig,
284        #[comptime] _rhs_layout_config: GlobalLayoutConfig,
285        #[comptime] _out_layout_config: GlobalLayoutConfig,
286    ) -> Self::State<Lhs, Rhs, EO> {
287        (*input, *output, config)
288    }
289
290    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
291        state: &Self::State<Lhs, Rhs, EO>,
292    ) -> View<Line<Lhs>, BatchedCoords> {
293        state.0.lhs
294    }
295
296    fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
297        state: &Self::State<Lhs, Rhs, EO>,
298        batch: usize,
299    ) -> usize {
300        state.0.lhs_batch.to_source_pos(batch)
301    }
302
303    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
304        state: &Self::State<Lhs, Rhs, EO>,
305    ) -> View<Line<Rhs>, BatchedCoords> {
306        state.0.rhs
307    }
308
309    fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
310        state: &Self::State<Lhs, Rhs, EO>,
311        batch: usize,
312    ) -> usize {
313        state.0.rhs_batch.to_source_pos(batch)
314    }
315
316    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
317        state: &Self::State<Lhs, Rhs, EO>,
318    ) -> CubeOption<View<Line<EO>, BatchedCoords>> {
319        state.0.acc
320    }
321
322    fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
323        state: &Self::State<Lhs, Rhs, EO>,
324        batch: usize,
325    ) -> usize {
326        match state.0.acc_batch {
327            CubeOption::Some(layout) => layout.to_source_pos(batch),
328            CubeOption::None => batch,
329        }
330    }
331
332    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
333        state: &mut Self::State<Lhs, Rhs, EO>,
334    ) -> View<Line<EO>, BatchedCoords, ReadWrite> {
335        state.1.view
336    }
337
338    fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
339        state: &Self::State<Lhs, Rhs, EO>,
340        batch: usize,
341    ) -> usize {
342        state.1.batch.to_source_pos(batch)
343    }
344
345    fn runtime_config<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
346        state: &Self::State<Lhs, Rhs, EO>,
347    ) -> Self::Config {
348        state.2.clone()
349    }
350}
351
352#[derive(Clone)]
353/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensor maps.
354///
355/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
356pub struct TensorMapArgs<Config: RuntimeConfig = ()> {
357    _config: PhantomData<Config>,
358}
359
360#[derive(CubeLaunch, CubeType, Clone, Copy)]
361/// Input representation for [TensorArgs] implementing [MatmulArgs].
362pub struct TensorMapInputs<Lhs: Numeric, Rhs: Numeric, EO: Numeric> {
363    /// The lhs tensor.
364    pub lhs: View<Line<Lhs>, BatchedCoords>,
365    /// The rhs tensor.
366    pub rhs: View<Line<Rhs>, BatchedCoords>,
367    /// The accumulator
368    pub acc: CubeOption<View<Line<EO>, BatchedCoords>>,
369    /// The accumulator batch layout
370    pub acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
371}
372
373impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric, A: Routine<(), Blueprint = TilingBlueprint>>
374    ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
375{
376    fn create<'a, R: Runtime>(
377        _client: &ComputeClient<R>,
378        lhs_handle: &'a MatmulInputHandleRef<'a, R>,
379        rhs_handle: &'a MatmulInputHandleRef<'a, R>,
380        blueprint: &A::Blueprint,
381        problem: &MatmulProblem,
382        line_sizes: &MatmulLineSizes,
383        dtypes: &MatmulElems,
384    ) -> Self::RuntimeArg<'a, R> {
385        let lhs = lhs_handle.data();
386        let rhs = rhs_handle.data();
387
388        let tiling_scheme = blueprint.tiling_scheme;
389        let stage_m = tiling_scheme.elements_per_stage_along_m();
390        let stage_n = tiling_scheme.elements_per_stage_along_n();
391        let stage_k = tiling_scheme.elements_per_stage_along_k();
392
393        // Loaders use dynamic layout based on swizzle setting. For no swizzle, contiguous tiles are
394        // loaded and TMA loads single tile wide columns.
395        // For swizzled, bank conflicts aren't an issue so the tile size is the full stage.
396        let stage_size_lhs = match blueprint.swizzle_modes.lhs {
397            SwizzleMode::None => match problem.lhs_layout {
398                definition::MatrixLayout::RowMajor => {
399                    vec![1, stage_m, tiling_scheme.tile_size.k]
400                }
401                definition::MatrixLayout::ColMajor => {
402                    vec![1, stage_k, tiling_scheme.tile_size.m]
403                }
404            },
405            _ => match problem.lhs_layout {
406                definition::MatrixLayout::RowMajor => {
407                    vec![1, stage_m, stage_k]
408                }
409                definition::MatrixLayout::ColMajor => {
410                    vec![1, stage_k, stage_m]
411                }
412            },
413        };
414        let stage_size_rhs = match blueprint.swizzle_modes.rhs {
415            SwizzleMode::None => match problem.rhs_layout {
416                definition::MatrixLayout::RowMajor => {
417                    vec![1, stage_k, tiling_scheme.tile_size.n]
418                }
419                definition::MatrixLayout::ColMajor => {
420                    vec![1, stage_n, tiling_scheme.tile_size.k]
421                }
422            },
423            _ => match problem.rhs_layout {
424                definition::MatrixLayout::RowMajor => {
425                    vec![1, stage_k, stage_n]
426                }
427                definition::MatrixLayout::ColMajor => {
428                    vec![1, stage_n, stage_k]
429                }
430            },
431        };
432
433        let lhs_rank = lhs.shape.len();
434        let mut lhs_shape = vec![
435            problem.lhs_batches.iter().product(),
436            lhs.shape[lhs_rank - 2],
437            lhs.shape[lhs_rank - 1],
438        ];
439        let mut lhs_strides = if lhs_rank > 2 {
440            lhs.strides[lhs_rank - 3..].to_vec()
441        } else {
442            vec![lhs.strides[0], lhs.strides[1]]
443        };
444
445        let rhs_rank = rhs.shape.len();
446        let mut rhs_shape = vec![
447            problem.rhs_batches.iter().product(),
448            rhs.shape[rhs_rank - 2],
449            rhs.shape[rhs_rank - 1],
450        ];
451        let mut rhs_strides = if rhs_rank > 2 {
452            rhs.strides[rhs_rank - 3..].to_vec()
453        } else {
454            vec![rhs.strides[0], rhs.strides[1]]
455        };
456
457        let mut lhs_transposed = false;
458        let mut rhs_transposed = false;
459
460        let lhs_rank = lhs_strides.len();
461        let rhs_rank = rhs_strides.len();
462
463        // TMA assumes the last stride is contiguous and won't even take it, so we need to map it
464        // with transposed shape and stride. Tensor metadata still has the normal layout.
465        if matches!(problem.lhs_layout, definition::MatrixLayout::ColMajor) {
466            lhs_shape.swap(2, 1);
467            lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
468            lhs_transposed = true;
469        }
470        if matches!(problem.rhs_layout, definition::MatrixLayout::ColMajor) {
471            rhs_shape.swap(2, 1);
472            rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
473            rhs_transposed = true;
474        }
475
476        // Insert batch stride after swap so we can easily get the non-contiguous stride
477        if lhs_rank == 2 {
478            let stride = lhs_strides[0];
479            lhs_strides.insert(0, stride);
480        }
481        if rhs_rank == 2 {
482            let stride = rhs_strides[0];
483            rhs_strides.insert(0, stride);
484        }
485
486        fn swizzle(mode: SwizzleMode) -> TensorMapSwizzle {
487            match mode {
488                SwizzleMode::None => TensorMapSwizzle::None,
489                SwizzleMode::B32 => TensorMapSwizzle::B32,
490                SwizzleMode::B64 => TensorMapSwizzle::B64,
491                SwizzleMode::B128 => TensorMapSwizzle::B128,
492            }
493        }
494
495        let swizzle_lhs = swizzle(blueprint.swizzle_modes.lhs);
496        let swizzle_rhs = swizzle(blueprint.swizzle_modes.rhs);
497
498        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
499        // It shouldn't matter, but it's better to be safe.
500        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
501            tf32::as_type_native_unchecked()
502        } else {
503            dtypes.lhs_stage
504        };
505        let rhs_elem = if dtypes.rhs_stage == f32::as_type_native_unchecked() {
506            tf32::as_type_native_unchecked()
507        } else {
508            dtypes.rhs_stage
509        };
510
511        let meta_lhs = TensorMapMeta {
512            format: TensorMapFormat::Tiled(TiledArgs {
513                tile_size: stage_size_lhs,
514            }),
515            rank: 3,
516            shape: lhs_shape.clone(),
517            strides: lhs_strides,
518            elem_stride: vec![1, 1, 1],
519            interleave: TensorMapInterleave::None,
520            swizzle: swizzle_lhs,
521            prefetch: TensorMapPrefetch::None,
522            oob_fill: OobFill::Zero,
523            storage_ty: lhs_elem,
524        };
525
526        let meta_rhs = TensorMapMeta {
527            format: TensorMapFormat::Tiled(TiledArgs {
528                tile_size: stage_size_rhs,
529            }),
530            rank: 3,
531            shape: rhs_shape.clone(),
532            strides: rhs_strides,
533            elem_stride: vec![1, 1, 1],
534            interleave: TensorMapInterleave::None,
535            swizzle: swizzle_rhs,
536            prefetch: TensorMapPrefetch::None,
537            oob_fill: OobFill::Zero,
538            storage_ty: rhs_elem,
539        };
540
541        let lhs = TensorMapArg {
542            tensor: lhs.as_tensor_arg(line_sizes.lhs),
543            metadata: meta_lhs,
544            _kind: PhantomData,
545        };
546        let rhs = TensorMapArg {
547            tensor: rhs.as_tensor_arg(line_sizes.rhs),
548            metadata: meta_rhs,
549            _kind: PhantomData,
550        };
551
552        let view = |buffer, shape: &[usize], transposed| {
553            let batches = ScalarArg::new(shape[0]);
554            let (rows, cols) = match transposed {
555                true => (
556                    ScalarArg::new(shape[2] as u32),
557                    ScalarArg::new(shape[1] as u32),
558                ),
559                false => (
560                    ScalarArg::new(shape[1] as u32),
561                    ScalarArg::new(shape[2] as u32),
562                ),
563            };
564            let shape = (batches, rows, cols);
565            let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
566            ViewArg::new_tensor_map_tiled::<SimpleTmaGlobalLayout>(buffer, layout)
567        };
568
569        TensorMapInputsLaunch::new(
570            view(lhs, &lhs_shape, lhs_transposed),
571            view(rhs, &rhs_shape, rhs_transposed),
572            CubeOptionArgs::None,
573            CubeOptionArgs::None,
574        )
575    }
576}
577
578#[cube]
579impl<Config: RuntimeConfig> MatmulArgs for TensorMapArgs<Config> {
580    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorMapInputs<Lhs, Rhs, EO>;
581    type Output<EO: Numeric> = TensorOutput<EO>;
582    type Config = Config;
583    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
584        (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
585
586    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
587        input: &Self::Input<Lhs, Rhs, EO>,
588        output: &mut Self::Output<EO>,
589        config: Self::Config,
590        #[comptime] _lhs_layout_config: GlobalLayoutConfig,
591        #[comptime] _rhs_layout_config: GlobalLayoutConfig,
592        #[comptime] _out_layout_config: GlobalLayoutConfig,
593    ) -> Self::State<Lhs, Rhs, EO> {
594        (*input, *output, config)
595    }
596
597    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
598        state: &Self::State<Lhs, Rhs, EO>,
599    ) -> View<Line<Lhs>, BatchedCoords> {
600        state.0.lhs
601    }
602
603    fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
604        _state: &Self::State<Lhs, Rhs, EO>,
605        batch: usize,
606    ) -> usize {
607        batch
608    }
609
610    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
611        state: &Self::State<Lhs, Rhs, EO>,
612    ) -> View<Line<Rhs>, BatchedCoords> {
613        state.0.rhs
614    }
615
616    fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
617        _state: &Self::State<Lhs, Rhs, EO>,
618        batch: usize,
619    ) -> usize {
620        batch
621    }
622
623    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
624        state: &Self::State<Lhs, Rhs, EO>,
625    ) -> CubeOption<View<Line<EO>, BatchedCoords>> {
626        state.0.acc
627    }
628
629    fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
630        state: &Self::State<Lhs, Rhs, EO>,
631        batch: usize,
632    ) -> usize {
633        match state.0.acc_batch {
634            CubeOption::Some(layout) => layout.to_source_pos(batch),
635            CubeOption::None => batch,
636        }
637    }
638
639    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
640        state: &mut Self::State<Lhs, Rhs, EO>,
641    ) -> View<Line<EO>, BatchedCoords, ReadWrite> {
642        state.1.view
643    }
644
645    fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
646        state: &Self::State<Lhs, Rhs, EO>,
647        batch: usize,
648    ) -> usize {
649        state.1.batch.to_source_pos(batch)
650    }
651
652    fn runtime_config<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
653        state: &Self::State<Lhs, Rhs, EO>,
654    ) -> Self::Config {
655        state.2.clone()
656    }
657}