cubecl_convolution/kernels/layered/selector/
select_kernel.rs

1use crate::components::ConvGemmConfig as _;
2use cubecl_core::prelude::TensorHandleRef;
3use cubecl_core::{Runtime, client::ComputeClient};
4use cubecl_matmul::components::MatmulElems;
5use cubecl_matmul::components::global::args::MatmulArgs;
6use cubecl_matmul::{
7    MatmulInputHandleRef,
8    components::{
9        InputArg, InputRuntimeArg, MatmulLineSizes, MatmulSelection, OutputArg, OutputRuntimeArg,
10    },
11};
12
13use crate::{
14    components::{
15        ConvSetupError, ConvolutionProblem,
16        global::{
17            args::{ConcreteInputsFactory, ConcreteOutputFactory},
18            entry_point::ConvolutionLaunch,
19        },
20    },
21    kernels::layered::algorithm::Algorithm,
22};
23
24/// Select which kernel to launch for the given Algorithm.
25///
26/// Only works for concrete tensor inputs and output.
27#[allow(clippy::result_large_err, clippy::too_many_arguments)]
28pub fn launch_kernel_concrete<R: Runtime, A: Algorithm>(
29    client: &ComputeClient<R>,
30    input: &MatmulInputHandleRef<'_, R>,
31    weight: &MatmulInputHandleRef<'_, R>,
32    bias: &Option<TensorHandleRef<'_, R>>,
33    out: &TensorHandleRef<'_, R>,
34    problem: ConvolutionProblem,
35    line_sizes: MatmulLineSizes,
36    selection: MatmulSelection,
37    dtypes: &MatmulElems,
38) -> Result<(), ConvSetupError>
39where
40    InputArg<A::Args>: ConcreteInputsFactory,
41    OutputArg<A::Args>: ConcreteOutputFactory,
42{
43    let config = A::setup(client, &problem, &selection, &line_sizes, dtypes)?;
44
45    let input = <InputArg<A::Args> as ConcreteInputsFactory>::create(
46        client,
47        input,
48        weight,
49        bias.as_ref(),
50        &selection,
51        &problem,
52        &line_sizes,
53        config,
54        dtypes,
55    );
56    let output = <OutputArg<A::Args> as ConcreteOutputFactory>::create(
57        client,
58        out,
59        &selection,
60        &problem,
61        &line_sizes,
62        config,
63        dtypes,
64    );
65
66    let result = unsafe {
67        A::GlobalConvolution::launch_unchecked::<A::Args, R>(
68            client,
69            config.cube_dim(),
70            A::cube_count(&selection, &problem),
71            input,
72            output,
73            &problem,
74            config,
75            dtypes,
76        )
77    };
78
79    match result {
80        Ok(_) => Ok(()),
81        Err(err) => Err(ConvSetupError::Launch(err)),
82    }
83}
84
85/// Select which kernel to launch for the given Algorithm.
86pub fn launch_kernel_virtual<'a, MA: MatmulArgs, R: Runtime, A: Algorithm>(
87    client: &ComputeClient<R>,
88    input: InputRuntimeArg<'a, MA, R>,
89    output: OutputRuntimeArg<'a, MA, R>,
90    problem: ConvolutionProblem,
91    line_sizes: MatmulLineSizes,
92    selection: MatmulSelection,
93    dtypes: &MatmulElems,
94) -> Result<(), ConvSetupError> {
95    let config = A::setup(client, &problem, &selection, &line_sizes, dtypes)?;
96
97    let result = unsafe {
98        A::GlobalConvolution::launch_unchecked::<MA, R>(
99            client,
100            config.cube_dim(),
101            A::cube_count(&selection, &problem),
102            input,
103            output,
104            &problem,
105            config,
106            dtypes,
107        )
108    };
109
110    match result {
111        Ok(_) => Ok(()),
112        Err(err) => Err(ConvSetupError::Launch(err)),
113    }
114}