cubecl_matmul/components/global/
args.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl, server::TensorMapMeta, unexpanded};
3use cubecl_std::{
4    CubeOption, CubeOptionArgs, CubeOptionExpand,
5    tensor::{
6        View,
7        launch::ViewArg,
8        layout::{Coords1d, Coords3d, VirtualLayout, VirtualLayoutLaunch},
9    },
10};
11
12use crate::{
13    MatmulInputHandleRef,
14    components::{
15        self, MatmulElems, MatmulIdent, MatmulLineSizes, MatmulProblem, MatmulSelection,
16        batch::BatchConfig,
17        global::{
18            GlobalConfig,
19            memory::{
20                BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutLaunch,
21                GlobalScaleLayout, NoopLayout, NoopLayoutLaunch, SimpleTmaGlobalLayout,
22                SimpleTmaGlobalLayoutLaunch,
23            },
24        },
25        stage::SwizzleMode,
26    },
27};
28
29/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
30/// output (not fused).
31pub trait ConcreteInputsFactory: LaunchArg {
32    #[allow(clippy::too_many_arguments)]
33    fn create<'a, R: Runtime>(
34        client: &ComputeClient<R::Server>,
35        lhs: &'a MatmulInputHandleRef<'a, R>,
36        rhs: &'a MatmulInputHandleRef<'a, R>,
37        selection: &MatmulSelection,
38        problem: &MatmulProblem,
39        line_sizes: &MatmulLineSizes,
40        config: impl BatchConfig,
41        dtypes: &MatmulElems,
42    ) -> Self::RuntimeArg<'a, R>;
43}
44
45/// Create the output runtime argument for a matmul kernel that works on concrete inputs and
46/// output (not fused).
47pub trait ConcreteOutputFactory: LaunchArg {
48    #[allow(clippy::too_many_arguments)]
49    fn create<'a, R: Runtime>(
50        client: &ComputeClient<R::Server>,
51        out: &'a TensorHandleRef<'a, R>,
52        selection: &MatmulSelection,
53        problem: &MatmulProblem,
54        line_sizes: &MatmulLineSizes,
55        config: impl BatchConfig,
56        dtypes: &MatmulElems,
57    ) -> Self::RuntimeArg<'a, R>;
58}
59
60#[cube]
61/// Arguments for the matrix multiplication algorithm.
62pub trait MatmulArgs: Send + Sync + 'static + Clone {
63    /// Type used for the input.
64    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: LaunchArg + CubeType;
65
66    /// Type used for the output.
67    type Output<EO: Numeric>: LaunchArg + LaunchArg + CubeType;
68
69    /// Inner state that is used to create [tensor inputs](TensorInput) and
70    /// [tensor outputs](TensorOutput) .
71    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: CubeType;
72
73    /// Init the state.
74    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
75        input: &Self::Input<Lhs, Rhs, EO>,
76        output: &mut Self::Output<EO>,
77        #[comptime] config: G,
78    ) -> Self::State<Lhs, Rhs, EO>;
79
80    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
81        _state: &Self::State<Lhs, Rhs, EO>,
82    ) -> View<Line<Lhs>, Coords3d> {
83        unexpanded!()
84    }
85    fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
86        _state: &Self::State<Lhs, Rhs, EO>,
87        _batch: u32,
88    ) -> u32 {
89        unexpanded!()
90    }
91    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
92        _state: &Self::State<Lhs, Rhs, EO>,
93    ) -> View<Line<Rhs>, Coords3d> {
94        unexpanded!()
95    }
96    fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
97        _state: &Self::State<Lhs, Rhs, EO>,
98        _batch: u32,
99    ) -> u32 {
100        unexpanded!()
101    }
102    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
103        _state: &Self::State<Lhs, Rhs, EO>,
104    ) -> CubeOption<View<Line<EO>, Coords3d>> {
105        unexpanded!()
106    }
107    fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
108        _state: &Self::State<Lhs, Rhs, EO>,
109        _batch: u32,
110    ) -> u32 {
111        unexpanded!()
112    }
113    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
114        _state: &mut Self::State<Lhs, Rhs, EO>,
115    ) -> View<Line<EO>, Coords3d, ReadWrite> {
116        unexpanded!()
117    }
118    fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
119        _state: &Self::State<Lhs, Rhs, EO>,
120        _batch: u32,
121    ) -> u32 {
122        unexpanded!()
123    }
124}
125
126#[derive(Clone, Copy)]
127/// Identification of the [tensor input](TensorInput).
128pub enum TensorInputIdent {
129    Lhs,
130    Rhs,
131}
132
133#[derive(Clone)]
134/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensors.
135///
136/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
137pub struct TensorArgs;
138
139#[derive(CubeLaunch, CubeType, Clone, Copy)]
140/// Input representation for [TensorArgs] implementing [MatmulArgs].
141pub struct TensorInputs<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> {
142    /// The lhs tensor.
143    lhs: View<Line<Lhs>, Coords3d>,
144    lhs_batch: VirtualLayout<Coords1d, Coords1d>,
145    /// The rhs tensor.
146    rhs: View<Line<Rhs>, Coords3d>,
147    rhs_batch: VirtualLayout<Coords1d, Coords1d>,
148    /// The tensor for loading the accumulator, if present
149    acc: CubeOption<View<Line<Acc>, Coords3d>>,
150    acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
151}
152
153impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> ConcreteInputsFactory
154    for TensorInputs<Lhs, Rhs, Acc>
155{
156    fn create<'a, R: Runtime>(
157        client: &ComputeClient<R::Server>,
158        lhs: &'a MatmulInputHandleRef<'a, R>,
159        rhs: &'a MatmulInputHandleRef<'a, R>,
160        _selection: &MatmulSelection,
161        problem: &MatmulProblem,
162        line_sizes: &MatmulLineSizes,
163        config: impl BatchConfig,
164        _dtypes: &MatmulElems,
165    ) -> Self::RuntimeArg<'a, R> {
166        let config = config.global_config();
167        let view = |handle: &'a MatmulInputHandleRef<'a, R>, ident, line_size| match handle {
168            MatmulInputHandleRef::Normal(handle, _dtype) => {
169                let layout = GlobalLayoutLaunch::from_handle(
170                    handle,
171                    line_size,
172                    config.global_memory_config(ident).into(),
173                );
174                ViewArg::new::<GlobalLayout>(handle.as_array_arg(line_size), layout)
175            }
176            MatmulInputHandleRef::Quantized {
177                data,
178                scale,
179                shape,
180                scheme,
181                ..
182            } => {
183                let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
184                    client,
185                    data,
186                    scale,
187                    shape,
188                    problem,
189                    **scheme,
190                    line_size,
191                    config.global_memory_config(ident).into(),
192                );
193                let data_view =
194                    ViewArg::new::<GlobalLayout>(data.as_array_arg(line_size), data_layout);
195                let scales_view =
196                    ViewArg::new::<GlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
197                ViewArg::new_quantized(data_view, scales_view, **scheme)
198            }
199        };
200        let batch_layout = |handle: &'a MatmulInputHandleRef<'a, R>| match handle {
201            MatmulInputHandleRef::Normal(handle, _dtype) => {
202                let layout = BatchLayoutLaunch::from_handle(client, handle, problem);
203                VirtualLayoutLaunch::new::<BatchLayout>(layout)
204            }
205            MatmulInputHandleRef::Quantized { .. } => {
206                VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
207            }
208        };
209
210        TensorInputsLaunch::new(
211            view(lhs, MatmulIdent::Lhs, line_sizes.lhs),
212            batch_layout(lhs),
213            view(rhs, MatmulIdent::Rhs, line_sizes.rhs),
214            batch_layout(rhs),
215            CubeOptionArgs::None,
216            CubeOptionArgs::None,
217        )
218    }
219}
220
221#[derive(CubeType, CubeLaunch, Clone, Copy)]
222pub struct TensorOutput<EG: Numeric> {
223    view: View<Line<EG>, Coords3d, ReadWrite>,
224    batch: VirtualLayout<Coords1d, Coords1d>,
225}
226
227impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
228    fn create<'a, R: Runtime>(
229        client: &ComputeClient<R::Server>,
230        out: &'a TensorHandleRef<'a, R>,
231        _selection: &MatmulSelection,
232        problem: &MatmulProblem,
233        line_sizes: &MatmulLineSizes,
234        config: impl BatchConfig,
235        _dtypes: &MatmulElems,
236    ) -> Self::RuntimeArg<'a, R> {
237        let config = config.global_config();
238        let layout = GlobalLayoutLaunch::from_handle(
239            out,
240            line_sizes.out,
241            config.global_memory_config(MatmulIdent::Out).into(),
242        );
243        let batch = BatchLayoutLaunch::from_handle(client, out, problem);
244        let view = ViewArg::new::<GlobalLayout>(out.as_array_arg(line_sizes.out), layout);
245        TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
246    }
247}
248
249#[cube]
250impl MatmulArgs for TensorArgs {
251    type Output<EO: Numeric> = TensorOutput<EO>;
252    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorInputs<Lhs, Rhs, EO>;
253    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
254        (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>);
255
256    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
257        input: &Self::Input<Lhs, Rhs, EO>,
258        output: &mut Self::Output<EO>,
259        #[comptime] _config: G,
260    ) -> Self::State<Lhs, Rhs, EO> {
261        (*input, *output)
262    }
263
264    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
265        state: &Self::State<Lhs, Rhs, EO>,
266    ) -> View<Line<Lhs>, Coords3d> {
267        state.0.lhs
268    }
269
270    fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
271        state: &Self::State<Lhs, Rhs, EO>,
272        batch: u32,
273    ) -> u32 {
274        state.0.lhs_batch.to_source_pos(batch)
275    }
276
277    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
278        state: &Self::State<Lhs, Rhs, EO>,
279    ) -> View<Line<Rhs>, Coords3d> {
280        state.0.rhs
281    }
282
283    fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
284        state: &Self::State<Lhs, Rhs, EO>,
285        batch: u32,
286    ) -> u32 {
287        state.0.rhs_batch.to_source_pos(batch)
288    }
289
290    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
291        state: &Self::State<Lhs, Rhs, EO>,
292    ) -> CubeOption<View<Line<EO>, Coords3d>> {
293        state.0.acc
294    }
295
296    fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
297        state: &Self::State<Lhs, Rhs, EO>,
298        batch: u32,
299    ) -> u32 {
300        match state.0.acc_batch {
301            CubeOption::Some(layout) => layout.to_source_pos(batch),
302            CubeOption::None => batch,
303        }
304    }
305
306    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
307        state: &mut Self::State<Lhs, Rhs, EO>,
308    ) -> View<Line<EO>, Coords3d, ReadWrite> {
309        state.1.view
310    }
311
312    fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
313        state: &Self::State<Lhs, Rhs, EO>,
314        batch: u32,
315    ) -> u32 {
316        state.1.batch.to_source_pos(batch)
317    }
318}
319
320#[derive(Clone)]
321/// Type implementing [MatmulArgs] where all inputs and the output are materialized tensor maps.
322///
323/// Other types might implement [MatmulArgs] for fused matrix multiplication kernels.
324pub struct TensorMapArgs;
325
326#[derive(CubeLaunch, CubeType, Clone, Copy)]
327/// Input representation for [TensorArgs] implementing [MatmulArgs].
328pub struct TensorMapInputs<Lhs: Numeric, Rhs: Numeric, EO: Numeric> {
329    /// The lhs tensor.
330    pub lhs: View<Line<Lhs>, Coords3d>,
331    /// The rhs tensor.
332    pub rhs: View<Line<Rhs>, Coords3d>,
333    /// The accumulator
334    pub acc: CubeOption<View<Line<EO>, Coords3d>>,
335    /// The accumulator batch layout
336    pub acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
337}
338
339impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
340    for TensorMapInputs<Lhs, Rhs, EO>
341{
342    fn create<'a, R: Runtime>(
343        _client: &ComputeClient<R::Server>,
344        lhs_handle: &'a MatmulInputHandleRef<'a, R>,
345        rhs_handle: &'a MatmulInputHandleRef<'a, R>,
346        selection: &MatmulSelection,
347        problem: &MatmulProblem,
348        line_sizes: &MatmulLineSizes,
349        config: impl BatchConfig,
350        dtypes: &MatmulElems,
351    ) -> Self::RuntimeArg<'a, R> {
352        let lhs = lhs_handle.data();
353        let rhs = rhs_handle.data();
354
355        let config = config.global_config();
356
357        let tiling_scheme = selection.tiling_scheme;
358        let stage_m = tiling_scheme.elements_in_stage_m();
359        let stage_n = tiling_scheme.elements_in_stage_n();
360        let stage_k = tiling_scheme.elements_in_stage_k();
361
362        // Loaders use dynamic layout based on swizzle setting. For no swizzle, contiguous tiles are
363        // loaded and TMA loads single tile wide columns.
364        // For swizzled, bank conflicts aren't an issue so the tile size is the full stage.
365        let stage_size_lhs = match config.swizzle_mode(MatmulIdent::Lhs) {
366            SwizzleMode::None => match problem.lhs_layout {
367                components::MatrixLayout::RowMajor => {
368                    vec![1, stage_m, tiling_scheme.elements_in_tile_k()]
369                }
370                components::MatrixLayout::ColMajor => {
371                    vec![1, stage_k, tiling_scheme.elements_in_tile_m()]
372                }
373            },
374            _ => match problem.lhs_layout {
375                components::MatrixLayout::RowMajor => {
376                    vec![1, stage_m, stage_k]
377                }
378                components::MatrixLayout::ColMajor => {
379                    vec![1, stage_k, stage_m]
380                }
381            },
382        };
383        let stage_size_rhs = match config.swizzle_mode(MatmulIdent::Rhs) {
384            SwizzleMode::None => match problem.rhs_layout {
385                components::MatrixLayout::RowMajor => {
386                    vec![1, stage_k, tiling_scheme.elements_in_tile_n()]
387                }
388                components::MatrixLayout::ColMajor => {
389                    vec![1, stage_n, tiling_scheme.elements_in_tile_k()]
390                }
391            },
392            _ => match problem.rhs_layout {
393                components::MatrixLayout::RowMajor => {
394                    vec![1, stage_k, stage_n]
395                }
396                components::MatrixLayout::ColMajor => {
397                    vec![1, stage_n, stage_k]
398                }
399            },
400        };
401
402        let lhs_rank = lhs.shape.len();
403        let mut lhs_shape = vec![
404            problem.lhs_batches.iter().product(),
405            lhs.shape[lhs_rank - 2],
406            lhs.shape[lhs_rank - 1],
407        ];
408        let mut lhs_strides = if lhs_rank > 2 {
409            lhs.strides[lhs_rank - 3..].to_vec()
410        } else {
411            vec![lhs.strides[0], lhs.strides[1]]
412        };
413
414        let rhs_rank = rhs.shape.len();
415        let mut rhs_shape = vec![
416            problem.rhs_batches.iter().product(),
417            rhs.shape[rhs_rank - 2],
418            rhs.shape[rhs_rank - 1],
419        ];
420        let mut rhs_strides = if rhs_rank > 2 {
421            rhs.strides[rhs_rank - 3..].to_vec()
422        } else {
423            vec![rhs.strides[0], rhs.strides[1]]
424        };
425
426        let mut lhs_transposed = false;
427        let mut rhs_transposed = false;
428
429        let lhs_rank = lhs_strides.len();
430        let rhs_rank = rhs_strides.len();
431
432        // TMA assumes the last stride is contiguous and won't even take it, so we need to map it
433        // with transposed shape and stride. Tensor metadata still has the normal layout.
434        if matches!(problem.lhs_layout, components::MatrixLayout::ColMajor) {
435            lhs_shape.swap(2, 1);
436            lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
437            lhs_transposed = true;
438        }
439        if matches!(problem.rhs_layout, components::MatrixLayout::ColMajor) {
440            rhs_shape.swap(2, 1);
441            rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
442            rhs_transposed = true;
443        }
444
445        // Insert batch stride after swap so we can easily get the non-contiguous stride
446        if lhs_rank == 2 {
447            let stride = lhs_strides[0];
448            lhs_strides.insert(0, stride);
449        }
450        if rhs_rank == 2 {
451            let stride = rhs_strides[0];
452            rhs_strides.insert(0, stride);
453        }
454
455        fn swizzle(mode: SwizzleMode) -> TensorMapSwizzle {
456            match mode {
457                SwizzleMode::None => TensorMapSwizzle::None,
458                SwizzleMode::B32 => TensorMapSwizzle::B32,
459                SwizzleMode::B64 => TensorMapSwizzle::B64,
460                SwizzleMode::B128 => TensorMapSwizzle::B128,
461            }
462        }
463
464        let swizzle_lhs = swizzle(config.swizzle_mode(MatmulIdent::Lhs));
465        let swizzle_rhs = swizzle(config.swizzle_mode(MatmulIdent::Rhs));
466
467        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
468        // It shouldn't matter, but it's better to be safe.
469        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
470            tf32::as_type_native_unchecked()
471        } else {
472            dtypes.lhs_stage
473        };
474        let rhs_elem = if dtypes.rhs_stage == f32::as_type_native_unchecked() {
475            tf32::as_type_native_unchecked()
476        } else {
477            dtypes.rhs_stage
478        };
479
480        let meta_lhs = TensorMapMeta {
481            format: TensorMapFormat::Tiled {
482                tile_size: stage_size_lhs,
483            },
484            rank: 3,
485            shape: lhs_shape.clone(),
486            strides: lhs_strides,
487            elem_stride: vec![1, 1, 1],
488            interleave: TensorMapInterleave::None,
489            swizzle: swizzle_lhs,
490            prefetch: TensorMapPrefetch::None,
491            oob_fill: OobFill::Zero,
492            storage_ty: lhs_elem,
493        };
494
495        let meta_rhs = TensorMapMeta {
496            format: TensorMapFormat::Tiled {
497                tile_size: stage_size_rhs,
498            },
499            rank: 3,
500            shape: rhs_shape.clone(),
501            strides: rhs_strides,
502            elem_stride: vec![1, 1, 1],
503            interleave: TensorMapInterleave::None,
504            swizzle: swizzle_rhs,
505            prefetch: TensorMapPrefetch::None,
506            oob_fill: OobFill::Zero,
507            storage_ty: rhs_elem,
508        };
509
510        let lhs = TensorMapArg {
511            tensor: lhs.as_tensor_arg(line_sizes.lhs),
512            metadata: meta_lhs,
513        };
514        let rhs = TensorMapArg {
515            tensor: rhs.as_tensor_arg(line_sizes.rhs),
516            metadata: meta_rhs,
517        };
518
519        let view = |buffer, shape: &[usize], transposed| {
520            let batches = ScalarArg::new(shape[0] as u32);
521            let (rows, cols) = match transposed {
522                true => (
523                    ScalarArg::new(shape[2] as u32),
524                    ScalarArg::new(shape[1] as u32),
525                ),
526                false => (
527                    ScalarArg::new(shape[1] as u32),
528                    ScalarArg::new(shape[2] as u32),
529                ),
530            };
531            let shape = (batches, rows, cols);
532            let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
533            ViewArg::new_tensor_map::<SimpleTmaGlobalLayout>(buffer, layout)
534        };
535
536        TensorMapInputsLaunch::new(
537            view(lhs, &lhs_shape, lhs_transposed),
538            view(rhs, &rhs_shape, rhs_transposed),
539            CubeOptionArgs::None,
540            CubeOptionArgs::None,
541        )
542    }
543}
544
545#[cube]
546impl MatmulArgs for TensorMapArgs {
547    type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorMapInputs<Lhs, Rhs, EO>;
548    type Output<EO: Numeric> = TensorOutput<EO>;
549    type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
550        (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>);
551
552    fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric, G: GlobalConfig>(
553        input: &Self::Input<Lhs, Rhs, EO>,
554        output: &mut Self::Output<EO>,
555        #[comptime] _config: G,
556    ) -> Self::State<Lhs, Rhs, EO> {
557        (*input, *output)
558    }
559
560    fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
561        state: &Self::State<Lhs, Rhs, EO>,
562    ) -> View<Line<Lhs>, Coords3d> {
563        state.0.lhs
564    }
565
566    fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
567        _state: &Self::State<Lhs, Rhs, EO>,
568        batch: u32,
569    ) -> u32 {
570        batch
571    }
572
573    fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
574        state: &Self::State<Lhs, Rhs, EO>,
575    ) -> View<Line<Rhs>, Coords3d> {
576        state.0.rhs
577    }
578
579    fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
580        _state: &Self::State<Lhs, Rhs, EO>,
581        batch: u32,
582    ) -> u32 {
583        batch
584    }
585
586    fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
587        state: &Self::State<Lhs, Rhs, EO>,
588    ) -> CubeOption<View<Line<EO>, Coords3d>> {
589        state.0.acc
590    }
591
592    fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
593        state: &Self::State<Lhs, Rhs, EO>,
594        batch: u32,
595    ) -> u32 {
596        match state.0.acc_batch {
597            CubeOption::Some(layout) => layout.to_source_pos(batch),
598            CubeOption::None => batch,
599        }
600    }
601
602    fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
603        state: &mut Self::State<Lhs, Rhs, EO>,
604    ) -> View<Line<EO>, Coords3d, ReadWrite> {
605        state.1.view
606    }
607
608    fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
609        state: &Self::State<Lhs, Rhs, EO>,
610        batch: u32,
611    ) -> u32 {
612        state.1.batch.to_source_pos(batch)
613    }
614}