1use std::marker::PhantomData;
2
3use cubecl::std::tensor::{
4 View,
5 launch::ViewArg,
6 layout::{Coords1d, VirtualLayout, VirtualLayoutLaunch},
7};
8use cubecl::unexpanded;
9use cubecl::{
10 prelude::*,
11 zspace::{metadata::Metadata, shape, strides},
12};
13use cubek_std::launch::tma::{remap_storage_for_tma, tma_meta_tiled, transpose_inner_for_tma};
14use cubek_std::{InputBinding, MatrixLayout, stage::SwizzleMode};
15
16use crate::components::global::memory::{
17 BatchLayout, BatchLayoutLaunch, GlobalLayout, GlobalLayoutConfig, GlobalLayoutLaunch,
18 GlobalScaleLayout, NoopLayout, NoopLayoutLaunch, SimpleTmaGlobalLayout,
19 SimpleTmaGlobalLayoutLaunch,
20};
21use crate::{
22 definition::{Blueprint as _, MatmulElems, MatmulProblem, MatmulVectorSizes},
23 routines::Routine,
24};
25
26define_scalar!(pub Lhs);
27define_scalar!(pub Rhs);
28define_scalar!(pub Acc);
29
30define_size!(pub LhsSize);
31define_size!(pub RhsSize);
32define_size!(pub AccSize);
33
34pub type InputArg<MA> =
36 <MA as MatmulArgs>::Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>;
37
38pub type OutputArg<MA> = <MA as MatmulArgs>::Output<Vector<Acc, AccSize>>;
40
41pub type ConfigArg<MA> = <MA as MatmulArgs>::Config;
43
44pub type InputRuntimeArg<MA, R> = <InputArg<MA> as LaunchArg>::RuntimeArg<R>;
46
47pub type ConfigRuntimeArg<MA, R> = <ConfigArg<MA> as LaunchArg>::RuntimeArg<R>;
49
50pub type OutputRuntimeArg<MA, R> = <OutputArg<MA> as LaunchArg>::RuntimeArg<R>;
52
53pub type BatchedCoords = (usize, u32, u32);
54
55pub trait ConcreteInputsFactory<A: Routine<()>>: LaunchArg {
58 #[allow(clippy::too_many_arguments)]
59 fn create<R: Runtime>(
60 lhs: InputBinding<R>,
61 rhs: InputBinding<R>,
62 blueprint: &A::Blueprint,
63 problem: &MatmulProblem,
64 vector_sizes: &MatmulVectorSizes,
65 dtypes: &MatmulElems,
66 ) -> Self::RuntimeArg<R>;
67}
68
69pub trait ConcreteOutputFactory<A: Routine<()>>: LaunchArg {
72 #[allow(clippy::too_many_arguments)]
73 fn create<R: Runtime>(
74 out: TensorBinding<R>,
75 blueprint: &A::Blueprint,
76 problem: &MatmulProblem,
77 vector_sizes: &MatmulVectorSizes,
78 dtypes: &MatmulElems,
79 ) -> Self::RuntimeArg<R>;
80}
81
82pub trait RuntimeConfig: LaunchArg + CubeType + Clone + Send + Sync {}
83impl<T: LaunchArg + CubeType + Clone + Send + Sync> RuntimeConfig for T {}
84
85#[cube]
86pub trait MatmulArgs: Send + Sync + 'static + Clone {
88 type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>: LaunchArg + CubeType;
90
91 type Output<EO: CubePrimitive>: LaunchArg + CubeType;
93
94 type Config: RuntimeConfig;
96
97 type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>: CubeType;
100
101 fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
103 input: &Self::Input<Lhs, Rhs, EO>,
104 output: &mut Self::Output<EO>,
105 config: Self::Config,
106 #[comptime] lhs_layout_config: GlobalLayoutConfig,
107 #[comptime] rhs_layout_config: GlobalLayoutConfig,
108 #[comptime] out_layout_config: GlobalLayoutConfig,
109 ) -> Self::State<Lhs, Rhs, EO>;
110
111 fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
112 _state: &Self::State<Lhs, Rhs, EO>,
113 ) -> View<Lhs, BatchedCoords> {
114 unexpanded!()
115 }
116 fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
117 _state: &Self::State<Lhs, Rhs, EO>,
118 _batch: usize,
119 ) -> usize {
120 unexpanded!()
121 }
122 fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
123 _state: &Self::State<Lhs, Rhs, EO>,
124 ) -> View<Rhs, BatchedCoords> {
125 unexpanded!()
126 }
127 fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
128 _state: &Self::State<Lhs, Rhs, EO>,
129 _batch: usize,
130 ) -> usize {
131 unexpanded!()
132 }
133 fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
134 _state: &Self::State<Lhs, Rhs, EO>,
135 ) -> ComptimeOption<View<EO, BatchedCoords>> {
136 unexpanded!()
137 }
138 fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
139 _state: &Self::State<Lhs, Rhs, EO>,
140 _batch: usize,
141 ) -> usize {
142 unexpanded!()
143 }
144 fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
145 _state: &mut Self::State<Lhs, Rhs, EO>,
146 ) -> View<EO, BatchedCoords, ReadWrite> {
147 unexpanded!()
148 }
149 fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
150 _state: &Self::State<Lhs, Rhs, EO>,
151 _batch: usize,
152 ) -> usize {
153 unexpanded!()
154 }
155
156 fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
157 _state: &Self::State<Lhs, Rhs, EO>,
158 ) -> Self::Config {
159 unexpanded!()
160 }
161}
162
163#[derive(Clone, Copy)]
164pub enum TensorInputIdent {
166 Lhs,
167 Rhs,
168}
169
170#[derive(Clone)]
171pub struct TensorArgs<Config: RuntimeConfig = ()> {
175 _config: PhantomData<Config>,
176}
177
178#[derive(CubeLaunch, CubeType, Clone, Copy)]
179pub struct TensorInputs<Lhs: CubePrimitive, Rhs: CubePrimitive, Acc: CubePrimitive> {
181 lhs_batch: VirtualLayout<Coords1d, Coords1d>,
183 lhs: View<Lhs, BatchedCoords>,
184 rhs_batch: VirtualLayout<Coords1d, Coords1d>,
186 rhs: View<Rhs, BatchedCoords>,
187 acc_batch: ComptimeOption<VirtualLayout<Coords1d, Coords1d>>,
189 acc: ComptimeOption<View<Acc, BatchedCoords>>,
190}
191
192impl<Lhs: CubePrimitive, Rhs: CubePrimitive, Acc: CubePrimitive, A: Routine<()>>
193 ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, Acc>
194{
195 fn create<R: Runtime>(
196 lhs: InputBinding<R>,
197 rhs: InputBinding<R>,
198 blueprint: &A::Blueprint,
199 problem: &MatmulProblem,
200 vector_sizes: &MatmulVectorSizes,
201 _dtypes: &MatmulElems,
202 ) -> Self::RuntimeArg<R> {
203 let view = |handle: InputBinding<R>, config: GlobalLayoutConfig, vector_size| match handle {
204 InputBinding::Normal(handle, _dtype) => {
205 let layout = GlobalLayoutLaunch::from_handle(&handle, vector_size, config);
206 ViewArg::new_tensor::<GlobalLayout>(handle.into_tensor_arg(), layout)
207 }
208 InputBinding::Quantized {
209 data,
210 scale,
211 shape,
212 scheme,
213 ..
214 } => {
215 let (data_layout, scales_layout) = GlobalLayoutLaunch::from_quantized_handle(
216 &data,
217 &scale,
218 &shape,
219 problem,
220 scheme,
221 vector_size,
222 config,
223 );
224 let data_view =
225 ViewArg::new_tensor::<GlobalLayout>(data.into_tensor_arg(), data_layout);
226 let scales_view = ViewArg::new_tensor::<GlobalScaleLayout>(
227 scale.into_tensor_arg(),
228 scales_layout,
229 );
230 ViewArg::new_quantized(data_view, scales_view, scheme)
231 }
232 };
233 let batch_layout = |handle: &InputBinding<R>| match handle {
234 InputBinding::Normal(handle, _dtype) => {
235 let layout = BatchLayoutLaunch::from_handle(handle, problem);
236 VirtualLayoutLaunch::new::<BatchLayout>(layout)
237 }
238 InputBinding::Quantized { .. } => {
239 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new())
240 }
241 };
242
243 TensorInputsLaunch::new(
244 batch_layout(&lhs),
245 view(lhs, blueprint.lhs_global_layout_config(), vector_sizes.lhs),
246 batch_layout(&rhs),
247 view(rhs, blueprint.rhs_global_layout_config(), vector_sizes.rhs),
248 ComptimeOptionArgs::None,
249 ComptimeOptionArgs::None,
250 )
251 }
252}
253
254#[derive(CubeType, CubeLaunch, Clone, Copy)]
255pub struct TensorOutput<EG: CubePrimitive> {
256 view: View<EG, BatchedCoords, ReadWrite>,
257 batch: VirtualLayout<Coords1d, Coords1d>,
258}
259
260impl<EG: CubePrimitive, A: Routine<()>> ConcreteOutputFactory<A> for TensorOutput<EG> {
261 fn create<R: Runtime>(
262 out: TensorBinding<R>,
263 blueprint: &A::Blueprint,
264 problem: &MatmulProblem,
265 vector_sizes: &MatmulVectorSizes,
266 _dtypes: &MatmulElems,
267 ) -> Self::RuntimeArg<R> {
268 let layout = GlobalLayoutLaunch::from_handle(
269 &out,
270 vector_sizes.out,
271 blueprint.out_global_layout_config(),
272 );
273 let batch = BatchLayoutLaunch::from_handle(&out, problem);
274 let view = ViewArg::new_tensor::<GlobalLayout>(out.into_tensor_arg(), layout);
275 TensorOutputLaunch::new(view, VirtualLayoutLaunch::new::<BatchLayout>(batch))
276 }
277}
278
279#[cube]
280impl<Config: RuntimeConfig> MatmulArgs for TensorArgs<Config> {
281 type Output<EO: CubePrimitive> = TensorOutput<EO>;
282 type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
283 TensorInputs<Lhs, Rhs, EO>;
284 type Config = Config;
285 type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
286 (TensorInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
287
288 fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
289 input: &Self::Input<Lhs, Rhs, EO>,
290 output: &mut Self::Output<EO>,
291 config: Self::Config,
292 #[comptime] _lhs_layout_config: GlobalLayoutConfig,
293 #[comptime] _rhs_layout_config: GlobalLayoutConfig,
294 #[comptime] _out_layout_config: GlobalLayoutConfig,
295 ) -> Self::State<Lhs, Rhs, EO> {
296 (*input, *output, config)
297 }
298
299 fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
300 state: &Self::State<Lhs, Rhs, EO>,
301 ) -> View<Lhs, BatchedCoords> {
302 state.0.lhs
303 }
304
305 fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
306 state: &Self::State<Lhs, Rhs, EO>,
307 batch: usize,
308 ) -> usize {
309 state.0.lhs_batch.to_source_pos(batch)
310 }
311
312 fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
313 state: &Self::State<Lhs, Rhs, EO>,
314 ) -> View<Rhs, BatchedCoords> {
315 state.0.rhs
316 }
317
318 fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
319 state: &Self::State<Lhs, Rhs, EO>,
320 batch: usize,
321 ) -> usize {
322 state.0.rhs_batch.to_source_pos(batch)
323 }
324
325 fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
326 state: &Self::State<Lhs, Rhs, EO>,
327 ) -> ComptimeOption<View<EO, BatchedCoords>> {
328 state.0.acc
329 }
330
331 fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
332 state: &Self::State<Lhs, Rhs, EO>,
333 batch: usize,
334 ) -> usize {
335 #[comptime]
336 #[comptime]
337 match state.0.acc_batch {
338 ComptimeOption::Some(layout) => layout.to_source_pos(batch),
339 ComptimeOption::None => batch,
340 }
341 }
342
343 fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
344 state: &mut Self::State<Lhs, Rhs, EO>,
345 ) -> View<EO, BatchedCoords, ReadWrite> {
346 state.1.view
347 }
348
349 fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
350 state: &Self::State<Lhs, Rhs, EO>,
351 batch: usize,
352 ) -> usize {
353 state.1.batch.to_source_pos(batch)
354 }
355
356 fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
357 state: &Self::State<Lhs, Rhs, EO>,
358 ) -> Self::Config {
359 state.2.clone()
360 }
361}
362
363#[derive(Clone)]
364pub struct TensorMapArgs<Config: RuntimeConfig = ()> {
368 _config: PhantomData<Config>,
369}
370
371#[derive(CubeLaunch, CubeType, Clone, Copy)]
372pub struct TensorMapInputs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> {
374 pub lhs: View<Lhs, BatchedCoords>,
376 pub rhs: View<Rhs, BatchedCoords>,
378 pub acc: ComptimeOption<View<EO, BatchedCoords>>,
380 pub acc_batch: ComptimeOption<VirtualLayout<Coords1d, Coords1d>>,
382}
383
384impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<()>>
385 ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
386{
387 fn create<R: Runtime>(
388 lhs_handle: InputBinding<R>,
389 rhs_handle: InputBinding<R>,
390 blueprint: &A::Blueprint,
391 problem: &MatmulProblem,
392 _vector_sizes: &MatmulVectorSizes,
393 dtypes: &MatmulElems,
394 ) -> Self::RuntimeArg<R> {
395 let lhs = lhs_handle.into_data();
396 let rhs = rhs_handle.into_data();
397
398 let tiling_scheme = blueprint.tiling_scheme();
399 let stage_m = tiling_scheme.elements_per_stage_along_m();
400 let stage_n = tiling_scheme.elements_per_stage_along_n();
401 let stage_k = tiling_scheme.elements_per_stage_along_k();
402
403 let stage_size_lhs = match blueprint.swizzle_modes().lhs {
407 SwizzleMode::None => match problem.lhs_layout {
408 MatrixLayout::RowMajor => {
409 shape![1, stage_m as usize, tiling_scheme.tile_size.k as usize]
410 }
411 MatrixLayout::ColMajor => {
412 shape![1, stage_k as usize, tiling_scheme.tile_size.m as usize]
413 }
414 },
415 _ => match problem.lhs_layout {
416 MatrixLayout::RowMajor => {
417 shape![1, stage_m as usize, stage_k as usize]
418 }
419 MatrixLayout::ColMajor => {
420 shape![1, stage_k as usize, stage_m as usize]
421 }
422 },
423 };
424 let stage_size_rhs = match blueprint.swizzle_modes().rhs {
425 SwizzleMode::None => match problem.rhs_layout {
426 MatrixLayout::RowMajor => {
427 shape![1, stage_k as usize, tiling_scheme.tile_size.n as usize]
428 }
429 MatrixLayout::ColMajor => {
430 shape![1, stage_n as usize, tiling_scheme.tile_size.k as usize]
431 }
432 },
433 _ => match problem.rhs_layout {
434 MatrixLayout::RowMajor => {
435 shape![1, stage_k as usize, stage_n as usize]
436 }
437 MatrixLayout::ColMajor => {
438 shape![1, stage_n as usize, stage_k as usize]
439 }
440 },
441 };
442
443 let lhs_rank = lhs.shape.len();
444 let mut lhs_shape = shape![
445 problem.lhs_batches.iter().product(),
446 lhs.shape[lhs_rank - 2],
447 lhs.shape[lhs_rank - 1],
448 ];
449 let mut lhs_strides = if lhs_rank > 2 {
450 lhs.strides[lhs_rank - 3..].into()
451 } else {
452 strides![lhs.strides[0], lhs.strides[1]]
453 };
454
455 let rhs_rank = rhs.shape.len();
456 let mut rhs_shape = shape![
457 problem.rhs_batches.iter().product(),
458 rhs.shape[rhs_rank - 2],
459 rhs.shape[rhs_rank - 1],
460 ];
461 let mut rhs_strides = if rhs_rank > 2 {
462 rhs.strides[rhs_rank - 3..].into()
463 } else {
464 strides![rhs.strides[0], rhs.strides[1]]
465 };
466
467 let lhs_transposed =
468 transpose_inner_for_tma(&mut lhs_shape, &mut lhs_strides, problem.lhs_layout);
469 let rhs_transposed =
470 transpose_inner_for_tma(&mut rhs_shape, &mut rhs_strides, problem.rhs_layout);
471
472 if lhs_strides.len() == 2 {
474 let stride = lhs_strides[0];
475 lhs_strides.insert(0, stride);
476 }
477 if rhs_strides.len() == 2 {
478 let stride = rhs_strides[0];
479 rhs_strides.insert(0, stride);
480 }
481
482 let meta_lhs = tma_meta_tiled(
483 Metadata::new(lhs_shape.clone(), lhs_strides),
484 stage_size_lhs,
485 remap_storage_for_tma(dtypes.lhs_stage),
486 blueprint.swizzle_modes().lhs.into(),
487 );
488
489 let meta_rhs = tma_meta_tiled(
490 Metadata::new(rhs_shape.clone(), rhs_strides),
491 stage_size_rhs,
492 remap_storage_for_tma(dtypes.rhs_stage),
493 blueprint.swizzle_modes().rhs.into(),
494 );
495
496 let lhs = TensorMapArg {
497 tensor: lhs.into_tensor_arg(),
498 metadata: meta_lhs,
499 _kind: PhantomData,
500 };
501 let rhs = TensorMapArg {
502 tensor: rhs.into_tensor_arg(),
503 metadata: meta_rhs,
504 _kind: PhantomData,
505 };
506
507 let view = |buffer, shape: &[usize], transposed| {
508 let batches = shape[0];
509 let (rows, cols) = match transposed {
510 true => (shape[2] as u32, shape[1] as u32),
511 false => (shape[1] as u32, shape[2] as u32),
512 };
513 let shape = (batches, rows, cols);
514 let layout = SimpleTmaGlobalLayoutLaunch::new(transposed, shape);
515 ViewArg::new_tensor_map_tiled::<SimpleTmaGlobalLayout>(buffer, layout)
516 };
517
518 TensorMapInputsLaunch::new(
519 view(lhs, &lhs_shape, lhs_transposed),
520 view(rhs, &rhs_shape, rhs_transposed),
521 ComptimeOptionArgs::None,
522 ComptimeOptionArgs::None,
523 )
524 }
525}
526
527#[cube]
528impl<Config: RuntimeConfig> MatmulArgs for TensorMapArgs<Config> {
529 type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
530 TensorMapInputs<Lhs, Rhs, EO>;
531 type Output<EO: CubePrimitive> = TensorOutput<EO>;
532 type Config = Config;
533 type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> =
534 (TensorMapInputs<Lhs, Rhs, EO>, TensorOutput<EO>, Config);
535
536 fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
537 input: &Self::Input<Lhs, Rhs, EO>,
538 output: &mut Self::Output<EO>,
539 config: Self::Config,
540 #[comptime] _lhs_layout_config: GlobalLayoutConfig,
541 #[comptime] _rhs_layout_config: GlobalLayoutConfig,
542 #[comptime] _out_layout_config: GlobalLayoutConfig,
543 ) -> Self::State<Lhs, Rhs, EO> {
544 (*input, *output, config)
545 }
546
547 fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
548 state: &Self::State<Lhs, Rhs, EO>,
549 ) -> View<Lhs, BatchedCoords> {
550 state.0.lhs
551 }
552
553 fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
554 _state: &Self::State<Lhs, Rhs, EO>,
555 batch: usize,
556 ) -> usize {
557 batch
558 }
559
560 fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
561 state: &Self::State<Lhs, Rhs, EO>,
562 ) -> View<Rhs, BatchedCoords> {
563 state.0.rhs
564 }
565
566 fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
567 _state: &Self::State<Lhs, Rhs, EO>,
568 batch: usize,
569 ) -> usize {
570 batch
571 }
572
573 fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
574 state: &Self::State<Lhs, Rhs, EO>,
575 ) -> ComptimeOption<View<EO, BatchedCoords>> {
576 state.0.acc
577 }
578
579 fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
580 state: &Self::State<Lhs, Rhs, EO>,
581 batch: usize,
582 ) -> usize {
583 #[comptime]
584 #[comptime]
585 match state.0.acc_batch {
586 ComptimeOption::Some(layout) => layout.to_source_pos(batch),
587 ComptimeOption::None => batch,
588 }
589 }
590
591 fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
592 state: &mut Self::State<Lhs, Rhs, EO>,
593 ) -> View<EO, BatchedCoords, ReadWrite> {
594 state.1.view
595 }
596
597 fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
598 state: &Self::State<Lhs, Rhs, EO>,
599 batch: usize,
600 ) -> usize {
601 state.1.batch.to_source_pos(batch)
602 }
603
604 fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(
605 state: &Self::State<Lhs, Rhs, EO>,
606 ) -> Self::Config {
607 state.2.clone()
608 }
609}