cubecl_convolution/kernels/layered/selector/
select_kernel.rs1use cubecl_core::prelude::TensorHandleRef;
2use cubecl_core::{Runtime, client::ComputeClient};
3use cubecl_matmul::{
4    MatmulInputHandleRef,
5    components::{
6        InputArg, InputRuntimeArg, MatmulLineSizes, MatmulSelection, MatmulSpec, OutputArg,
7        OutputRuntimeArg, global::GlobalConfig as _,
8    },
9};
10
11use crate::{
12    components::{
13        ConvSetupError, ConvolutionProblem,
14        global::{
15            args::{ConcreteInputsFactory, ConcreteOutputFactory},
16            entry_point::ConvolutionLaunch,
17        },
18    },
19    kernels::layered::algorithm::Algorithm,
20};
21
22#[allow(clippy::result_large_err, clippy::too_many_arguments)]
26pub fn launch_kernel_concrete<MS: MatmulSpec, R: Runtime, A: Algorithm>(
27    client: &ComputeClient<R::Server>,
28    input: &MatmulInputHandleRef<'_, R>,
29    weight: &MatmulInputHandleRef<'_, R>,
30    bias: &Option<TensorHandleRef<'_, R>>,
31    out: &TensorHandleRef<'_, R>,
32    problem: ConvolutionProblem,
33    line_sizes: MatmulLineSizes,
34    selection: MatmulSelection,
35) -> Result<(), ConvSetupError>
36where
37    InputArg<MS>: ConcreteInputsFactory,
38    OutputArg<MS>: ConcreteOutputFactory,
39{
40    let config = A::setup::<R, MS::Precision>(client, &problem, &selection, &line_sizes)?;
41
42    let input = <InputArg<MS> as ConcreteInputsFactory>::create(
43        client,
44        input,
45        weight,
46        bias.as_ref(),
47        &selection,
48        &problem,
49        &line_sizes,
50        config,
51    );
52    let output = <OutputArg<MS> as ConcreteOutputFactory>::create(
53        client,
54        out,
55        &selection,
56        &problem,
57        &line_sizes,
58        config,
59    );
60
61    unsafe {
62        A::GlobalConvolution::launch_unchecked::<MS, R>(
63            client,
64            config.cube_dim(),
65            A::cube_count(&selection, &problem),
66            input,
67            output,
68            &problem,
69            config,
70        );
71    }
72
73    Ok(())
74}
75
76pub fn launch_kernel_virtual<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>(
78    client: &ComputeClient<R::Server>,
79    input: InputRuntimeArg<'a, MS, R>,
80    output: OutputRuntimeArg<'a, MS, R>,
81    problem: ConvolutionProblem,
82    line_sizes: MatmulLineSizes,
83    selection: MatmulSelection,
84) -> Result<(), ConvSetupError> {
85    let config = A::setup::<R, MS::Precision>(client, &problem, &selection, &line_sizes)?;
86
87    unsafe {
88        A::GlobalConvolution::launch_unchecked::<MS, R>(
89            client,
90            config.cube_dim(),
91            A::cube_count(&selection, &problem),
92            input,
93            output,
94            &problem,
95            config,
96        );
97    }
98
99    Ok(())
100}