1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl::std::{
5 CubeOption, CubeOptionArgs, CubeOptionExpand,
6 tensor::{
7 View,
8 launch::ViewArg,
9 layout::{Coords1d, VirtualLayout, VirtualLayoutLaunch},
10 },
11};
12use cubecl::{server::TensorMapMeta, unexpanded};
13
14use crate::components::{
15 global::memory::{
16 BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutConfig, GlobalLayoutLaunch,
17 GlobalScaleLayout, NoopLayout, NoopLayoutLaunch, SimpleTmaGlobalLayout,
18 SimpleTmaGlobalLayoutLaunch,
19 },
20 stage::SwizzleMode,
21};
22use crate::definition::{
23 self, Blueprint as _, MatmulElems, MatmulLineSizes, MatmulProblem, TilingBlueprint,
24};
25use crate::launch::handle::MatmulInputHandleRef;
26use crate::routines::Routine;
27
28pub type InputArg<MA> =
30 <MA as MatmulArgs>::Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>;
31
32pub type OutputArg<MA> = <MA as MatmulArgs>::Output<NumericExpand<2>>;
34
35pub type ConfigArg<MA> = <MA as MatmulArgs>::Config;
37
38pub type InputRuntimeArg<'a, MA, R> = <InputArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
40
41pub type ConfigRuntimeArg<'a, MA, R> = <ConfigArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
43
44pub type OutputRuntimeArg<'a, MA, R> = <OutputArg<MA> as LaunchArg>::RuntimeArg<'a, R>;
46
47pub type BatchedCoords = (usize, u32, u32);
48
49pub trait ConcreteInputsFactory<A: Routine<()>>: LaunchArg {
52 #[allow(clippy::too_many_arguments)]
53 fn create<'a, R: Runtime>(
54 client: &ComputeClient<R>,
55 lhs: &'a MatmulInputHandleRef<'a, R>,
56 rhs: &'a MatmulInputHandleRef<'a, R>,
57 blueprint: &A::Blueprint,
58 problem: &MatmulProblem,
59 line_sizes: &MatmulLineSizes,
60 dtypes: &MatmulElems,
61 ) -> Self::RuntimeArg<'a, R>;
62}
63
64pub trait ConcreteOutputFactory<A: Routine<()>>: LaunchArg {
67 #[allow(clippy::too_many_arguments)]
68 fn create<'a, R: Runtime>(
69 client: &ComputeClient<R>,
70 out: &'a TensorHandleRef<'a, R>,
71 blueprint: &A::Blueprint,
72 problem: &MatmulProblem,
73 line_sizes: &MatmulLineSizes,
74 dtypes: &MatmulElems,
75 ) -> Self::RuntimeArg<'a, R>;
76}
77
78pub trait RuntimeConfig: LaunchArg + CubeType + Clone + Send + Sync {}
79impl<T: LaunchArg + CubeType + Clone + Send + Sync> RuntimeConfig for T {}
80
81#[cube]
82pub trait MatmulArgs: Send + Sync + 'static + Clone {
84 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: LaunchArg + CubeType;
86
87 type Output<EO: Numeric>: LaunchArg + CubeType;
89
90 type Config: RuntimeConfig;
92
93 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric>: CubeType;
96
97 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
99 input: &Self::Input<Lhs, Rhs, EO>,
100 output: &mut Self::Output<EO>,
101 config: Self::Config,
102 #[comptime] lhs_layout_config: GlobalLayoutConfig,
103 #[comptime] rhs_layout_config: GlobalLayoutConfig,
104 #[comptime] out_layout_config: GlobalLayoutConfig,
105 ) -> Self::State<Lhs, Rhs, EO>;
106
107 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
108 _state: &Self::State<Lhs, Rhs, EO>,
109 ) -> View<Line<Lhs>, BatchedCoords> {
110 unexpanded!()
111 }
112 fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
113 _state: &Self::State<Lhs, Rhs, EO>,
114 _batch: usize,
115 ) -> usize {
116 unexpanded!()
117 }
118 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
119 _state: &Self::State<Lhs, Rhs, EO>,
120 ) -> View<Line<Rhs>, BatchedCoords> {
121 unexpanded!()
122 }
123 fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
124 _state: &Self::State<Lhs, Rhs, EO>,
125 _batch: usize,
126 ) -> usize {
127 unexpanded!()
128 }
129 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
130 _state: &Self::State<Lhs, Rhs, EO>,
131 ) -> CubeOption<View<Line<EO>, BatchedCoords>> {
132 unexpanded!()
133 }
134 fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
135 _state: &Self::State<Lhs, Rhs, EO>,
136 _batch: usize,
137 ) -> usize {
138 unexpanded!()
139 }
140 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
141 _state: &mut Self::State<Lhs, Rhs, EO>,
142 ) -> View<Line<EO>, BatchedCoords, ReadWrite> {
143 unexpanded!()
144 }
145 fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
146 _state: &Self::State<Lhs, Rhs, EO>,
147 _batch: usize,
148 ) -> usize {
149 unexpanded!()
150 }
151
152 fn runtime_config<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
153 _state: &Self::State<Lhs, Rhs, EO>,
154 ) -> Self::Config {
155 unexpanded!()
156 }
157}
158
159#[derive(Clone, Copy)]
160pub enum TensorInputIdent {
162 Lhs,
163 Rhs,
164}
165
166#[derive(Clone)]
167pub struct TensorArgs<Config: RuntimeConfig = ()> {
171 _config: PhantomData<Config>,
172}
173
174#[derive(CubeLaunch, CubeType, Clone, Copy)]
175pub struct TensorInputs<Lhs: Numeric, Rhs: Numeric, Acc: Numeric> {
177 lhs: View<Line<Lhs>, BatchedCoords>,
179 lhs_batch: VirtualLayout<Coords1d, Coords1d>,
180 rhs: View<Line<Rhs>, BatchedCoords>,
182 rhs_batch: VirtualLayout<Coords1d, Coords1d>,
183 acc: CubeOption<View<Line<Acc>, BatchedCoords>>,
185 acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
186}
187
188impl<Lhs: Numeric, Rhs: Numeric, Acc: Numeric, A: Routine<()>> ConcreteInputsFactory<A>
189 for TensorInputs<Lhs, Rhs, Acc>
190{
191 fn create<'a, R: Runtime>(
192 client: &ComputeClient<R>,
193 lhs: &'a MatmulInputHandleRef<'a, R>,
194 rhs: &'a MatmulInputHandleRef<'a, R>,
195 blueprint: &A::Blueprint,
196 problem: &MatmulProblem,
197 line_sizes: &MatmulLineSizes,
198 _dtypes: &MatmulElems,
199 ) -> Self::RuntimeArg<'a, R> {
200 let view = |handle: &'a MatmulInputHandleRef<'a, R>,
201 config: GlobalLayoutConfig,
202 line_size| match handle {
203 MatmulInputHandleRef::Normal(handle, _dtype) => {
204 let layout = GlobalLayoutLaunch::from_handle(handle, line_size, config);
205 ViewArg::new::<GlobalLayout>(handle.as_array_arg(line_size), layout)
206 }
207 MatmulInputHandleRef::Quantized {
208 data,
209 scale,
210 shape,
211 scheme,
212 ..
213 } => {
214 let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
215 client, data, scale, shape, problem, **scheme, line_size, config,
216 );
217 let data_view =
218 ViewArg::new::<GlobalLayout>(data.as_array_arg(line_size), data_layout);
219 let scales_view =
220 ViewArg::new::<GlobalScaleLayout>(scale.as_array_arg(1), scales_layout);
221 ViewArg::new_quantized(data_view, scales_view, **scheme)
222 }
223 };
224 let batch_layout = |handle: &'a MatmulInputHandleRef<'a, R>| match handle {
225 MatmulInputHandleRef::Normal(handle, _dtype) => {
226 let layout = BatchLayoutLaunch::from_handle(client, handle, problem);
227 VirtualLayoutLaunch::new::<BatchLayout>(layout)
228 }
229 MatmulInputHandleRef::Quantized { .. } => {
230 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
231 }
232 };
233
234 TensorInputsLaunch::new(
235 view(lhs, blueprint.lhs_global_layout_config(), line_sizes.lhs),
236 batch_layout(lhs),
237 view(rhs, blueprint.rhs_global_layout_config(), line_sizes.rhs),
238 batch_layout(rhs),
239 CubeOptionArgs::None,
240 CubeOptionArgs::None,
241 )
242 }
243}
244
245#[derive(CubeType, CubeLaunch, Clone, Copy)]
246pub struct TensorOutput<EG: Numeric> {
247 view: View<Line<EG>, BatchedCoords, ReadWrite>,
248 batch: VirtualLayout<Coords1d, Coords1d>,
249}
250
251impl<EG: Numeric, A: Routine<()>> ConcreteOutputFactory<A> for TensorOutput<EG> {
252 fn create<'a, R: Runtime>(
253 client: &ComputeClient<R>,
254 out: &'a TensorHandleRef<'a, R>,
255 blueprint: &A::Blueprint,
256 problem: &MatmulProblem,
257 line_sizes: &MatmulLineSizes,
258 _dtypes: &MatmulElems,
259 ) -> Self::RuntimeArg<'a, R> {
260 let layout = GlobalLayoutLaunch::from_handle(
261 out,
262 line_sizes.out,
263 blueprint.out_global_layout_config(),
264 );
265 let batch = BatchLayoutLaunch::from_handle(client, out, problem);
266 let view = ViewArg::new::<GlobalLayout>(out.as_array_arg(line_sizes.out), layout);
267 TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
268 }
269}
270
271#[cube]
272impl<Config: RuntimeConfig> MatmulArgs for TensorArgs<Config> {
273 type Output<EO: Numeric> = TensorOutput<EO>;
274 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorInputs<Lhs, Rhs, EO>;
275 type Config = Config;
276 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
277 (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
278
279 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
280 input: &Self::Input<Lhs, Rhs, EO>,
281 output: &mut Self::Output<EO>,
282 config: Self::Config,
283 #[comptime] _lhs_layout_config: GlobalLayoutConfig,
284 #[comptime] _rhs_layout_config: GlobalLayoutConfig,
285 #[comptime] _out_layout_config: GlobalLayoutConfig,
286 ) -> Self::State<Lhs, Rhs, EO> {
287 (*input, *output, config)
288 }
289
290 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
291 state: &Self::State<Lhs, Rhs, EO>,
292 ) -> View<Line<Lhs>, BatchedCoords> {
293 state.0.lhs
294 }
295
296 fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
297 state: &Self::State<Lhs, Rhs, EO>,
298 batch: usize,
299 ) -> usize {
300 state.0.lhs_batch.to_source_pos(batch)
301 }
302
303 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
304 state: &Self::State<Lhs, Rhs, EO>,
305 ) -> View<Line<Rhs>, BatchedCoords> {
306 state.0.rhs
307 }
308
309 fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
310 state: &Self::State<Lhs, Rhs, EO>,
311 batch: usize,
312 ) -> usize {
313 state.0.rhs_batch.to_source_pos(batch)
314 }
315
316 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
317 state: &Self::State<Lhs, Rhs, EO>,
318 ) -> CubeOption<View<Line<EO>, BatchedCoords>> {
319 state.0.acc
320 }
321
322 fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
323 state: &Self::State<Lhs, Rhs, EO>,
324 batch: usize,
325 ) -> usize {
326 match state.0.acc_batch {
327 CubeOption::Some(layout) => layout.to_source_pos(batch),
328 CubeOption::None => batch,
329 }
330 }
331
332 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
333 state: &mut Self::State<Lhs, Rhs, EO>,
334 ) -> View<Line<EO>, BatchedCoords, ReadWrite> {
335 state.1.view
336 }
337
338 fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
339 state: &Self::State<Lhs, Rhs, EO>,
340 batch: usize,
341 ) -> usize {
342 state.1.batch.to_source_pos(batch)
343 }
344
345 fn runtime_config<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
346 state: &Self::State<Lhs, Rhs, EO>,
347 ) -> Self::Config {
348 state.2.clone()
349 }
350}
351
352#[derive(Clone)]
353pub struct TensorMapArgs<Config: RuntimeConfig = ()> {
357 _config: PhantomData<Config>,
358}
359
360#[derive(CubeLaunch, CubeType, Clone, Copy)]
361pub struct TensorMapInputs<Lhs: Numeric, Rhs: Numeric, EO: Numeric> {
363 pub lhs: View<Line<Lhs>, BatchedCoords>,
365 pub rhs: View<Line<Rhs>, BatchedCoords>,
367 pub acc: CubeOption<View<Line<EO>, BatchedCoords>>,
369 pub acc_batch: CubeOption<VirtualLayout<Coords1d, Coords1d>>,
371}
372
373impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric, A: Routine<(), Blueprint = TilingBlueprint>>
374 ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
375{
376 fn create<'a, R: Runtime>(
377 _client: &ComputeClient<R>,
378 lhs_handle: &'a MatmulInputHandleRef<'a, R>,
379 rhs_handle: &'a MatmulInputHandleRef<'a, R>,
380 blueprint: &A::Blueprint,
381 problem: &MatmulProblem,
382 line_sizes: &MatmulLineSizes,
383 dtypes: &MatmulElems,
384 ) -> Self::RuntimeArg<'a, R> {
385 let lhs = lhs_handle.data();
386 let rhs = rhs_handle.data();
387
388 let tiling_scheme = blueprint.tiling_scheme;
389 let stage_m = tiling_scheme.elements_per_stage_along_m();
390 let stage_n = tiling_scheme.elements_per_stage_along_n();
391 let stage_k = tiling_scheme.elements_per_stage_along_k();
392
393 let stage_size_lhs = match blueprint.swizzle_modes.lhs {
397 SwizzleMode::None => match problem.lhs_layout {
398 definition::MatrixLayout::RowMajor => {
399 vec![1, stage_m, tiling_scheme.tile_size.k]
400 }
401 definition::MatrixLayout::ColMajor => {
402 vec![1, stage_k, tiling_scheme.tile_size.m]
403 }
404 },
405 _ => match problem.lhs_layout {
406 definition::MatrixLayout::RowMajor => {
407 vec![1, stage_m, stage_k]
408 }
409 definition::MatrixLayout::ColMajor => {
410 vec![1, stage_k, stage_m]
411 }
412 },
413 };
414 let stage_size_rhs = match blueprint.swizzle_modes.rhs {
415 SwizzleMode::None => match problem.rhs_layout {
416 definition::MatrixLayout::RowMajor => {
417 vec![1, stage_k, tiling_scheme.tile_size.n]
418 }
419 definition::MatrixLayout::ColMajor => {
420 vec![1, stage_n, tiling_scheme.tile_size.k]
421 }
422 },
423 _ => match problem.rhs_layout {
424 definition::MatrixLayout::RowMajor => {
425 vec![1, stage_k, stage_n]
426 }
427 definition::MatrixLayout::ColMajor => {
428 vec![1, stage_n, stage_k]
429 }
430 },
431 };
432
433 let lhs_rank = lhs.shape.len();
434 let mut lhs_shape = vec![
435 problem.lhs_batches.iter().product(),
436 lhs.shape[lhs_rank - 2],
437 lhs.shape[lhs_rank - 1],
438 ];
439 let mut lhs_strides = if lhs_rank > 2 {
440 lhs.strides[lhs_rank - 3..].to_vec()
441 } else {
442 vec![lhs.strides[0], lhs.strides[1]]
443 };
444
445 let rhs_rank = rhs.shape.len();
446 let mut rhs_shape = vec![
447 problem.rhs_batches.iter().product(),
448 rhs.shape[rhs_rank - 2],
449 rhs.shape[rhs_rank - 1],
450 ];
451 let mut rhs_strides = if rhs_rank > 2 {
452 rhs.strides[rhs_rank - 3..].to_vec()
453 } else {
454 vec![rhs.strides[0], rhs.strides[1]]
455 };
456
457 let mut lhs_transposed = false;
458 let mut rhs_transposed = false;
459
460 let lhs_rank = lhs_strides.len();
461 let rhs_rank = rhs_strides.len();
462
463 if matches!(problem.lhs_layout, definition::MatrixLayout::ColMajor) {
466 lhs_shape.swap(2, 1);
467 lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
468 lhs_transposed = true;
469 }
470 if matches!(problem.rhs_layout, definition::MatrixLayout::ColMajor) {
471 rhs_shape.swap(2, 1);
472 rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
473 rhs_transposed = true;
474 }
475
476 if lhs_rank == 2 {
478 let stride = lhs_strides[0];
479 lhs_strides.insert(0, stride);
480 }
481 if rhs_rank == 2 {
482 let stride = rhs_strides[0];
483 rhs_strides.insert(0, stride);
484 }
485
486 fn swizzle(mode: SwizzleMode) -> TensorMapSwizzle {
487 match mode {
488 SwizzleMode::None => TensorMapSwizzle::None,
489 SwizzleMode::B32 => TensorMapSwizzle::B32,
490 SwizzleMode::B64 => TensorMapSwizzle::B64,
491 SwizzleMode::B128 => TensorMapSwizzle::B128,
492 }
493 }
494
495 let swizzle_lhs = swizzle(blueprint.swizzle_modes.lhs);
496 let swizzle_rhs = swizzle(blueprint.swizzle_modes.rhs);
497
498 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
501 tf32::as_type_native_unchecked()
502 } else {
503 dtypes.lhs_stage
504 };
505 let rhs_elem = if dtypes.rhs_stage == f32::as_type_native_unchecked() {
506 tf32::as_type_native_unchecked()
507 } else {
508 dtypes.rhs_stage
509 };
510
511 let meta_lhs = TensorMapMeta {
512 format: TensorMapFormat::Tiled(TiledArgs {
513 tile_size: stage_size_lhs,
514 }),
515 rank: 3,
516 shape: lhs_shape.clone(),
517 strides: lhs_strides,
518 elem_stride: vec![1, 1, 1],
519 interleave: TensorMapInterleave::None,
520 swizzle: swizzle_lhs,
521 prefetch: TensorMapPrefetch::None,
522 oob_fill: OobFill::Zero,
523 storage_ty: lhs_elem,
524 };
525
526 let meta_rhs = TensorMapMeta {
527 format: TensorMapFormat::Tiled(TiledArgs {
528 tile_size: stage_size_rhs,
529 }),
530 rank: 3,
531 shape: rhs_shape.clone(),
532 strides: rhs_strides,
533 elem_stride: vec![1, 1, 1],
534 interleave: TensorMapInterleave::None,
535 swizzle: swizzle_rhs,
536 prefetch: TensorMapPrefetch::None,
537 oob_fill: OobFill::Zero,
538 storage_ty: rhs_elem,
539 };
540
541 let lhs = TensorMapArg {
542 tensor: lhs.as_tensor_arg(line_sizes.lhs),
543 metadata: meta_lhs,
544 _kind: PhantomData,
545 };
546 let rhs = TensorMapArg {
547 tensor: rhs.as_tensor_arg(line_sizes.rhs),
548 metadata: meta_rhs,
549 _kind: PhantomData,
550 };
551
552 let view = |buffer, shape: &[usize], transposed| {
553 let batches = ScalarArg::new(shape[0]);
554 let (rows, cols) = match transposed {
555 true => (
556 ScalarArg::new(shape[2] as u32),
557 ScalarArg::new(shape[1] as u32),
558 ),
559 false => (
560 ScalarArg::new(shape[1] as u32),
561 ScalarArg::new(shape[2] as u32),
562 ),
563 };
564 let shape = (batches, rows, cols);
565 let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
566 ViewArg::new_tensor_map_tiled::<SimpleTmaGlobalLayout>(buffer, layout)
567 };
568
569 TensorMapInputsLaunch::new(
570 view(lhs, &lhs_shape, lhs_transposed),
571 view(rhs, &rhs_shape, rhs_transposed),
572 CubeOptionArgs::None,
573 CubeOptionArgs::None,
574 )
575 }
576}
577
578#[cube]
579impl<Config: RuntimeConfig> MatmulArgs for TensorMapArgs<Config> {
580 type Input<Lhs: Numeric, Rhs: Numeric, EO: Numeric> = TensorMapInputs<Lhs, Rhs, EO>;
581 type Output<EO: Numeric> = TensorOutput<EO>;
582 type Config = Config;
583 type State<Lhs: Numeric, Rhs: Numeric, EO: Numeric> =
584 (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
585
586 fn init_state<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
587 input: &Self::Input<Lhs, Rhs, EO>,
588 output: &mut Self::Output<EO>,
589 config: Self::Config,
590 #[comptime] _lhs_layout_config: GlobalLayoutConfig,
591 #[comptime] _rhs_layout_config: GlobalLayoutConfig,
592 #[comptime] _out_layout_config: GlobalLayoutConfig,
593 ) -> Self::State<Lhs, Rhs, EO> {
594 (*input, *output, config)
595 }
596
597 fn view_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
598 state: &Self::State<Lhs, Rhs, EO>,
599 ) -> View<Line<Lhs>, BatchedCoords> {
600 state.0.lhs
601 }
602
603 fn batch_lhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
604 _state: &Self::State<Lhs, Rhs, EO>,
605 batch: usize,
606 ) -> usize {
607 batch
608 }
609
610 fn view_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
611 state: &Self::State<Lhs, Rhs, EO>,
612 ) -> View<Line<Rhs>, BatchedCoords> {
613 state.0.rhs
614 }
615
616 fn batch_rhs<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
617 _state: &Self::State<Lhs, Rhs, EO>,
618 batch: usize,
619 ) -> usize {
620 batch
621 }
622
623 fn view_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
624 state: &Self::State<Lhs, Rhs, EO>,
625 ) -> CubeOption<View<Line<EO>, BatchedCoords>> {
626 state.0.acc
627 }
628
629 fn batch_acc<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
630 state: &Self::State<Lhs, Rhs, EO>,
631 batch: usize,
632 ) -> usize {
633 match state.0.acc_batch {
634 CubeOption::Some(layout) => layout.to_source_pos(batch),
635 CubeOption::None => batch,
636 }
637 }
638
639 fn view_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
640 state: &mut Self::State<Lhs, Rhs, EO>,
641 ) -> View<Line<EO>, BatchedCoords, ReadWrite> {
642 state.1.view
643 }
644
645 fn batch_out<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
646 state: &Self::State<Lhs, Rhs, EO>,
647 batch: usize,
648 ) -> usize {
649 state.1.batch.to_source_pos(batch)
650 }
651
652 fn runtime_config<Lhs: Numeric, Rhs: Numeric, EO: Numeric>(
653 state: &Self::State<Lhs, Rhs, EO>,
654 ) -> Self::Config {
655 state.2.clone()
656 }
657}