cubecl_convolution/kernels/layered/selector/
select_kernel.rs

1use cubecl_core::prelude::TensorHandleRef;
2use cubecl_core::{Runtime, client::ComputeClient};
3use cubecl_matmul::components::MatmulElems;
4use cubecl_matmul::components::global::args::MatmulArgs;
5use cubecl_matmul::{
6    MatmulInputHandleRef,
7    components::{
8        InputArg, InputRuntimeArg, MatmulLineSizes, MatmulSelection, OutputArg, OutputRuntimeArg,
9        global::GlobalConfig as _,
10    },
11};
12
13use crate::{
14    components::{
15        ConvSetupError, ConvolutionProblem,
16        global::{
17            args::{ConcreteInputsFactory, ConcreteOutputFactory},
18            entry_point::ConvolutionLaunch,
19        },
20    },
21    kernels::layered::algorithm::Algorithm,
22};
23
24/// Select which kernel to launch for the given Algorithm.
25///
26/// Only works for concrete tensor inputs and output.
27#[allow(clippy::result_large_err, clippy::too_many_arguments)]
28pub fn launch_kernel_concrete<R: Runtime, A: Algorithm>(
29    client: &ComputeClient<R::Server>,
30    input: &MatmulInputHandleRef<'_, R>,
31    weight: &MatmulInputHandleRef<'_, R>,
32    bias: &Option<TensorHandleRef<'_, R>>,
33    out: &TensorHandleRef<'_, R>,
34    problem: ConvolutionProblem,
35    line_sizes: MatmulLineSizes,
36    selection: MatmulSelection,
37    dtypes: &MatmulElems,
38) -> Result<(), ConvSetupError>
39where
40    InputArg<A::Args>: ConcreteInputsFactory,
41    OutputArg<A::Args>: ConcreteOutputFactory,
42{
43    let config = A::setup::<R>(client, &problem, &selection, &line_sizes, dtypes)?;
44
45    let input = <InputArg<A::Args> as ConcreteInputsFactory>::create(
46        client,
47        input,
48        weight,
49        bias.as_ref(),
50        &selection,
51        &problem,
52        &line_sizes,
53        config,
54        dtypes,
55    );
56    let output = <OutputArg<A::Args> as ConcreteOutputFactory>::create(
57        client,
58        out,
59        &selection,
60        &problem,
61        &line_sizes,
62        config,
63        dtypes,
64    );
65
66    unsafe {
67        A::GlobalConvolution::launch_unchecked::<A::Args, R>(
68            client,
69            config.cube_dim(),
70            A::cube_count(&selection, &problem),
71            input,
72            output,
73            &problem,
74            config,
75            dtypes,
76        );
77    }
78
79    Ok(())
80}
81
82/// Select which kernel to launch for the given Algorithm.
83pub fn launch_kernel_virtual<'a, MA: MatmulArgs, R: Runtime, A: Algorithm>(
84    client: &ComputeClient<R::Server>,
85    input: InputRuntimeArg<'a, MA, R>,
86    output: OutputRuntimeArg<'a, MA, R>,
87    problem: ConvolutionProblem,
88    line_sizes: MatmulLineSizes,
89    selection: MatmulSelection,
90    dtypes: &MatmulElems,
91) -> Result<(), ConvSetupError> {
92    let config = A::setup::<R>(client, &problem, &selection, &line_sizes, dtypes)?;
93
94    unsafe {
95        A::GlobalConvolution::launch_unchecked::<MA, R>(
96            client,
97            config.cube_dim(),
98            A::cube_count(&selection, &problem),
99            input,
100            output,
101            &problem,
102            config,
103            dtypes,
104        );
105    }
106
107    Ok(())
108}