1use std::marker::PhantomData;
2
3use cubecl::std::tensor::{
4 View,
5 launch::ViewArg,
6 layout::{Coords1d, VirtualLayout, VirtualLayoutLaunch},
7};
8use cubecl::{
9 prelude::*,
10 zspace::{metadata::Metadata, shape, strides},
11};
12use cubecl::{server::TensorMapMeta, unexpanded};
13use cubek_std::{InputBinding, MatrixLayout, stage::SwizzleMode};
14
15use crate::components::global::memory::{
16 BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutConfig, GlobalLayoutLaunch,
17 GlobalScaleLayout, NoopLayout, NoopLayoutLaunch, SimpleTmaGlobalLayout,
18 SimpleTmaGlobalLayoutLaunch,
19};
20use crate::{
21 definition::{Blueprint as _, MatmulElems, MatmulProblem, MatmulVectorSizes},
22 routines::Routine,
23};
24
25define_scalar!(pub Lhs);
26define_scalar!(pub Rhs);
27define_scalar!(pub Acc);
28
29define_size!(pub LhsSize);
30define_size!(pub RhsSize);
31define_size!(pub AccSize);
32
33pub type InputArg<MA> =
35 <MA as MatmulArgs>::Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>;
36
37pub type OutputArg<MA> = <MA as MatmulArgs>::Output<Vector<Acc, AccSize>>;
39
40pub type ConfigArg<MA> = <MA as MatmulArgs>::Config;
42
43pub type InputRuntimeArg<MA, R> = <InputArg<MA> as LaunchArg>::RuntimeArg<R>;
45
46pub type ConfigRuntimeArg<MA, R> = <ConfigArg<MA> as LaunchArg>::RuntimeArg<R>;
48
49pub type OutputRuntimeArg<MA, R> = <OutputArg<MA> as LaunchArg>::RuntimeArg<R>;
51
52pub type BatchedCoords = (usize, u32, u32);
53
54pub trait ConcreteInputsFactory<A: Routine<()>>: LaunchArg {
57 #[allow(clippy::too_many_arguments)]
58 fn create<R: Runtime>(
59 lhs: InputBinding<R>,
60 rhs: InputBinding<R>,
61 blueprint: &A::Blueprint,
62 problem: &MatmulProblem,
63 vector_sizes: &MatmulVectorSizes,
64 dtypes: &MatmulElems,
65 ) -> Self::RuntimeArg<R>;
66}
67
68pub trait ConcreteOutputFactory<A: Routine<()>>: LaunchArg {
71 #[allow(clippy::too_many_arguments)]
72 fn create<R: Runtime>(
73 out: TensorBinding<R>,
74 blueprint: &A::Blueprint,
75 problem: &MatmulProblem,
76 vector_sizes: &MatmulVectorSizes,
77 dtypes: &MatmulElems,
78 ) -> Self::RuntimeArg<R>;
79}
80
81pub trait RuntimeConfig: LaunchArg + CubeType + Clone + Send + Sync {}
82impl<T: LaunchArg + CubeType + Clone + Send + Sync> RuntimeConfig for T {}
83
84#[cube]
85pub trait MatmulArgs: Send + Sync + 'static + Clone {
87 type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>: LaunchArg + CubeType;
89
90 type Output<EO: CubePrimitive>: LaunchArg + CubeType;
92
93 type Config: RuntimeConfig;
95
96 type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>: CubeType;
99
100 fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
102 input: &Self::Input<Lhs, Rhs, EO>,
103 output: &mut Self::Output<EO>,
104 config: Self::Config,
105 #[comptime] lhs_layout_config: GlobalLayoutConfig,
106 #[comptime] rhs_layout_config: GlobalLayoutConfig,
107 #[comptime] out_layout_config: GlobalLayoutConfig,
108 ) -> Self::State<Lhs, Rhs, EO>;
109
110 fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
111 _state: &Self::State<Lhs, Rhs, EO>,
112 ) -> View<Lhs, BatchedCoords> {
113 unexpanded!()
114 }
115 fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
116 _state: &Self::State<Lhs, Rhs, EO>,
117 _batch: usize,
118 ) -> usize {
119 unexpanded!()
120 }
121 fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
122 _state: &Self::State<Lhs, Rhs, EO>,
123 ) -> View<Rhs, BatchedCoords> {
124 unexpanded!()
125 }
126 fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
127 _state: &Self::State<Lhs, Rhs, EO>,
128 _batch: usize,
129 ) -> usize {
130 unexpanded!()
131 }
132 fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
133 _state: &Self::State<Lhs, Rhs, EO>,
134 ) -> ComptimeOption<View<EO, BatchedCoords>> {
135 unexpanded!()
136 }
137 fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
138 _state: &Self::State<Lhs, Rhs, EO>,
139 _batch: usize,
140 ) -> usize {
141 unexpanded!()
142 }
143 fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
144 _state: &mut Self::State<Lhs, Rhs, EO>,
145 ) -> View<EO, BatchedCoords, ReadWrite> {
146 unexpanded!()
147 }
148 fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
149 _state: &Self::State<Lhs, Rhs, EO>,
150 _batch: usize,
151 ) -> usize {
152 unexpanded!()
153 }
154
155 fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
156 _state: &Self::State<Lhs, Rhs, EO>,
157 ) -> Self::Config {
158 unexpanded!()
159 }
160}
161
162#[derive(Clone, Copy)]
163pub enum TensorInputIdent {
165 Lhs,
166 Rhs,
167}
168
169#[derive(Clone)]
170pub struct TensorArgs<Config: RuntimeConfig = ()> {
174 _config: PhantomData<Config>,
175}
176
177#[derive(CubeLaunch, CubeType, Clone, Copy)]
178pub struct TensorInputs<Lhs: CubePrimitive, Rhs: CubePrimitive, Acc: CubePrimitive> {
180 lhs_batch: VirtualLayout<Coords1d, Coords1d>,
182 lhs: View<Lhs, BatchedCoords>,
183 rhs_batch: VirtualLayout<Coords1d, Coords1d>,
185 rhs: View<Rhs, BatchedCoords>,
186 acc_batch: ComptimeOption<VirtualLayout<Coords1d, Coords1d>>,
188 acc: ComptimeOption<View<Acc, BatchedCoords>>,
189}
190
191impl<Lhs: CubePrimitive, Rhs: CubePrimitive, Acc: CubePrimitive, A: Routine<()>>
192 ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, Acc>
193{
194 fn create<R: Runtime>(
195 lhs: InputBinding<R>,
196 rhs: InputBinding<R>,
197 blueprint: &A::Blueprint,
198 problem: &MatmulProblem,
199 vector_sizes: &MatmulVectorSizes,
200 _dtypes: &MatmulElems,
201 ) -> Self::RuntimeArg<R> {
202 let view = |handle: InputBinding<R>, config: GlobalLayoutConfig, vector_size| match handle {
203 InputBinding::Normal(handle, _dtype) => {
204 let layout = GlobalLayoutLaunch::from_handle(&handle, vector_size, config);
205 ViewArg::new_tensor::<GlobalLayout>(handle.into_tensor_arg(), layout)
206 }
207 InputBinding::Quantized {
208 data,
209 scale,
210 shape,
211 scheme,
212 ..
213 } => {
214 let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
215 &data,
216 &scale,
217 &shape,
218 problem,
219 scheme,
220 vector_size,
221 config,
222 );
223 let data_view =
224 ViewArg::new_tensor::<GlobalLayout>(data.into_tensor_arg(), data_layout);
225 let scales_view = ViewArg::new_tensor::<GlobalScaleLayout>(
226 scale.into_tensor_arg(),
227 scales_layout,
228 );
229 ViewArg::new_quantized(data_view, scales_view, scheme)
230 }
231 };
232 let batch_layout = |handle: &InputBinding<R>| match handle {
233 InputBinding::Normal(handle, _dtype) => {
234 let layout = BatchLayoutLaunch::from_handle(handle, problem);
235 VirtualLayoutLaunch::new::<BatchLayout>(layout)
236 }
237 InputBinding::Quantized { .. } => {
238 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
239 }
240 };
241
242 TensorInputsLaunch::new(
243 batch_layout(&lhs),
244 view(lhs, blueprint.lhs_global_layout_config(), vector_sizes.lhs),
245 batch_layout(&rhs),
246 view(rhs, blueprint.rhs_global_layout_config(), vector_sizes.rhs),
247 ComptimeOptionArgs::None,
248 ComptimeOptionArgs::None,
249 )
250 }
251}
252
253#[derive(CubeType, CubeLaunch, Clone, Copy)]
254pub struct TensorOutput<EG: CubePrimitive> {
255 view: View<EG, BatchedCoords, ReadWrite>,
256 batch: VirtualLayout<Coords1d, Coords1d>,
257}
258
259impl<EG: CubePrimitive, A: Routine<()>> ConcreteOutputFactory<A> for TensorOutput<EG> {
260 fn create<R: Runtime>(
261 out: TensorBinding<R>,
262 blueprint: &A::Blueprint,
263 problem: &MatmulProblem,
264 vector_sizes: &MatmulVectorSizes,
265 _dtypes: &MatmulElems,
266 ) -> Self::RuntimeArg<R> {
267 let layout = GlobalLayoutLaunch::from_handle(
268 &out,
269 vector_sizes.out,
270 blueprint.out_global_layout_config(),
271 );
272 let batch = BatchLayoutLaunch::from_handle(&out, problem);
273 let view = ViewArg::new_tensor::<GlobalLayout>(out.into_tensor_arg(), layout);
274 TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
275 }
276}
277
278#[cube]
279impl<Config: RuntimeConfig> MatmulArgs for TensorArgs<Config> {
280 type Output<EO: CubePrimitive> = TensorOutput<EO>;
281 type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
282 TensorInputs<Lhs, Rhs, EO>;
283 type Config = Config;
284 type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
285 (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
286
287 fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
288 input: &Self::Input<Lhs, Rhs, EO>,
289 output: &mut Self::Output<EO>,
290 config: Self::Config,
291 #[comptime] _lhs_layout_config: GlobalLayoutConfig,
292 #[comptime] _rhs_layout_config: GlobalLayoutConfig,
293 #[comptime] _out_layout_config: GlobalLayoutConfig,
294 ) -> Self::State<Lhs, Rhs, EO> {
295 (*input, *output, config)
296 }
297
298 fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
299 state: &Self::State<Lhs, Rhs, EO>,
300 ) -> View<Lhs, BatchedCoords> {
301 state.0.lhs
302 }
303
304 fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
305 state: &Self::State<Lhs, Rhs, EO>,
306 batch: usize,
307 ) -> usize {
308 state.0.lhs_batch.to_source_pos(batch)
309 }
310
311 fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
312 state: &Self::State<Lhs, Rhs, EO>,
313 ) -> View<Rhs, BatchedCoords> {
314 state.0.rhs
315 }
316
317 fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
318 state: &Self::State<Lhs, Rhs, EO>,
319 batch: usize,
320 ) -> usize {
321 state.0.rhs_batch.to_source_pos(batch)
322 }
323
324 fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
325 state: &Self::State<Lhs, Rhs, EO>,
326 ) -> ComptimeOption<View<EO, BatchedCoords>> {
327 state.0.acc
328 }
329
330 fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
331 state: &Self::State<Lhs, Rhs, EO>,
332 batch: usize,
333 ) -> usize {
334 #[comptime]
335 #[comptime]
336 match state.0.acc_batch {
337 ComptimeOption::Some(layout) => layout.to_source_pos(batch),
338 ComptimeOption::None => batch,
339 }
340 }
341
342 fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
343 state: &mut Self::State<Lhs, Rhs, EO>,
344 ) -> View<EO, BatchedCoords, ReadWrite> {
345 state.1.view
346 }
347
348 fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
349 state: &Self::State<Lhs, Rhs, EO>,
350 batch: usize,
351 ) -> usize {
352 state.1.batch.to_source_pos(batch)
353 }
354
355 fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
356 state: &Self::State<Lhs, Rhs, EO>,
357 ) -> Self::Config {
358 state.2.clone()
359 }
360}
361
362#[derive(Clone)]
363pub struct TensorMapArgs<Config: RuntimeConfig = ()> {
367 _config: PhantomData<Config>,
368}
369
370#[derive(CubeLaunch, CubeType, Clone, Copy)]
371pub struct TensorMapInputs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> {
373 pub lhs: View<Lhs, BatchedCoords>,
375 pub rhs: View<Rhs, BatchedCoords>,
377 pub acc: ComptimeOption<View<EO, BatchedCoords>>,
379 pub acc_batch: ComptimeOption<VirtualLayout<Coords1d, Coords1d>>,
381}
382
383impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<()>>
384 ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
385{
386 fn create<R: Runtime>(
387 lhs_handle: InputBinding<R>,
388 rhs_handle: InputBinding<R>,
389 blueprint: &A::Blueprint,
390 problem: &MatmulProblem,
391 _vector_sizes: &MatmulVectorSizes,
392 dtypes: &MatmulElems,
393 ) -> Self::RuntimeArg<R> {
394 let lhs = lhs_handle.into_data();
395 let rhs = rhs_handle.into_data();
396
397 let tiling_scheme = blueprint.tiling_scheme();
398 let stage_m = tiling_scheme.elements_per_stage_along_m();
399 let stage_n = tiling_scheme.elements_per_stage_along_n();
400 let stage_k = tiling_scheme.elements_per_stage_along_k();
401
402 let stage_size_lhs = match blueprint.swizzle_modes().lhs {
406 SwizzleMode::None => match problem.lhs_layout {
407 MatrixLayout::RowMajor => {
408 shape![1, stage_m as usize, tiling_scheme.tile_size.k as usize]
409 }
410 MatrixLayout::ColMajor => {
411 shape![1, stage_k as usize, tiling_scheme.tile_size.m as usize]
412 }
413 },
414 _ => match problem.lhs_layout {
415 MatrixLayout::RowMajor => {
416 shape![1, stage_m as usize, stage_k as usize]
417 }
418 MatrixLayout::ColMajor => {
419 shape![1, stage_k as usize, stage_m as usize]
420 }
421 },
422 };
423 let stage_size_rhs = match blueprint.swizzle_modes().rhs {
424 SwizzleMode::None => match problem.rhs_layout {
425 MatrixLayout::RowMajor => {
426 shape![1, stage_k as usize, tiling_scheme.tile_size.n as usize]
427 }
428 MatrixLayout::ColMajor => {
429 shape![1, stage_n as usize, tiling_scheme.tile_size.k as usize]
430 }
431 },
432 _ => match problem.rhs_layout {
433 MatrixLayout::RowMajor => {
434 shape![1, stage_k as usize, stage_n as usize]
435 }
436 MatrixLayout::ColMajor => {
437 shape![1, stage_n as usize, stage_k as usize]
438 }
439 },
440 };
441
442 let lhs_rank = lhs.shape.len();
443 let mut lhs_shape = shape![
444 problem.lhs_batches.iter().product(),
445 lhs.shape[lhs_rank - 2],
446 lhs.shape[lhs_rank - 1],
447 ];
448 let mut lhs_strides = if lhs_rank > 2 {
449 lhs.strides[lhs_rank - 3..].into()
450 } else {
451 strides![lhs.strides[0], lhs.strides[1]]
452 };
453
454 let rhs_rank = rhs.shape.len();
455 let mut rhs_shape = shape![
456 problem.rhs_batches.iter().product(),
457 rhs.shape[rhs_rank - 2],
458 rhs.shape[rhs_rank - 1],
459 ];
460 let mut rhs_strides = if rhs_rank > 2 {
461 rhs.strides[rhs_rank - 3..].into()
462 } else {
463 strides![rhs.strides[0], rhs.strides[1]]
464 };
465
466 let mut lhs_transposed = false;
467 let mut rhs_transposed = false;
468
469 let lhs_rank = lhs_strides.len();
470 let rhs_rank = rhs_strides.len();
471
472 if matches!(problem.lhs_layout, MatrixLayout::ColMajor) {
475 lhs_shape.swap(2, 1);
476 lhs_strides.swap(lhs_rank - 1, lhs_rank - 2);
477 lhs_transposed = true;
478 }
479 if matches!(problem.rhs_layout, MatrixLayout::ColMajor) {
480 rhs_shape.swap(2, 1);
481 rhs_strides.swap(rhs_rank - 1, rhs_rank - 2);
482 rhs_transposed = true;
483 }
484
485 if lhs_rank == 2 {
487 let stride = lhs_strides[0];
488 lhs_strides.insert(0, stride);
489 }
490 if rhs_rank == 2 {
491 let stride = rhs_strides[0];
492 rhs_strides.insert(0, stride);
493 }
494
495 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked().storage_type() {
498 tf32::as_type_native_unchecked().storage_type()
499 } else {
500 dtypes.lhs_stage
501 };
502 let rhs_elem = if dtypes.rhs_stage == f32::as_type_native_unchecked().storage_type() {
503 tf32::as_type_native_unchecked().storage_type()
504 } else {
505 dtypes.rhs_stage
506 };
507
508 let meta_lhs = TensorMapMeta {
509 format: TensorMapFormat::Tiled(TiledArgs {
510 tile_size: stage_size_lhs,
511 }),
512 metadata: Metadata::new(lhs_shape.clone(), lhs_strides),
513 elem_stride: strides![1, 1, 1],
514 interleave: TensorMapInterleave::None,
515 swizzle: blueprint.swizzle_modes().lhs.into(),
516 prefetch: TensorMapPrefetch::None,
517 oob_fill: OobFill::Zero,
518 storage_ty: lhs_elem,
519 };
520
521 let meta_rhs = TensorMapMeta {
522 format: TensorMapFormat::Tiled(TiledArgs {
523 tile_size: stage_size_rhs,
524 }),
525 metadata: Metadata::new(rhs_shape.clone(), rhs_strides),
526 elem_stride: strides![1, 1, 1],
527 interleave: TensorMapInterleave::None,
528 swizzle: blueprint.swizzle_modes().rhs.into(),
529 prefetch: TensorMapPrefetch::None,
530 oob_fill: OobFill::Zero,
531 storage_ty: rhs_elem,
532 };
533
534 let lhs = TensorMapArg {
535 tensor: lhs.into_tensor_arg(),
536 metadata: meta_lhs,
537 _kind: PhantomData,
538 };
539 let rhs = TensorMapArg {
540 tensor: rhs.into_tensor_arg(),
541 metadata: meta_rhs,
542 _kind: PhantomData,
543 };
544
545 let view = |buffer, shape: &[usize], transposed| {
546 let batches = shape[0];
547 let (rows, cols) = match transposed {
548 true => (shape[2] as u32, shape[1] as u32),
549 false => (shape[1] as u32, shape[2] as u32),
550 };
551 let shape = (batches, rows, cols);
552 let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
553 ViewArg::new_tensor_map_tiled::<SimpleTmaGlobalLayout>(buffer, layout)
554 };
555
556 TensorMapInputsLaunch::new(
557 view(lhs, &lhs_shape, lhs_transposed),
558 view(rhs, &rhs_shape, rhs_transposed),
559 ComptimeOptionArgs::None,
560 ComptimeOptionArgs::None,
561 )
562 }
563}
564
565#[cube]
566impl<Config: RuntimeConfig> MatmulArgs for TensorMapArgs<Config> {
567 type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
568 TensorMapInputs<Lhs, Rhs, EO>;
569 type Output<EO: CubePrimitive> = TensorOutput<EO>;
570 type Config = Config;
571 type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
572 (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
573
574 fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
575 input: &Self::Input<Lhs, Rhs, EO>,
576 output: &mut Self::Output<EO>,
577 config: Self::Config,
578 #[comptime] _lhs_layout_config: GlobalLayoutConfig,
579 #[comptime] _rhs_layout_config: GlobalLayoutConfig,
580 #[comptime] _out_layout_config: GlobalLayoutConfig,
581 ) -> Self::State<Lhs, Rhs, EO> {
582 (*input, *output, config)
583 }
584
585 fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
586 state: &Self::State<Lhs, Rhs, EO>,
587 ) -> View<Lhs, BatchedCoords> {
588 state.0.lhs
589 }
590
591 fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
592 _state: &Self::State<Lhs, Rhs, EO>,
593 batch: usize,
594 ) -> usize {
595 batch
596 }
597
598 fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
599 state: &Self::State<Lhs, Rhs, EO>,
600 ) -> View<Rhs, BatchedCoords> {
601 state.0.rhs
602 }
603
604 fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
605 _state: &Self::State<Lhs, Rhs, EO>,
606 batch: usize,
607 ) -> usize {
608 batch
609 }
610
611 fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
612 state: &Self::State<Lhs, Rhs, EO>,
613 ) -> ComptimeOption<View<EO, BatchedCoords>> {
614 state.0.acc
615 }
616
617 fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
618 state: &Self::State<Lhs, Rhs, EO>,
619 batch: usize,
620 ) -> usize {
621 #[comptime]
622 #[comptime]
623 match state.0.acc_batch {
624 ComptimeOption::Some(layout) => layout.to_source_pos(batch),
625 ComptimeOption::None => batch,
626 }
627 }
628
629 fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
630 state: &mut Self::State<Lhs, Rhs, EO>,
631 ) -> View<EO, BatchedCoords, ReadWrite> {
632 state.1.view
633 }
634
635 fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
636 state: &Self::State<Lhs, Rhs, EO>,
637 batch: usize,
638 ) -> usize {
639 state.1.batch.to_source_pos(batch)
640 }
641
642 fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
643 state: &Self::State<Lhs, Rhs, EO>,
644 ) -> Self::Config {
645 state.2.clone()
646 }
647}