cubecl_matmul/components/global/
args.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, server::TensorMapMeta, unexpanded};
3use cubecl_std::{
4    CubeOption, CubeOptionArgs, CubeOptionExpand,
5    tensor::{
6        View,
7        launch::ViewArg,
8        layout::{Coords1d, Coords3d, VirtualLayout, VirtualLayoutLaunch},
9    },
10};
11
12use crate::{
13    MatmulInputHandleRef,
14    components::{
15        self, MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection,
16        batch::BatchConfig,
17        global::{
18            GlobalConfig,
19            memory::{
20                BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutConfig,
21                GlobalLayoutLaunch, GlobalScaleLayout, NoopLayout, NoopLayoutLaunch,
22                SimpleTmaGlobalLayout, SimpleTmaGlobalLayoutLaunch,
23            },
24        },
25        stage::SwizzleMode,
26    },
27};
28
29/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
30/// output (not fused).
31pub trait ConcreteInputsFactory: LaunchArg {
32    #[allow(clippy::too_many_arguments)]
33    fn create<'a, R: Runtime>(
34        client: &ComputeClient<R::Server>,
35        lhs: &'a MatmulInputHandleRef<'a, R>,
36        rhs: &'a MatmulInputHandleRef<'a, R>,
37        selection: &MatmulSelection,
38        problem: &MatmulProblem,
39        line_sizes: &MatmulLineSizes,
40        config: impl BatchConfig,
41        dtypes: &MatmulElems,
42    ) -> Self::RuntimeArg<'a, R>;
43}
44
45/// Create the output runtime argument for a matmul kernel that works on concrete inputs and
46/// output (not fused).
47pub trait ConcreteOutputFactory: LaunchArg {
48    #[allow(clippy::too_many_arguments)]
49    fn create<'a, R: Runtime>(
50        client: &ComputeClient<R::Server>,
51        out: &'a TensorHandleRef<'a, R>,
52        selection: &MatmulSelection,
53        problem: &MatmulProblem,
54        line_sizes: &MatmulLineSizes,
55        config: impl BatchConfig,
56        dtypes: &MatmulElems,
57    ) -> Self::RuntimeArg<'a, R>;
58}
59
60#[cube]
61/// Arguments for the matrix multiplication algorithm.
62pub trait MatmulArgs: Send + Sync + 'static + Clone {
63    /// Type used for the input.
64    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: LaunchArg + CubeType;
65
66    /// Type used for the output.
67    type Output<EO: Numeric>: LaunchArg + LaunchArg + CubeType;
68
69    /// Inner state that is used to create [tensor inputs](TensorInput) and
70    /// [tensor outputs](TensorOutput) .
71    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: CubeType;
72
73    /// Init the state.
74    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
75        input: &Self::Input<Lhs, Rhs, EO>,
76        output: &mut Self::Output<EO>,
77        #[comptime] config: G,
78    ) -> Self::State<Lhs, Rhs, EO>;
79
80    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
81        _state: &Self::State<Lhs, Rhs, EO>,
82    ) -> View<Line<Lhs>, Coords3d> {
83        unexpanded!()
84    }
85    fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
86        _state: &Self::State<Lhs, Rhs, EO>,
87        _batch: u32,
88    ) -> u32 {
89        unexpanded!()
90    }
91    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
92        _state: &Self::State<Lhs, Rhs, EO>,
93    ) -> View<Line<Rhs>, Coords3d> {
94        unexpanded!()
95    }
96    fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
97        _state: &Self::State<Lhs, Rhs, EO>,
98        _batch: u32,
99    ) -> u32 {
100        unexpanded!()
101    }
102    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
103        _state: &Self::State<Lhs, Rhs, EO>,
104    ) -> CubeOption<View<Line<EO>, Coords3d>> {
105        unexpanded!()
106    }
107    fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
108        _state: &Self::State<Lhs, Rhs, EO>,
109        _batch: u32,
110    ) -> u32 {
111        unexpanded!()
112    }
113    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
114        _state: &mut Self::State<Lhs, Rhs, EO>,
115    ) -> View<Line<EO>, Coords3d, ReadWrite> {
116        unexpanded!()
117    }
118    fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
119        _state: &Self::State<Lhs, Rhs, EO>,
120        _batch: u32,
121    ) -> u32 {
122        unexpanded!()
123    }
124}
125
126#[derive(Clone, Copy)]
127/// Identification of the [tensor input](TensorInput).
128pub enum TensorInputIdent {
129    Lhs,
130    Rhs,
131}
132
133#[derive(Clone)]
134/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensors.
135///
136/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
137pub struct TensorArgs;
138
139#[derive(CubeLaunch, CubeType, Clone, Copy)]
140/// Input representation for [TensorArgs] implementing [MatmulArgs].
141pub struct TensorInputs<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> {
142    /// The lhs tensor.
143    lhs: View<Line<Lhs>, Coords3d>,
144    lhs_batch: VirtualLayout<Coords1d, Coords1d>,
145    /// The rhs tensor.
146    rhs: View<Line<Rhs>, Coords3d>,
147    rhs_batch: VirtualLayout<Coords1d, Coords1d>,
148    /// The tensor for loading the accumulator, if present
149    acc: CubeOption<View<Line<Acc>, Coords3d>>,
150    acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
151}
152
153impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> ConcreteInputsFactory
154    for TensorInputs<Lhs, Rhs, Acc>
155{
156    fn create<'a, R: Runtime>(
157        client: &ComputeClient<R::Server>,
158        lhs: &'a MatmulInputHandleRef<'a, R>,
159        rhs: &'a MatmulInputHandleRef<'a, R>,
160        _selection: &MatmulSelection,
161        problem: &MatmulProblem,
162        line_sizes: &MatmulLineSizes,
163        config: impl BatchConfig,
164        _dtypes: &MatmulElems,
165    ) -> Self::RuntimeArg<'a, R> {
166        let view = |handle: &'a MatmulInputHandleRef<'a, R>,
167                    config: GlobalLayoutConfig,
168                    line_size| match handle {
169            MatmulInputHandleRef::Normal(handle, _dtype) => {
170                let layout = GlobalLayoutLaunch::from_handle(handle, line_size, config);
171                ViewArg::new::<GlobalLayout>(handle.as_array_arg(line_size), layout)
172            }
173            MatmulInputHandleRef::Quantized {
174                data,
175                scale,
176                shape,
177                scheme,
178                ..
179            } => {
180                let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
181                    client, data, scale, shape, problem, **scheme, line_size, config,
182                );
183                let data_view =
184                    ViewArg::new::<GlobalLayout>(data.as_array_arg(line_size), data_layout);
185                let scales_view =
186                    ViewArg::new::<GlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
187                ViewArg::new_quantized(data_view, scales_view, **scheme)
188            }
189        };
190        let batch_layout = |handle: &'a MatmulInputHandleRef<'a, R>| match handle {
191            MatmulInputHandleRef::Normal(handle, _dtype) => {
192                let layout = BatchLayoutLaunch::from_handle(client, handle, problem);
193                VirtualLayoutLaunch::new::<BatchLayout>(layout)
194            }
195            MatmulInputHandleRef::Quantized { .. } => {
196                VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
197            }
198        };
199
200        let config = config.global_config();
201        TensorInputsLaunch::new(
202            view(
203                lhs,
204                config.lhs_reader_config().gmem_config.into(),
205                line_sizes.lhs,
206            ),
207            batch_layout(lhs),
208            view(
209                rhs,
210                config.rhs_reader_config().gmem_config.into(),
211                line_sizes.rhs,
212            ),
213            batch_layout(rhs),
214            CubeOptionArgs::None,
215            CubeOptionArgs::None,
216        )
217    }
218}
219
220#[derive(CubeType, CubeLaunch, Clone, Copy)]
221pub struct TensorOutput<EG: Numeric> {
222    view: View<Line<EG>, Coords3d, ReadWrite>,
223    batch: VirtualLayout<Coords1d, Coords1d>,
224}
225
226impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
227    fn create<'a, R: Runtime>(
228        client: &ComputeClient<R::Server>,
229        out: &'a TensorHandleRef<'a, R>,
230        _selection: &MatmulSelection,
231        problem: &MatmulProblem,
232        line_sizes: &MatmulLineSizes,
233        config: impl BatchConfig,
234        _dtypes: &MatmulElems,
235    ) -> Self::RuntimeArg<'a, R> {
236        let config = config.global_config();
237        let layout = GlobalLayoutLaunch::from_handle(
238            out,
239            line_sizes.out,
240            config.writer_config().gmem_config.into(),
241        );
242        let batch = BatchLayoutLaunch::from_handle(client, out, problem);
243        let view = ViewArg::new::<GlobalLayout>(out.as_array_arg(line_sizes.out), layout);
244        TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
245    }
246}
247
248#[cube]
249impl MatmulArgs for TensorArgs {
250    type Output<EO: Numeric> = TensorOutput<EO>;
251    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorInputs<Lhs, Rhs, EO>;
252    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
253        (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>);
254
255    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
256        input: &Self::Input<Lhs, Rhs, EO>,
257        output: &mut Self::Output<EO>,
258        #[comptime] _config: G,
259    ) -> Self::State<Lhs, Rhs, EO> {
260        (*input, *output)
261    }
262
263    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
264        state: &Self::State<Lhs, Rhs, EO>,
265    ) -> View<Line<Lhs>, Coords3d> {
266        state.0.lhs
267    }
268
269    fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
270        state: &Self::State<Lhs, Rhs, EO>,
271        batch: u32,
272    ) -> u32 {
273        state.0.lhs_batch.to_source_pos(batch)
274    }
275
276    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
277        state: &Self::State<Lhs, Rhs, EO>,
278    ) -> View<Line<Rhs>, Coords3d> {
279        state.0.rhs
280    }
281
282    fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
283        state: &Self::State<Lhs, Rhs, EO>,
284        batch: u32,
285    ) -> u32 {
286        state.0.rhs_batch.to_source_pos(batch)
287    }
288
289    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
290        state: &Self::State<Lhs, Rhs, EO>,
291    ) -> CubeOption<View<Line<EO>, Coords3d>> {
292        state.0.acc
293    }
294
295    fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
296        state: &Self::State<Lhs, Rhs, EO>,
297        batch: u32,
298    ) -> u32 {
299        match state.0.acc_batch {
300            CubeOption::Some(layout) => layout.to_source_pos(batch),
301            CubeOption::None => batch,
302        }
303    }
304
305    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
306        state: &mut Self::State<Lhs, Rhs, EO>,
307    ) -> View<Line<EO>, Coords3d, ReadWrite> {
308        state.1.view
309    }
310
311    fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
312        state: &Self::State<Lhs, Rhs, EO>,
313        batch: u32,
314    ) -> u32 {
315        state.1.batch.to_source_pos(batch)
316    }
317}
318
319#[derive(Clone)]
320/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensor maps.
321///
322/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
323pub struct TensorMapArgs;
324
325#[derive(CubeLaunch, CubeType, Clone, Copy)]
326/// Input representation for [TensorArgs] implementing [MatmulArgs].
327pub struct TensorMapInputs<Lhs: Numeric, Rhs: Numeric, EO: Numeric> {
328    /// The lhs tensor.
329    pub lhs: View<Line<Lhs>, Coords3d>,
330    /// The rhs tensor.
331    pub rhs: View<Line<Rhs>, Coords3d>,
332    /// The accumulator
333    pub acc: CubeOption<View<Line<EO>, Coords3d>>,
334    /// The accumulator batch layout
335    pub acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
336}
337
338impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
339    for TensorMapInputs<Lhs, Rhs, EO>
340{
341    fn create<'a, R: Runtime>(
342        _client: &ComputeClient<R::Server>,
343        lhs_handle: &'a MatmulInputHandleRef<'a, R>,
344        rhs_handle: &'a MatmulInputHandleRef<'a, R>,
345        selection: &MatmulSelection,
346        problem: &MatmulProblem,
347        line_sizes: &MatmulLineSizes,
348        config: impl BatchConfig,
349        dtypes: &MatmulElems,
350    ) -> Self::RuntimeArg<'a, R> {
351        let lhs = lhs_handle.data();
352        let rhs = rhs_handle.data();
353
354        let config = config.global_config();
355
356        let tiling_scheme = selection.tiling_scheme;
357        let stage_m = tiling_scheme.elements_per_stage_along_m();
358        let stage_n = tiling_scheme.elements_per_stage_along_n();
359        let stage_k = tiling_scheme.elements_per_stage_along_k();
360
361        // Loaders use dynamic layout based on swizzle setting. For no swizzle, contiguous tiles are
362        // loaded and TMA loads single tile wide columns.
363        // For swizzled, bank conflicts aren't an issue so the tile size is the full stage.
364        let stage_size_lhs = match config.lhs_reader_config().smem_config.swizzle {
365            SwizzleMode::None => match problem.lhs_layout {
366                components::MatrixLayout::RowMajor => {
367                    vec![1, stage_m, tiling_scheme.tile_size.k]
368                }
369                components::MatrixLayout::ColMajor => {
370                    vec![1, stage_k, tiling_scheme.tile_size.m]
371                }
372            },
373            _ => match problem.lhs_layout {
374                components::MatrixLayout::RowMajor => {
375                    vec![1, stage_m, stage_k]
376                }
377                components::MatrixLayout::ColMajor => {
378                    vec![1, stage_k, stage_m]
379                }
380            },
381        };
382        let stage_size_rhs = match config.rhs_reader_config().smem_config.swizzle {
383            SwizzleMode::None => match problem.rhs_layout {
384                components::MatrixLayout::RowMajor => {
385                    vec![1, stage_k, tiling_scheme.tile_size.n]
386                }
387                components::MatrixLayout::ColMajor => {
388                    vec![1, stage_n, tiling_scheme.tile_size.k]
389                }
390            },
391            _ => match problem.rhs_layout {
392                components::MatrixLayout::RowMajor => {
393                    vec![1, stage_k, stage_n]
394                }
395                components::MatrixLayout::ColMajor => {
396                    vec![1, stage_n, stage_k]
397                }
398            },
399        };
400
401        let lhs_rank = lhs.shape.len();
402        let mut lhs_shape = vec![
403            problem.lhs_batches.iter().product(),
404            lhs.shape[lhs_rank - 2],
405            lhs.shape[lhs_rank - 1],
406        ];
407        let mut lhs_strides = if lhs_rank > 2 {
408            lhs.strides[lhs_rank - 3..].to_vec()
409        } else {
410            vec![lhs.strides[0], lhs.strides[1]]
411        };
412
413        let rhs_rank = rhs.shape.len();
414        let mut rhs_shape = vec![
415            problem.rhs_batches.iter().product(),
416            rhs.shape[rhs_rank - 2],
417            rhs.shape[rhs_rank - 1],
418        ];
419        let mut rhs_strides = if rhs_rank > 2 {
420            rhs.strides[rhs_rank - 3..].to_vec()
421        } else {
422            vec![rhs.strides[0], rhs.strides[1]]
423        };
424
425        let mut lhs_transposed = false;
426        let mut rhs_transposed = false;
427
428        let lhs_rank = lhs_strides.len();
429        let rhs_rank = rhs_strides.len();
430
431        // TMA assumes the last stride is contiguous and won't even take it, so we need to map it
432        // with transposed shape and stride. Tensor metadata still has the normal layout.
433        if matches!(problem.lhs_layout, components::MatrixLayout::ColMajor) {
434            lhs_shape.swap(2, 1);
435            lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
436            lhs_transposed = true;
437        }
438        if matches!(problem.rhs_layout, components::MatrixLayout::ColMajor) {
439            rhs_shape.swap(2, 1);
440            rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
441            rhs_transposed = true;
442        }
443
444        // Insert batch stride after swap so we can easily get the non-contiguous stride
445        if lhs_rank == 2 {
446            let stride = lhs_strides[0];
447            lhs_strides.insert(0, stride);
448        }
449        if rhs_rank == 2 {
450            let stride = rhs_strides[0];
451            rhs_strides.insert(0, stride);
452        }
453
454        fn swizzle(mode: SwizzleMode) -> TensorMapSwizzle {
455            match mode {
456                SwizzleMode::None => TensorMapSwizzle::None,
457                SwizzleMode::B32 => TensorMapSwizzle::B32,
458                SwizzleMode::B64 => TensorMapSwizzle::B64,
459                SwizzleMode::B128 => TensorMapSwizzle::B128,
460            }
461        }
462
463        let swizzle_lhs = swizzle(config.lhs_reader_config().smem_config.swizzle);
464        let swizzle_rhs = swizzle(config.rhs_reader_config().smem_config.swizzle);
465
466        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
467        // It shouldn't matter, but it's better to be safe.
468        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
469            tf32::as_type_native_unchecked()
470        } else {
471            dtypes.lhs_stage
472        };
473        let rhs_elem = if dtypes.rhs_stage == f32::as_type_native_unchecked() {
474            tf32::as_type_native_unchecked()
475        } else {
476            dtypes.rhs_stage
477        };
478
479        let meta_lhs = TensorMapMeta {
480            format: TensorMapFormat::Tiled {
481                tile_size: stage_size_lhs,
482            },
483            rank: 3,
484            shape: lhs_shape.clone(),
485            strides: lhs_strides,
486            elem_stride: vec![1, 1, 1],
487            interleave: TensorMapInterleave::None,
488            swizzle: swizzle_lhs,
489            prefetch: TensorMapPrefetch::None,
490            oob_fill: OobFill::Zero,
491            storage_ty: lhs_elem,
492        };
493
494        let meta_rhs = TensorMapMeta {
495            format: TensorMapFormat::Tiled {
496                tile_size: stage_size_rhs,
497            },
498            rank: 3,
499            shape: rhs_shape.clone(),
500            strides: rhs_strides,
501            elem_stride: vec![1, 1, 1],
502            interleave: TensorMapInterleave::None,
503            swizzle: swizzle_rhs,
504            prefetch: TensorMapPrefetch::None,
505            oob_fill: OobFill::Zero,
506            storage_ty: rhs_elem,
507        };
508
509        let lhs = TensorMapArg {
510            tensor: lhs.as_tensor_arg(line_sizes.lhs),
511            metadata: meta_lhs,
512        };
513        let rhs = TensorMapArg {
514            tensor: rhs.as_tensor_arg(line_sizes.rhs),
515            metadata: meta_rhs,
516        };
517
518        let view = |buffer, shape: &[usize], transposed| {
519            let batches = ScalarArg::new(shape[0] as u32);
520            let (rows, cols) = match transposed {
521                true => (
522                    ScalarArg::new(shape[2] as u32),
523                    ScalarArg::new(shape[1] as u32),
524                ),
525                false => (
526                    ScalarArg::new(shape[1] as u32),
527                    ScalarArg::new(shape[2] as u32),
528                ),
529            };
530            let shape = (batches, rows, cols);
531            let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
532            ViewArg::new_tensor_map::<SimpleTmaGlobalLayout>(buffer, layout)
533        };
534
535        TensorMapInputsLaunch::new(
536            view(lhs, &lhs_shape, lhs_transposed),
537            view(rhs, &rhs_shape, rhs_transposed),
538            CubeOptionArgs::None,
539            CubeOptionArgs::None,
540        )
541    }
542}
543
544#[cube]
545impl MatmulArgs for TensorMapArgs {
546    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorMapInputs<Lhs, Rhs, EO>;
547    type Output<EO: Numeric> = TensorOutput<EO>;
548    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
549        (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>);
550
551    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
552        input: &Self::Input<Lhs, Rhs, EO>,
553        output: &mut Self::Output<EO>,
554        #[comptime] _config: G,
555    ) -> Self::State<Lhs, Rhs, EO> {
556        (*input, *output)
557    }
558
559    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
560        state: &Self::State<Lhs, Rhs, EO>,
561    ) -> View<Line<Lhs>, Coords3d> {
562        state.0.lhs
563    }
564
565    fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
566        _state: &Self::State<Lhs, Rhs, EO>,
567        batch: u32,
568    ) -> u32 {
569        batch
570    }
571
572    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
573        state: &Self::State<Lhs, Rhs, EO>,
574    ) -> View<Line<Rhs>, Coords3d> {
575        state.0.rhs
576    }
577
578    fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
579        _state: &Self::State<Lhs, Rhs, EO>,
580        batch: u32,
581    ) -> u32 {
582        batch
583    }
584
585    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
586        state: &Self::State<Lhs, Rhs, EO>,
587    ) -> CubeOption<View<Line<EO>, Coords3d>> {
588        state.0.acc
589    }
590
591    fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
592        state: &Self::State<Lhs, Rhs, EO>,
593        batch: u32,
594    ) -> u32 {
595        match state.0.acc_batch {
596            CubeOption::Some(layout) => layout.to_source_pos(batch),
597            CubeOption::None => batch,
598        }
599    }
600
601    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
602        state: &mut Self::State<Lhs, Rhs, EO>,
603    ) -> View<Line<EO>, Coords3d, ReadWrite> {
604        state.1.view
605    }
606
607    fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
608        state: &Self::State<Lhs, Rhs, EO>,
609        batch: u32,
610    ) -> u32 {
611        state.1.batch.to_source_pos(batch)
612    }
613}