cubecl_matmul/kernels/layered/selector/
select_kernel.rs1use crate::MatmulInputHandleRef;
2use crate::components::batch::BatchConfig;
3use crate::components::{
4 InputArg, InputRuntimeArg, MatmulElems, MatmulLineSizes, MatmulSetupError, OutputRuntimeArg,
5};
6use crate::components::{
7 MatmulProblem, MatmulSpec, OutputArg,
8 global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
9};
10use crate::kernels::layered::base::Selection;
11use crate::kernels::layered::{Algorithm, launch_with_config};
12use cubecl_core::prelude::TensorHandleRef;
13use cubecl_core::{Runtime, client::ComputeClient};
14
15#[allow(clippy::result_large_err, clippy::too_many_arguments)]
19pub fn launch_kernel_concrete<MS: MatmulSpec, R: Runtime, A: Algorithm>(
20 client: &ComputeClient<R::Server>,
21 lhs: &MatmulInputHandleRef<'_, R>,
22 rhs: &MatmulInputHandleRef<'_, R>,
23 out: &TensorHandleRef<'_, R>,
24 problem: MatmulProblem,
25 line_sizes: MatmulLineSizes,
26 plane_dim: u32,
27 selection: &Selection<A::SelectionArgs>,
28) -> Result<(), MatmulSetupError>
29where
30 InputArg<MS>: ConcreteInputsFactory,
31 OutputArg<MS>: ConcreteOutputFactory,
32{
33 let elems = MatmulElems::new::<MS::Precision>();
34
35 let mut view_line_sizes = line_sizes;
36
37 if let MatmulInputHandleRef::Quantized { scheme, .. } = lhs {
38 view_line_sizes.lhs *= scheme.num_quants() as u8;
39 }
40 if let MatmulInputHandleRef::Quantized { scheme, .. } = rhs {
41 view_line_sizes.rhs *= scheme.num_quants() as u8;
42 }
43
44 let selection = match selection {
45 Selection::Forced(selection) => selection.clone(),
46 Selection::Inferred(args) => {
47 A::selection::<R>(client, &problem, plane_dim, &view_line_sizes, elems, args)?
48 }
49 };
50 let config = A::setup::<MS::Precision, R>(client, &problem, &selection, &view_line_sizes)?;
51 let cube_count_plan = config.hypercube_config().cube_count_plan(
52 &problem,
53 client.properties().hardware.max_cube_count.clone(),
54 );
55
56 launch_with_config::<MS, R, A>(
57 client,
58 config.cube_dim(),
59 cube_count_plan.resolve(),
60 <InputArg<MS> as ConcreteInputsFactory>::create(
61 client,
62 lhs,
63 rhs,
64 &selection,
65 &problem,
66 &line_sizes,
67 config,
68 ),
69 <OutputArg<MS> as ConcreteOutputFactory>::create(
70 client,
71 out,
72 &selection,
73 &problem,
74 &line_sizes,
75 config,
76 ),
77 cube_count_plan.as_args(),
78 config,
79 )
80}
81
82pub fn launch_kernel_virtual<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>(
84 client: &ComputeClient<R::Server>,
85 input: InputRuntimeArg<'a, MS, R>,
86 output: OutputRuntimeArg<'a, MS, R>,
87 problem: MatmulProblem,
88 view_line_sizes: MatmulLineSizes,
89 plane_dim: u32,
90 selection: &Selection<A::SelectionArgs>,
91) -> Result<(), MatmulSetupError> {
92 let elems = MatmulElems::new::<MS::Precision>();
93
94 let selection = match selection {
95 Selection::Forced(selection) => selection.clone(),
96 Selection::Inferred(args) => {
97 A::selection::<R>(client, &problem, plane_dim, &view_line_sizes, elems, args)?
98 }
99 };
100 let config = A::setup::<MS::Precision, R>(client, &problem, &selection, &view_line_sizes)?;
101
102 let cube_count_plan = config.hypercube_config().cube_count_plan(
103 &problem,
104 client.properties().hardware.max_cube_count.clone(),
105 );
106
107 launch_with_config::<MS, R, A>(
108 client,
109 config.cube_dim(),
110 cube_count_plan.resolve(),
111 input,
112 output,
113 cube_count_plan.as_args(),
114 config,
115 )
116}