cubek_convolution/kernels/forward/selector/
select_kernel.rs

1use crate::{
2    components::{ConvGemmConfig as _, global::args::RuntimeArgsLaunch},
3    forward::args::{ConcreteInputsFactory, ConcreteOutputFactory},
4};
5use cubecl::prelude::TensorHandleRef;
6use cubecl::{Runtime, client::ComputeClient};
7use cubek_matmul::definition::{MatmulElems, MatmulLineSizes, TilingBlueprint};
8use cubek_matmul::launch::{
9    InputArg, InputRuntimeArg, MatmulArgs, MatmulInputHandleRef, OutputArg, OutputRuntimeArg,
10};
11
12use crate::{
13    components::{ConvSetupError, ConvolutionProblem, global::entry_point::ConvolutionLaunch},
14    kernels::forward::algorithm::Algorithm,
15};
16
17/// Select which kernel to launch for the given Algorithm.
18///
19/// Only works for concrete tensor inputs and output.
20#[allow(clippy::result_large_err, clippy::too_many_arguments)]
21pub fn launch_kernel_concrete<R: Runtime, A: Algorithm>(
22    client: &ComputeClient<R>,
23    input: &MatmulInputHandleRef<'_, R>,
24    weight: &MatmulInputHandleRef<'_, R>,
25    bias: &Option<MatmulInputHandleRef<'_, R>>,
26    out: &TensorHandleRef<'_, R>,
27    problem: ConvolutionProblem,
28    line_sizes: MatmulLineSizes,
29    selection: TilingBlueprint,
30    dtypes: &MatmulElems,
31) -> Result<(), ConvSetupError>
32where
33    InputArg<A::Args>: ConcreteInputsFactory,
34    OutputArg<A::Args>: ConcreteOutputFactory,
35{
36    let config = A::expand_config(
37        client.properties(),
38        &problem,
39        &selection,
40        &line_sizes,
41        dtypes,
42    )?;
43
44    let (input, runtime_args) = <InputArg<A::Args> as ConcreteInputsFactory>::create(
45        client,
46        input,
47        weight,
48        bias.as_ref(),
49        &selection,
50        &problem,
51        &line_sizes,
52        config,
53        dtypes,
54    );
55    let output = <OutputArg<A::Args> as ConcreteOutputFactory>::create(
56        client,
57        out,
58        &selection,
59        &problem,
60        &line_sizes,
61        config,
62        dtypes,
63    );
64
65    let result = unsafe {
66        A::GlobalConvolution::launch_unchecked::<A::Args, R>(
67            client,
68            config.cube_dim(),
69            A::cube_count(&selection, &problem),
70            input,
71            output,
72            runtime_args,
73            config,
74            dtypes,
75        )
76    };
77
78    match result {
79        Ok(_) => Ok(()),
80        Err(err) => Err(ConvSetupError::Launch(err)),
81    }
82}
83
84/// Select which kernel to launch for the given Algorithm.
85#[allow(clippy::too_many_arguments)]
86pub fn launch_kernel_virtual<'a, MA: MatmulArgs, R: Runtime, A: Algorithm>(
87    client: &ComputeClient<R>,
88    input: InputRuntimeArg<'a, MA, R>,
89    output: OutputRuntimeArg<'a, MA, R>,
90    runtime_args: RuntimeArgsLaunch<'a, R>,
91    problem: ConvolutionProblem,
92    line_sizes: MatmulLineSizes,
93    selection: TilingBlueprint,
94    dtypes: &MatmulElems,
95) -> Result<(), ConvSetupError> {
96    let config = A::expand_config(
97        client.properties(),
98        &problem,
99        &selection,
100        &line_sizes,
101        dtypes,
102    )?;
103
104    let result = unsafe {
105        A::GlobalConvolution::launch_unchecked::<MA, R>(
106            client,
107            config.cube_dim(),
108            A::cube_count(&selection, &problem),
109            input,
110            output,
111            runtime_args,
112            config,
113            dtypes,
114        )
115    };
116
117    match result {
118        Ok(_) => Ok(()),
119        Err(err) => Err(ConvSetupError::Launch(err)),
120    }
121}