Skip to main content

cubek_matmul/launch/
args.rs

1use std::marker::PhantomData;
2
3use cubecl::std::tensor::{
4    View,
5    launch::ViewArg,
6    layout::{Coords1d, VirtualLayout, VirtualLayoutLaunch},
7};
8use cubecl::{
9    prelude::*,
10    zspace::{metadata::Metadata, shape, strides},
11};
12use cubecl::{server::TensorMapMeta, unexpanded};
13use cubek_std::{InputBinding, MatrixLayout, stage::SwizzleMode};
14
15use crate::components::global::memory::{
16    BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutConfig, GlobalLayoutLaunch,
17    GlobalScaleLayout, NoopLayout, NoopLayoutLaunch, SimpleTmaGlobalLayout,
18    SimpleTmaGlobalLayoutLaunch,
19};
20use crate::{
21    definition::{Blueprint as _, MatmulElems, MatmulProblem, MatmulVectorSizes},
22    routines::Routine,
23};
24
25define_scalar!(pub Lhs);
26define_scalar!(pub Rhs);
27define_scalar!(pub Acc);
28
29define_size!(pub LhsSize);
30define_size!(pub RhsSize);
31define_size!(pub AccSize);
32
33/// Input argument
34pub type InputArg<MA> =
35    <MA as MatmulArgs>::Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>;
36
37/// Output argument
38pub type OutputArg<MA> = <MA as MatmulArgs>::Output<Vector<Acc, AccSize>>;
39
40/// Config argument
41pub type ConfigArg<MA> = <MA as MatmulArgs>::Config;
42
43/// Input runtime argument
44pub type InputRuntimeArg<MA, R> = <InputArg<MA> as LaunchArg>::RuntimeArg<R>;
45
46/// Config runtime argument
47pub type ConfigRuntimeArg<MA, R> = <ConfigArg<MA> as LaunchArg>::RuntimeArg<R>;
48
49/// Output runtime argument
50pub type OutputRuntimeArg<MA, R> = <OutputArg<MA> as LaunchArg>::RuntimeArg<R>;
51
52pub type BatchedCoords = (usize, u32, u32);
53
54/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
55/// output (not fused).
56pub trait ConcreteInputsFactory<A: Routine<()>>: LaunchArg {
57    #[allow(clippy::too_many_arguments)]
58    fn create<R: Runtime>(
59        lhs: InputBinding<R>,
60        rhs: InputBinding<R>,
61        blueprint: &A::Blueprint,
62        problem: &MatmulProblem,
63        vector_sizes: &MatmulVectorSizes,
64        dtypes: &MatmulElems,
65    ) -> Self::RuntimeArg<R>;
66}
67
68/// Create the output runtime argument for a matmul kernel that works on concrete inputs and
69/// output (not fused).
70pub trait ConcreteOutputFactory<A: Routine<()>>: LaunchArg {
71    #[allow(clippy::too_many_arguments)]
72    fn create<R: Runtime>(
73        out: TensorBinding<R>,
74        blueprint: &A::Blueprint,
75        problem: &MatmulProblem,
76        vector_sizes: &MatmulVectorSizes,
77        dtypes: &MatmulElems,
78    ) -> Self::RuntimeArg<R>;
79}
80
81pub trait RuntimeConfig: LaunchArg + CubeType + Clone + Send + Sync {}
82impl<T: LaunchArg + CubeType + Clone + Send + Sync> RuntimeConfig for T {}
83
84#[cube]
85/// Arguments for the matrix multiplication algorithm.
86pub trait MatmulArgs: Send + Sync + 'static + Clone {
87    /// Type used for the input.
88    type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>: LaunchArg + CubeType;
89
90    /// Type used for the output.
91    type Output<EO: CubePrimitive>: LaunchArg + CubeType;
92
93    /// Type used for runtime configuration.
94    type Config: RuntimeConfig;
95
96    /// Inner state that is used to create tensor inputs and
97    /// tensor outputs.
98    type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>: CubeType;
99
100    /// Init the state.
101    fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
102        input: &Self::Input<Lhs, Rhs, EO>,
103        output: &mut Self::Output<EO>,
104        config: Self::Config,
105        #[comptime] lhs_layout_config: GlobalLayoutConfig,
106        #[comptime] rhs_layout_config: GlobalLayoutConfig,
107        #[comptime] out_layout_config: GlobalLayoutConfig,
108    ) -> Self::State<Lhs, Rhs, EO>;
109
110    fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
111        _state: &Self::State<Lhs, Rhs, EO>,
112    ) -> View<Lhs, BatchedCoords> {
113        unexpanded!()
114    }
115    fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
116        _state: &Self::State<Lhs, Rhs, EO>,
117        _batch: usize,
118    ) -> usize {
119        unexpanded!()
120    }
121    fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
122        _state: &Self::State<Lhs, Rhs, EO>,
123    ) -> View<Rhs, BatchedCoords> {
124        unexpanded!()
125    }
126    fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
127        _state: &Self::State<Lhs, Rhs, EO>,
128        _batch: usize,
129    ) -> usize {
130        unexpanded!()
131    }
132    fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
133        _state: &Self::State<Lhs, Rhs, EO>,
134    ) -> ComptimeOption<View<EO, BatchedCoords>> {
135        unexpanded!()
136    }
137    fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
138        _state: &Self::State<Lhs, Rhs, EO>,
139        _batch: usize,
140    ) -> usize {
141        unexpanded!()
142    }
143    fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
144        _state: &mut Self::State<Lhs, Rhs, EO>,
145    ) -> View<EO, BatchedCoords, ReadWrite> {
146        unexpanded!()
147    }
148    fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
149        _state: &Self::State<Lhs, Rhs, EO>,
150        _batch: usize,
151    ) -> usize {
152        unexpanded!()
153    }
154
155    fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
156        _state: &Self::State<Lhs, Rhs, EO>,
157    ) -> Self::Config {
158        unexpanded!()
159    }
160}
161
162#[derive(Clone, Copy)]
163/// Identification of the tensor input.
164pub enum TensorInputIdent {
165    Lhs,
166    Rhs,
167}
168
169#[derive(Clone)]
170/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensors.
171///
172/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
173pub struct TensorArgs<Config: RuntimeConfig = ()> {
174    _config: PhantomData<Config>,
175}
176
177#[derive(CubeLaunch, CubeType, Clone, Copy)]
178/// Input representation for [TensorArgs] implementing [MatmulArgs].
179pub struct TensorInputs<Lhs: CubePrimitive, Rhs: CubePrimitive, Acc: CubePrimitive> {
180    /// The lhs tensor.
181    lhs_batch: VirtualLayout<Coords1d, Coords1d>,
182    lhs: View<Lhs, BatchedCoords>,
183    /// The rhs tensor.
184    rhs_batch: VirtualLayout<Coords1d, Coords1d>,
185    rhs: View<Rhs, BatchedCoords>,
186    /// The tensor for loading the accumulator, if present
187    acc_batch: ComptimeOption<VirtualLayout<Coords1d, Coords1d>>,
188    acc: ComptimeOption<View<Acc, BatchedCoords>>,
189}
190
191impl<Lhs: CubePrimitive, Rhs: CubePrimitive, Acc: CubePrimitive, A: Routine<()>>
192    ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, Acc>
193{
194    fn create<R: Runtime>(
195        lhs: InputBinding<R>,
196        rhs: InputBinding<R>,
197        blueprint: &A::Blueprint,
198        problem: &MatmulProblem,
199        vector_sizes: &MatmulVectorSizes,
200        _dtypes: &MatmulElems,
201    ) -> Self::RuntimeArg<R> {
202        let view = |handle: InputBinding<R>, config: GlobalLayoutConfig, vector_size| match handle {
203            InputBinding::Normal(handle, _dtype) => {
204                let layout = GlobalLayoutLaunch::from_handle(&handle, vector_size, config);
205                ViewArg::new_tensor::<GlobalLayout>(handle.into_tensor_arg(), layout)
206            }
207            InputBinding::Quantized {
208                data,
209                scale,
210                shape,
211                scheme,
212                ..
213            } => {
214                let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
215                    &data,
216                    &scale,
217                    &shape,
218                    problem,
219                    scheme,
220                    vector_size,
221                    config,
222                );
223                let data_view =
224                    ViewArg::new_tensor::<GlobalLayout>(data.into_tensor_arg(), data_layout);
225                let scales_view = ViewArg::new_tensor::<GlobalScaleLayout>(
226                    scale.into_tensor_arg(),
227                    scales_layout,
228                );
229                ViewArg::new_quantized(data_view, scales_view, scheme)
230            }
231        };
232        let batch_layout = |handle: &InputBinding<R>| match handle {
233            InputBinding::Normal(handle, _dtype) => {
234                let layout = BatchLayoutLaunch::from_handle(handle, problem);
235                VirtualLayoutLaunch::new::<BatchLayout>(layout)
236            }
237            InputBinding::Quantized { .. } => {
238                VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
239            }
240        };
241
242        TensorInputsLaunch::new(
243            batch_layout(&lhs),
244            view(lhs, blueprint.lhs_global_layout_config(), vector_sizes.lhs),
245            batch_layout(&rhs),
246            view(rhs, blueprint.rhs_global_layout_config(), vector_sizes.rhs),
247            ComptimeOptionArgs::None,
248            ComptimeOptionArgs::None,
249        )
250    }
251}
252
253#[derive(CubeType, CubeLaunch, Clone, Copy)]
254pub struct TensorOutput<EG: CubePrimitive> {
255    view: View<EG, BatchedCoords, ReadWrite>,
256    batch: VirtualLayout<Coords1d, Coords1d>,
257}
258
259impl<EG: CubePrimitive, A: Routine<()>> ConcreteOutputFactory<A> for TensorOutput<EG> {
260    fn create<R: Runtime>(
261        out: TensorBinding<R>,
262        blueprint: &A::Blueprint,
263        problem: &MatmulProblem,
264        vector_sizes: &MatmulVectorSizes,
265        _dtypes: &MatmulElems,
266    ) -> Self::RuntimeArg<R> {
267        let layout = GlobalLayoutLaunch::from_handle(
268            &out,
269            vector_sizes.out,
270            blueprint.out_global_layout_config(),
271        );
272        let batch = BatchLayoutLaunch::from_handle(&out, problem);
273        let view = ViewArg::new_tensor::<GlobalLayout>(out.into_tensor_arg(), layout);
274        TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
275    }
276}
277
278#[cube]
279impl<Config: RuntimeConfig> MatmulArgs for TensorArgs<Config> {
280    type Output<EO: CubePrimitive> = TensorOutput<EO>;
281    type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
282        TensorInputs<Lhs, Rhs, EO>;
283    type Config = Config;
284    type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
285        (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
286
287    fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
288        input: &Self::Input<Lhs, Rhs, EO>,
289        output: &mut Self::Output<EO>,
290        config: Self::Config,
291        #[comptime] _lhs_layout_config: GlobalLayoutConfig,
292        #[comptime] _rhs_layout_config: GlobalLayoutConfig,
293        #[comptime] _out_layout_config: GlobalLayoutConfig,
294    ) -> Self::State<Lhs, Rhs, EO> {
295        (*input, *output, config)
296    }
297
298    fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
299        state: &Self::State<Lhs, Rhs, EO>,
300    ) -> View<Lhs, BatchedCoords> {
301        state.0.lhs
302    }
303
304    fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
305        state: &Self::State<Lhs, Rhs, EO>,
306        batch: usize,
307    ) -> usize {
308        state.0.lhs_batch.to_source_pos(batch)
309    }
310
311    fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
312        state: &Self::State<Lhs, Rhs, EO>,
313    ) -> View<Rhs, BatchedCoords> {
314        state.0.rhs
315    }
316
317    fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
318        state: &Self::State<Lhs, Rhs, EO>,
319        batch: usize,
320    ) -> usize {
321        state.0.rhs_batch.to_source_pos(batch)
322    }
323
324    fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
325        state: &Self::State<Lhs, Rhs, EO>,
326    ) -> ComptimeOption<View<EO, BatchedCoords>> {
327        state.0.acc
328    }
329
330    fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
331        state: &Self::State<Lhs, Rhs, EO>,
332        batch: usize,
333    ) -> usize {
334        #[comptime]
335        #[comptime]
336        match state.0.acc_batch {
337            ComptimeOption::Some(layout) => layout.to_source_pos(batch),
338            ComptimeOption::None => batch,
339        }
340    }
341
342    fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
343        state: &mut Self::State<Lhs, Rhs, EO>,
344    ) -> View<EO, BatchedCoords, ReadWrite> {
345        state.1.view
346    }
347
348    fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
349        state: &Self::State<Lhs, Rhs, EO>,
350        batch: usize,
351    ) -> usize {
352        state.1.batch.to_source_pos(batch)
353    }
354
355    fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
356        state: &Self::State<Lhs, Rhs, EO>,
357    ) -> Self::Config {
358        state.2.clone()
359    }
360}
361
362#[derive(Clone)]
363/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensor maps.
364///
365/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
366pub struct TensorMapArgs<Config: RuntimeConfig = ()> {
367    _config: PhantomData<Config>,
368}
369
370#[derive(CubeLaunch, CubeType, Clone, Copy)]
371/// Input representation for [TensorArgs] implementing [MatmulArgs].
372pub struct TensorMapInputs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> {
373    /// The lhs tensor.
374    pub lhs: View<Lhs, BatchedCoords>,
375    /// The rhs tensor.
376    pub rhs: View<Rhs, BatchedCoords>,
377    /// The accumulator
378    pub acc: ComptimeOption<View<EO, BatchedCoords>>,
379    /// The accumulator batch layout
380    pub acc_batch: ComptimeOption<VirtualLayout<Coords1d, Coords1d>>,
381}
382
383impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<()>>
384    ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
385{
386    fn create<R: Runtime>(
387        lhs_handle: InputBinding<R>,
388        rhs_handle: InputBinding<R>,
389        blueprint: &A::Blueprint,
390        problem: &MatmulProblem,
391        _vector_sizes: &MatmulVectorSizes,
392        dtypes: &MatmulElems,
393    ) -> Self::RuntimeArg<R> {
394        let lhs = lhs_handle.into_data();
395        let rhs = rhs_handle.into_data();
396
397        let tiling_scheme = blueprint.tiling_scheme();
398        let stage_m = tiling_scheme.elements_per_stage_along_m();
399        let stage_n = tiling_scheme.elements_per_stage_along_n();
400        let stage_k = tiling_scheme.elements_per_stage_along_k();
401
402        // Loaders use dynamic layout based on swizzle setting. For no swizzle, contiguous tiles are
403        // loaded and TMA loads single tile wide columns.
404        // For swizzled, bank conflicts aren't an issue so the tile size is the full stage.
405        let stage_size_lhs = match blueprint.swizzle_modes().lhs {
406            SwizzleMode::None => match problem.lhs_layout {
407                MatrixLayout::RowMajor => {
408                    shape![1, stage_m as usize, tiling_scheme.tile_size.k as usize]
409                }
410                MatrixLayout::ColMajor => {
411                    shape![1, stage_k as usize, tiling_scheme.tile_size.m as usize]
412                }
413            },
414            _ => match problem.lhs_layout {
415                MatrixLayout::RowMajor => {
416                    shape![1, stage_m as usize, stage_k as usize]
417                }
418                MatrixLayout::ColMajor => {
419                    shape![1, stage_k as usize, stage_m as usize]
420                }
421            },
422        };
423        let stage_size_rhs = match blueprint.swizzle_modes().rhs {
424            SwizzleMode::None => match problem.rhs_layout {
425                MatrixLayout::RowMajor => {
426                    shape![1, stage_k as usize, tiling_scheme.tile_size.n as usize]
427                }
428                MatrixLayout::ColMajor => {
429                    shape![1, stage_n as usize, tiling_scheme.tile_size.k as usize]
430                }
431            },
432            _ => match problem.rhs_layout {
433                MatrixLayout::RowMajor => {
434                    shape![1, stage_k as usize, stage_n as usize]
435                }
436                MatrixLayout::ColMajor => {
437                    shape![1, stage_n as usize, stage_k as usize]
438                }
439            },
440        };
441
442        let lhs_rank = lhs.shape.len();
443        let mut lhs_shape = shape![
444            problem.lhs_batches.iter().product(),
445            lhs.shape[lhs_rank - 2],
446            lhs.shape[lhs_rank - 1],
447        ];
448        let mut lhs_strides = if lhs_rank > 2 {
449            lhs.strides[lhs_rank - 3..].into()
450        } else {
451            strides![lhs.strides[0], lhs.strides[1]]
452        };
453
454        let rhs_rank = rhs.shape.len();
455        let mut rhs_shape = shape![
456            problem.rhs_batches.iter().product(),
457            rhs.shape[rhs_rank - 2],
458            rhs.shape[rhs_rank - 1],
459        ];
460        let mut rhs_strides = if rhs_rank > 2 {
461            rhs.strides[rhs_rank - 3..].into()
462        } else {
463            strides![rhs.strides[0], rhs.strides[1]]
464        };
465
466        let mut lhs_transposed = false;
467        let mut rhs_transposed = false;
468
469        let lhs_rank = lhs_strides.len();
470        let rhs_rank = rhs_strides.len();
471
472        // TMA assumes the last stride is contiguous and won't even take it, so we need to map it
473        // with transposed shape and stride. Tensor metadata still has the normal layout.
474        if matches!(problem.lhs_layout, MatrixLayout::ColMajor) {
475            lhs_shape.swap(2, 1);
476            lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
477            lhs_transposed = true;
478        }
479        if matches!(problem.rhs_layout, MatrixLayout::ColMajor) {
480            rhs_shape.swap(2, 1);
481            rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
482            rhs_transposed = true;
483        }
484
485        // Insert batch stride after swap so we can easily get the non-contiguous stride
486        if lhs_rank == 2 {
487            let stride = lhs_strides[0];
488            lhs_strides.insert(0, stride);
489        }
490        if rhs_rank == 2 {
491            let stride = rhs_strides[0];
492            rhs_strides.insert(0, stride);
493        }
494
495        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
496        // It shouldn't matter, but it's better to be safe.
497        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked().storage_type() {
498            tf32::as_type_native_unchecked().storage_type()
499        } else {
500            dtypes.lhs_stage
501        };
502        let rhs_elem = if dtypes.rhs_stage == f32::as_type_native_unchecked().storage_type() {
503            tf32::as_type_native_unchecked().storage_type()
504        } else {
505            dtypes.rhs_stage
506        };
507
508        let meta_lhs = TensorMapMeta {
509            format: TensorMapFormat::Tiled(TiledArgs {
510                tile_size: stage_size_lhs,
511            }),
512            metadata: Metadata::new(lhs_shape.clone(), lhs_strides),
513            elem_stride: strides![1, 1, 1],
514            interleave: TensorMapInterleave::None,
515            swizzle: blueprint.swizzle_modes().lhs.into(),
516            prefetch: TensorMapPrefetch::None,
517            oob_fill: OobFill::Zero,
518            storage_ty: lhs_elem,
519        };
520
521        let meta_rhs = TensorMapMeta {
522            format: TensorMapFormat::Tiled(TiledArgs {
523                tile_size: stage_size_rhs,
524            }),
525            metadata: Metadata::new(rhs_shape.clone(), rhs_strides),
526            elem_stride: strides![1, 1, 1],
527            interleave: TensorMapInterleave::None,
528            swizzle: blueprint.swizzle_modes().rhs.into(),
529            prefetch: TensorMapPrefetch::None,
530            oob_fill: OobFill::Zero,
531            storage_ty: rhs_elem,
532        };
533
534        let lhs = TensorMapArg {
535            tensor: lhs.into_tensor_arg(),
536            metadata: meta_lhs,
537            _kind: PhantomData,
538        };
539        let rhs = TensorMapArg {
540            tensor: rhs.into_tensor_arg(),
541            metadata: meta_rhs,
542            _kind: PhantomData,
543        };
544
545        let view = |buffer, shape: &[usize], transposed| {
546            let batches = shape[0];
547            let (rows, cols) = match transposed {
548                true => (shape[2] as u32, shape[1] as u32),
549                false => (shape[1] as u32, shape[2] as u32),
550            };
551            let shape = (batches, rows, cols);
552            let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
553            ViewArg::new_tensor_map_tiled::<SimpleTmaGlobalLayout>(buffer, layout)
554        };
555
556        TensorMapInputsLaunch::new(
557            view(lhs, &lhs_shape, lhs_transposed),
558            view(rhs, &rhs_shape, rhs_transposed),
559            ComptimeOptionArgs::None,
560            ComptimeOptionArgs::None,
561        )
562    }
563}
564
565#[cube]
566impl<Config: RuntimeConfig> MatmulArgs for TensorMapArgs<Config> {
567    type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
568        TensorMapInputs<Lhs, Rhs, EO>;
569    type Output<EO: CubePrimitive> = TensorOutput<EO>;
570    type Config = Config;
571    type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
572        (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
573
574    fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
575        input: &Self::Input<Lhs, Rhs, EO>,
576        output: &mut Self::Output<EO>,
577        config: Self::Config,
578        #[comptime] _lhs_layout_config: GlobalLayoutConfig,
579        #[comptime] _rhs_layout_config: GlobalLayoutConfig,
580        #[comptime] _out_layout_config: GlobalLayoutConfig,
581    ) -> Self::State<Lhs, Rhs, EO> {
582        (*input, *output, config)
583    }
584
585    fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
586        state: &Self::State<Lhs, Rhs, EO>,
587    ) -> View<Lhs, BatchedCoords> {
588        state.0.lhs
589    }
590
591    fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
592        _state: &Self::State<Lhs, Rhs, EO>,
593        batch: usize,
594    ) -> usize {
595        batch
596    }
597
598    fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
599        state: &Self::State<Lhs, Rhs, EO>,
600    ) -> View<Rhs, BatchedCoords> {
601        state.0.rhs
602    }
603
604    fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
605        _state: &Self::State<Lhs, Rhs, EO>,
606        batch: usize,
607    ) -> usize {
608        batch
609    }
610
611    fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
612        state: &Self::State<Lhs, Rhs, EO>,
613    ) -> ComptimeOption<View<EO, BatchedCoords>> {
614        state.0.acc
615    }
616
617    fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
618        state: &Self::State<Lhs, Rhs, EO>,
619        batch: usize,
620    ) -> usize {
621        #[comptime]
622        #[comptime]
623        match state.0.acc_batch {
624            ComptimeOption::Some(layout) => layout.to_source_pos(batch),
625            ComptimeOption::None => batch,
626        }
627    }
628
629    fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
630        state: &mut Self::State<Lhs, Rhs, EO>,
631    ) -> View<EO, BatchedCoords, ReadWrite> {
632        state.1.view
633    }
634
635    fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
636        state: &Self::State<Lhs, Rhs, EO>,
637        batch: usize,
638    ) -> usize {
639        state.1.batch.to_source_pos(batch)
640    }
641
642    fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
643        state: &Self::State<Lhs, Rhs, EO>,
644    ) -> Self::Config {
645        state.2.clone()
646    }
647}