Skip to main content

cubek_matmul/launch/
args.rs

1use std::marker::PhantomData;
2
3use cubecl::std::tensor::{
4    View,
5    launch::ViewArg,
6    layout::{Coords1d, VirtualLayout, VirtualLayoutLaunch},
7};
8use cubecl::unexpanded;
9use cubecl::{
10    prelude::*,
11    zspace::{metadata::Metadata, shape, strides},
12};
13use cubek_std::launch::tma::{remap_storage_for_tma, tma_meta_tiled, transpose_inner_for_tma};
14use cubek_std::{InputBinding, MatrixLayout, stage::SwizzleMode};
15
16use crate::components::global::memory::{
17    BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutConfig, GlobalLayoutLaunch,
18    GlobalScaleLayout, NoopLayout, NoopLayoutLaunch, SimpleTmaGlobalLayout,
19    SimpleTmaGlobalLayoutLaunch,
20};
21use crate::{
22    definition::{Blueprint as _, MatmulElems, MatmulProblem, MatmulVectorSizes},
23    routines::Routine,
24};
25
26define_scalar!(pub Lhs);
27define_scalar!(pub Rhs);
28define_scalar!(pub Acc);
29
30define_size!(pub LhsSize);
31define_size!(pub RhsSize);
32define_size!(pub AccSize);
33
34/// Input argument
35pub type InputArg<MA> =
36    <MA as MatmulArgs>::Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>;
37
38/// Output argument
39pub type OutputArg<MA> = <MA as MatmulArgs>::Output<Vector<Acc, AccSize>>;
40
41/// Config argument
42pub type ConfigArg<MA> = <MA as MatmulArgs>::Config;
43
44/// Input runtime argument
45pub type InputRuntimeArg<MA, R> = <InputArg<MA> as LaunchArg>::RuntimeArg<R>;
46
47/// Config runtime argument
48pub type ConfigRuntimeArg<MA, R> = <ConfigArg<MA> as LaunchArg>::RuntimeArg<R>;
49
50/// Output runtime argument
51pub type OutputRuntimeArg<MA, R> = <OutputArg<MA> as LaunchArg>::RuntimeArg<R>;
52
53pub type BatchedCoords = (usize, u32, u32);
54
55/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
56/// output (not fused).
57pub trait ConcreteInputsFactory<A: Routine<()>>: LaunchArg {
58    #[allow(clippy::too_many_arguments)]
59    fn create<R: Runtime>(
60        lhs: InputBinding<R>,
61        rhs: InputBinding<R>,
62        blueprint: &A::Blueprint,
63        problem: &MatmulProblem,
64        vector_sizes: &MatmulVectorSizes,
65        dtypes: &MatmulElems,
66    ) -> Self::RuntimeArg<R>;
67}
68
69/// Create the output runtime argument for a matmul kernel that works on concrete inputs and
70/// output (not fused).
71pub trait ConcreteOutputFactory<A: Routine<()>>: LaunchArg {
72    #[allow(clippy::too_many_arguments)]
73    fn create<R: Runtime>(
74        out: TensorBinding<R>,
75        blueprint: &A::Blueprint,
76        problem: &MatmulProblem,
77        vector_sizes: &MatmulVectorSizes,
78        dtypes: &MatmulElems,
79    ) -> Self::RuntimeArg<R>;
80}
81
82pub trait RuntimeConfig: LaunchArg + CubeType + Clone + Send + Sync {}
83impl<T: LaunchArg + CubeType + Clone + Send + Sync> RuntimeConfig for T {}
84
85#[cube]
86/// Arguments for the matrix multiplication algorithm.
87pub trait MatmulArgs: Send + Sync + 'static + Clone {
88    /// Type used for the input.
89    type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>: LaunchArg + CubeType;
90
91    /// Type used for the output.
92    type Output<EO: CubePrimitive>: LaunchArg + CubeType;
93
94    /// Type used for runtime configuration.
95    type Config: RuntimeConfig;
96
97    /// Inner state that is used to create tensor inputs and
98    /// tensor outputs.
99    type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>: CubeType;
100
101    /// Init the state.
102    fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
103        input: &Self::Input<Lhs, Rhs, EO>,
104        output: &mut Self::Output<EO>,
105        config: Self::Config,
106        #[comptime] lhs_layout_config: GlobalLayoutConfig,
107        #[comptime] rhs_layout_config: GlobalLayoutConfig,
108        #[comptime] out_layout_config: GlobalLayoutConfig,
109    ) -> Self::State<Lhs, Rhs, EO>;
110
111    fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
112        _state: &Self::State<Lhs, Rhs, EO>,
113    ) -> View<Lhs, BatchedCoords> {
114        unexpanded!()
115    }
116    fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
117        _state: &Self::State<Lhs, Rhs, EO>,
118        _batch: usize,
119    ) -> usize {
120        unexpanded!()
121    }
122    fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
123        _state: &Self::State<Lhs, Rhs, EO>,
124    ) -> View<Rhs, BatchedCoords> {
125        unexpanded!()
126    }
127    fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
128        _state: &Self::State<Lhs, Rhs, EO>,
129        _batch: usize,
130    ) -> usize {
131        unexpanded!()
132    }
133    fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
134        _state: &Self::State<Lhs, Rhs, EO>,
135    ) -> ComptimeOption<View<EO, BatchedCoords>> {
136        unexpanded!()
137    }
138    fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
139        _state: &Self::State<Lhs, Rhs, EO>,
140        _batch: usize,
141    ) -> usize {
142        unexpanded!()
143    }
144    fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
145        _state: &mut Self::State<Lhs, Rhs, EO>,
146    ) -> View<EO, BatchedCoords, ReadWrite> {
147        unexpanded!()
148    }
149    fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
150        _state: &Self::State<Lhs, Rhs, EO>,
151        _batch: usize,
152    ) -> usize {
153        unexpanded!()
154    }
155
156    fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
157        _state: &Self::State<Lhs, Rhs, EO>,
158    ) -> Self::Config {
159        unexpanded!()
160    }
161}
162
163#[derive(Clone, Copy)]
164/// Identification of the tensor input.
165pub enum TensorInputIdent {
166    Lhs,
167    Rhs,
168}
169
170#[derive(Clone)]
171/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensors.
172///
173/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
174pub struct TensorArgs<Config: RuntimeConfig = ()> {
175    _config: PhantomData<Config>,
176}
177
178#[derive(CubeLaunch, CubeType, Clone, Copy)]
179/// Input representation for [TensorArgs] implementing [MatmulArgs].
180pub struct TensorInputs<Lhs: CubePrimitive, Rhs: CubePrimitive, Acc: CubePrimitive> {
181    /// The lhs tensor.
182    lhs_batch: VirtualLayout<Coords1d, Coords1d>,
183    lhs: View<Lhs, BatchedCoords>,
184    /// The rhs tensor.
185    rhs_batch: VirtualLayout<Coords1d, Coords1d>,
186    rhs: View<Rhs, BatchedCoords>,
187    /// The tensor for loading the accumulator, if present
188    acc_batch: ComptimeOption<VirtualLayout<Coords1d, Coords1d>>,
189    acc: ComptimeOption<View<Acc, BatchedCoords>>,
190}
191
192impl<Lhs: CubePrimitive, Rhs: CubePrimitive, Acc: CubePrimitive, A: Routine<()>>
193    ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, Acc>
194{
195    fn create<R: Runtime>(
196        lhs: InputBinding<R>,
197        rhs: InputBinding<R>,
198        blueprint: &A::Blueprint,
199        problem: &MatmulProblem,
200        vector_sizes: &MatmulVectorSizes,
201        _dtypes: &MatmulElems,
202    ) -> Self::RuntimeArg<R> {
203        let view = |handle: InputBinding<R>, config: GlobalLayoutConfig, vector_size| match handle {
204            InputBinding::Normal(handle, _dtype) => {
205                let layout = GlobalLayoutLaunch::from_handle(&handle, vector_size, config);
206                ViewArg::new_tensor::<GlobalLayout>(handle.into_tensor_arg(), layout)
207            }
208            InputBinding::Quantized {
209                data,
210                scale,
211                shape,
212                scheme,
213                ..
214            } => {
215                let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
216                    &data,
217                    &scale,
218                    &shape,
219                    problem,
220                    scheme,
221                    vector_size,
222                    config,
223                );
224                let data_view =
225                    ViewArg::new_tensor::<GlobalLayout>(data.into_tensor_arg(), data_layout);
226                let scales_view = ViewArg::new_tensor::<GlobalScaleLayout>(
227                    scale.into_tensor_arg(),
228                    scales_layout,
229                );
230                ViewArg::new_quantized(data_view, scales_view, scheme)
231            }
232        };
233        let batch_layout = |handle: &InputBinding<R>| match handle {
234            InputBinding::Normal(handle, _dtype) => {
235                let layout = BatchLayoutLaunch::from_handle(handle, problem);
236                VirtualLayoutLaunch::new::<BatchLayout>(layout)
237            }
238            InputBinding::Quantized { .. } => {
239                VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
240            }
241        };
242
243        TensorInputsLaunch::new(
244            batch_layout(&lhs),
245            view(lhs, blueprint.lhs_global_layout_config(), vector_sizes.lhs),
246            batch_layout(&rhs),
247            view(rhs, blueprint.rhs_global_layout_config(), vector_sizes.rhs),
248            ComptimeOptionArgs::None,
249            ComptimeOptionArgs::None,
250        )
251    }
252}
253
254#[derive(CubeType, CubeLaunch, Clone, Copy)]
255pub struct TensorOutput<EG: CubePrimitive> {
256    view: View<EG, BatchedCoords, ReadWrite>,
257    batch: VirtualLayout<Coords1d, Coords1d>,
258}
259
260impl<EG: CubePrimitive, A: Routine<()>> ConcreteOutputFactory<A> for TensorOutput<EG> {
261    fn create<R: Runtime>(
262        out: TensorBinding<R>,
263        blueprint: &A::Blueprint,
264        problem: &MatmulProblem,
265        vector_sizes: &MatmulVectorSizes,
266        _dtypes: &MatmulElems,
267    ) -> Self::RuntimeArg<R> {
268        let layout = GlobalLayoutLaunch::from_handle(
269            &out,
270            vector_sizes.out,
271            blueprint.out_global_layout_config(),
272        );
273        let batch = BatchLayoutLaunch::from_handle(&out, problem);
274        let view = ViewArg::new_tensor::<GlobalLayout>(out.into_tensor_arg(), layout);
275        TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
276    }
277}
278
279#[cube]
280impl<Config: RuntimeConfig> MatmulArgs for TensorArgs<Config> {
281    type Output<EO: CubePrimitive> = TensorOutput<EO>;
282    type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
283        TensorInputs<Lhs, Rhs, EO>;
284    type Config = Config;
285    type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
286        (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
287
288    fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
289        input: &Self::Input<Lhs, Rhs, EO>,
290        output: &mut Self::Output<EO>,
291        config: Self::Config,
292        #[comptime] _lhs_layout_config: GlobalLayoutConfig,
293        #[comptime] _rhs_layout_config: GlobalLayoutConfig,
294        #[comptime] _out_layout_config: GlobalLayoutConfig,
295    ) -> Self::State<Lhs, Rhs, EO> {
296        (*input, *output, config)
297    }
298
299    fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
300        state: &Self::State<Lhs, Rhs, EO>,
301    ) -> View<Lhs, BatchedCoords> {
302        state.0.lhs
303    }
304
305    fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
306        state: &Self::State<Lhs, Rhs, EO>,
307        batch: usize,
308    ) -> usize {
309        state.0.lhs_batch.to_source_pos(batch)
310    }
311
312    fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
313        state: &Self::State<Lhs, Rhs, EO>,
314    ) -> View<Rhs, BatchedCoords> {
315        state.0.rhs
316    }
317
318    fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
319        state: &Self::State<Lhs, Rhs, EO>,
320        batch: usize,
321    ) -> usize {
322        state.0.rhs_batch.to_source_pos(batch)
323    }
324
325    fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
326        state: &Self::State<Lhs, Rhs, EO>,
327    ) -> ComptimeOption<View<EO, BatchedCoords>> {
328        state.0.acc
329    }
330
331    fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
332        state: &Self::State<Lhs, Rhs, EO>,
333        batch: usize,
334    ) -> usize {
335        #[comptime]
336        #[comptime]
337        match state.0.acc_batch {
338            ComptimeOption::Some(layout) => layout.to_source_pos(batch),
339            ComptimeOption::None => batch,
340        }
341    }
342
343    fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
344        state: &mut Self::State<Lhs, Rhs, EO>,
345    ) -> View<EO, BatchedCoords, ReadWrite> {
346        state.1.view
347    }
348
349    fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
350        state: &Self::State<Lhs, Rhs, EO>,
351        batch: usize,
352    ) -> usize {
353        state.1.batch.to_source_pos(batch)
354    }
355
356    fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
357        state: &Self::State<Lhs, Rhs, EO>,
358    ) -> Self::Config {
359        state.2.clone()
360    }
361}
362
363#[derive(Clone)]
364/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensor maps.
365///
366/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
367pub struct TensorMapArgs<Config: RuntimeConfig = ()> {
368    _config: PhantomData<Config>,
369}
370
371#[derive(CubeLaunch, CubeType, Clone, Copy)]
372/// Input representation for [TensorArgs] implementing [MatmulArgs].
373pub struct TensorMapInputs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> {
374    /// The lhs tensor.
375    pub lhs: View<Lhs, BatchedCoords>,
376    /// The rhs tensor.
377    pub rhs: View<Rhs, BatchedCoords>,
378    /// The accumulator
379    pub acc: ComptimeOption<View<EO, BatchedCoords>>,
380    /// The accumulator batch layout
381    pub acc_batch: ComptimeOption<VirtualLayout<Coords1d, Coords1d>>,
382}
383
384impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<()>>
385    ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
386{
387    fn create<R: Runtime>(
388        lhs_handle: InputBinding<R>,
389        rhs_handle: InputBinding<R>,
390        blueprint: &A::Blueprint,
391        problem: &MatmulProblem,
392        _vector_sizes: &MatmulVectorSizes,
393        dtypes: &MatmulElems,
394    ) -> Self::RuntimeArg<R> {
395        let lhs = lhs_handle.into_data();
396        let rhs = rhs_handle.into_data();
397
398        let tiling_scheme = blueprint.tiling_scheme();
399        let stage_m = tiling_scheme.elements_per_stage_along_m();
400        let stage_n = tiling_scheme.elements_per_stage_along_n();
401        let stage_k = tiling_scheme.elements_per_stage_along_k();
402
403        // Loaders use dynamic layout based on swizzle setting. For no swizzle, contiguous tiles are
404        // loaded and TMA loads single tile wide columns.
405        // For swizzled, bank conflicts aren't an issue so the tile size is the full stage.
406        let stage_size_lhs = match blueprint.swizzle_modes().lhs {
407            SwizzleMode::None => match problem.lhs_layout {
408                MatrixLayout::RowMajor => {
409                    shape![1, stage_m as usize, tiling_scheme.tile_size.k as usize]
410                }
411                MatrixLayout::ColMajor => {
412                    shape![1, stage_k as usize, tiling_scheme.tile_size.m as usize]
413                }
414            },
415            _ => match problem.lhs_layout {
416                MatrixLayout::RowMajor => {
417                    shape![1, stage_m as usize, stage_k as usize]
418                }
419                MatrixLayout::ColMajor => {
420                    shape![1, stage_k as usize, stage_m as usize]
421                }
422            },
423        };
424        let stage_size_rhs = match blueprint.swizzle_modes().rhs {
425            SwizzleMode::None => match problem.rhs_layout {
426                MatrixLayout::RowMajor => {
427                    shape![1, stage_k as usize, tiling_scheme.tile_size.n as usize]
428                }
429                MatrixLayout::ColMajor => {
430                    shape![1, stage_n as usize, tiling_scheme.tile_size.k as usize]
431                }
432            },
433            _ => match problem.rhs_layout {
434                MatrixLayout::RowMajor => {
435                    shape![1, stage_k as usize, stage_n as usize]
436                }
437                MatrixLayout::ColMajor => {
438                    shape![1, stage_n as usize, stage_k as usize]
439                }
440            },
441        };
442
443        let lhs_rank = lhs.shape.len();
444        let mut lhs_shape = shape![
445            problem.lhs_batches.iter().product(),
446            lhs.shape[lhs_rank - 2],
447            lhs.shape[lhs_rank - 1],
448        ];
449        let mut lhs_strides = if lhs_rank > 2 {
450            lhs.strides[lhs_rank - 3..].into()
451        } else {
452            strides![lhs.strides[0], lhs.strides[1]]
453        };
454
455        let rhs_rank = rhs.shape.len();
456        let mut rhs_shape = shape![
457            problem.rhs_batches.iter().product(),
458            rhs.shape[rhs_rank - 2],
459            rhs.shape[rhs_rank - 1],
460        ];
461        let mut rhs_strides = if rhs_rank > 2 {
462            rhs.strides[rhs_rank - 3..].into()
463        } else {
464            strides![rhs.strides[0], rhs.strides[1]]
465        };
466
467        let lhs_transposed =
468            transpose_inner_for_tma(&mut lhs_shape, &mut lhs_strides, problem.lhs_layout);
469        let rhs_transposed =
470            transpose_inner_for_tma(&mut rhs_shape, &mut rhs_strides, problem.rhs_layout);
471
472        // Insert batch stride after swap so we can easily get the non-contiguous stride
473        if lhs_strides.len() == 2 {
474            let stride = lhs_strides[0];
475            lhs_strides.insert(0, stride);
476        }
477        if rhs_strides.len() == 2 {
478            let stride = rhs_strides[0];
479            rhs_strides.insert(0, stride);
480        }
481
482        let meta_lhs = tma_meta_tiled(
483            Metadata::new(lhs_shape.clone(), lhs_strides),
484            stage_size_lhs,
485            remap_storage_for_tma(dtypes.lhs_stage),
486            blueprint.swizzle_modes().lhs.into(),
487        );
488
489        let meta_rhs = tma_meta_tiled(
490            Metadata::new(rhs_shape.clone(), rhs_strides),
491            stage_size_rhs,
492            remap_storage_for_tma(dtypes.rhs_stage),
493            blueprint.swizzle_modes().rhs.into(),
494        );
495
496        let lhs = TensorMapArg {
497            tensor: lhs.into_tensor_arg(),
498            metadata: meta_lhs,
499            _kind: PhantomData,
500        };
501        let rhs = TensorMapArg {
502            tensor: rhs.into_tensor_arg(),
503            metadata: meta_rhs,
504            _kind: PhantomData,
505        };
506
507        let view = |buffer, shape: &[usize], transposed| {
508            let batches = shape[0];
509            let (rows, cols) = match transposed {
510                true => (shape[2] as u32, shape[1] as u32),
511                false => (shape[1] as u32, shape[2] as u32),
512            };
513            let shape = (batches, rows, cols);
514            let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
515            ViewArg::new_tensor_map_tiled::<SimpleTmaGlobalLayout>(buffer, layout)
516        };
517
518        TensorMapInputsLaunch::new(
519            view(lhs, &lhs_shape, lhs_transposed),
520            view(rhs, &rhs_shape, rhs_transposed),
521            ComptimeOptionArgs::None,
522            ComptimeOptionArgs::None,
523        )
524    }
525}
526
527#[cube]
528impl<Config: RuntimeConfig> MatmulArgs for TensorMapArgs<Config> {
529    type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
530        TensorMapInputs<Lhs, Rhs, EO>;
531    type Output<EO: CubePrimitive> = TensorOutput<EO>;
532    type Config = Config;
533    type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
534        (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
535
536    fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
537        input: &Self::Input<Lhs, Rhs, EO>,
538        output: &mut Self::Output<EO>,
539        config: Self::Config,
540        #[comptime] _lhs_layout_config: GlobalLayoutConfig,
541        #[comptime] _rhs_layout_config: GlobalLayoutConfig,
542        #[comptime] _out_layout_config: GlobalLayoutConfig,
543    ) -> Self::State<Lhs, Rhs, EO> {
544        (*input, *output, config)
545    }
546
547    fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
548        state: &Self::State<Lhs, Rhs, EO>,
549    ) -> View<Lhs, BatchedCoords> {
550        state.0.lhs
551    }
552
553    fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
554        _state: &Self::State<Lhs, Rhs, EO>,
555        batch: usize,
556    ) -> usize {
557        batch
558    }
559
560    fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
561        state: &Self::State<Lhs, Rhs, EO>,
562    ) -> View<Rhs, BatchedCoords> {
563        state.0.rhs
564    }
565
566    fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
567        _state: &Self::State<Lhs, Rhs, EO>,
568        batch: usize,
569    ) -> usize {
570        batch
571    }
572
573    fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
574        state: &Self::State<Lhs, Rhs, EO>,
575    ) -> ComptimeOption<View<EO, BatchedCoords>> {
576        state.0.acc
577    }
578
579    fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
580        state: &Self::State<Lhs, Rhs, EO>,
581        batch: usize,
582    ) -> usize {
583        #[comptime]
584        #[comptime]
585        match state.0.acc_batch {
586            ComptimeOption::Some(layout) => layout.to_source_pos(batch),
587            ComptimeOption::None => batch,
588        }
589    }
590
591    fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
592        state: &mut Self::State<Lhs, Rhs, EO>,
593    ) -> View<EO, BatchedCoords, ReadWrite> {
594        state.1.view
595    }
596
597    fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
598        state: &Self::State<Lhs, Rhs, EO>,
599        batch: usize,
600    ) -> usize {
601        state.1.batch.to_source_pos(batch)
602    }
603
604    fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
605        state: &Self::State<Lhs, Rhs, EO>,
606    ) -> Self::Config {
607        state.2.clone()
608    }
609}