cubecl_matmul/kernels/layered/selector/
select_kernel.rs

1use crate::MatmulInputHandleRef;
2use crate::components::batch::BatchConfig;
3use crate::components::{
4    InputArg, InputRuntimeArg, MatmulElems, MatmulLineSizes, MatmulSetupError, OutputRuntimeArg,
5};
6use crate::components::{
7    MatmulProblem, MatmulSpec, OutputArg,
8    global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
9};
10use crate::kernels::layered::base::Selection;
11use crate::kernels::layered::{Algorithm, launch_with_config};
12use cubecl_core::prelude::TensorHandleRef;
13use cubecl_core::{Runtime, client::ComputeClient};
14
15/// Select which kernel to launch for the given Algorithm.
16///
17/// Only works for concrete tensor inputs and output.
18#[allow(clippy::result_large_err, clippy::too_many_arguments)]
19pub fn launch_kernel_concrete<MS: MatmulSpec, R: Runtime, A: Algorithm>(
20    client: &ComputeClient<R::Server>,
21    lhs: &MatmulInputHandleRef<'_, R>,
22    rhs: &MatmulInputHandleRef<'_, R>,
23    out: &TensorHandleRef<'_, R>,
24    problem: MatmulProblem,
25    line_sizes: MatmulLineSizes,
26    plane_dim: u32,
27    selection: &Selection<A::SelectionArgs>,
28) -> Result<(), MatmulSetupError>
29where
30    InputArg<MS>: ConcreteInputsFactory,
31    OutputArg<MS>: ConcreteOutputFactory,
32{
33    let elems = MatmulElems::new::<MS::Precision>();
34
35    let mut view_line_sizes = line_sizes;
36
37    if let MatmulInputHandleRef::Quantized { scheme, .. } = lhs {
38        view_line_sizes.lhs *= scheme.num_quants() as u8;
39    }
40    if let MatmulInputHandleRef::Quantized { scheme, .. } = rhs {
41        view_line_sizes.rhs *= scheme.num_quants() as u8;
42    }
43
44    let selection = match selection {
45        Selection::Forced(selection) => selection.clone(),
46        Selection::Inferred(args) => {
47            A::selection::<R>(client, &problem, plane_dim, &view_line_sizes, elems, args)?
48        }
49    };
50    let config = A::setup::<MS::Precision, R>(client, &problem, &selection, &view_line_sizes)?;
51    let cube_count_plan = config.hypercube_config().cube_count_plan(
52        &problem,
53        client.properties().hardware.max_cube_count.clone(),
54    );
55
56    launch_with_config::<MS, R, A>(
57        client,
58        config.cube_dim(),
59        cube_count_plan.resolve(),
60        <InputArg<MS> as ConcreteInputsFactory>::create(
61            client,
62            lhs,
63            rhs,
64            &selection,
65            &problem,
66            &line_sizes,
67            config,
68        ),
69        <OutputArg<MS> as ConcreteOutputFactory>::create(
70            client,
71            out,
72            &selection,
73            &problem,
74            &line_sizes,
75            config,
76        ),
77        cube_count_plan.as_args(),
78        config,
79    )
80}
81
82/// Select which kernel to launch for the given Algorithm.
83pub fn launch_kernel_virtual<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>(
84    client: &ComputeClient<R::Server>,
85    input: InputRuntimeArg<'a, MS, R>,
86    output: OutputRuntimeArg<'a, MS, R>,
87    problem: MatmulProblem,
88    view_line_sizes: MatmulLineSizes,
89    plane_dim: u32,
90    selection: &Selection<A::SelectionArgs>,
91) -> Result<(), MatmulSetupError> {
92    let elems = MatmulElems::new::<MS::Precision>();
93
94    let selection = match selection {
95        Selection::Forced(selection) => selection.clone(),
96        Selection::Inferred(args) => {
97            A::selection::<R>(client, &problem, plane_dim, &view_line_sizes, elems, args)?
98        }
99    };
100    let config = A::setup::<MS::Precision, R>(client, &problem, &selection, &view_line_sizes)?;
101
102    let cube_count_plan = config.hypercube_config().cube_count_plan(
103        &problem,
104        client.properties().hardware.max_cube_count.clone(),
105    );
106
107    launch_with_config::<MS, R, A>(
108        client,
109        config.cube_dim(),
110        cube_count_plan.resolve(),
111        input,
112        output,
113        cube_count_plan.as_args(),
114        config,
115    )
116}