cubek_convolution/kernels/backward_data/selector/
select_kernel.rs

1use crate::{
2    backward_data::args::{ConcreteInputsFactory, ConcreteOutputFactory},
3    components::{ConvGemmConfig as _, global::args::RuntimeArgsLaunch},
4};
5use cubecl::prelude::TensorHandleRef;
6use cubecl::{Runtime, client::ComputeClient};
7use cubek_matmul::{
8    definition::{MatmulElems, MatmulLineSizes, TilingBlueprint},
9    launch::{
10        InputArg, InputRuntimeArg, MatmulArgs, MatmulInputHandleRef, OutputArg, OutputRuntimeArg,
11    },
12};
13
14use crate::{
15    components::{ConvSetupError, ConvolutionProblem, global::entry_point::ConvolutionLaunch},
16    kernels::forward::algorithm::Algorithm,
17};
18
19/// Select which kernel to launch for the given Algorithm.
20///
21/// Only works for concrete tensor inputs and output.
22#[allow(clippy::result_large_err, clippy::too_many_arguments)]
23pub fn launch_kernel_concrete<R: Runtime, A: Algorithm>(
24    client: &ComputeClient<R>,
25    out_grad: &MatmulInputHandleRef<'_, R>,
26    weights: &MatmulInputHandleRef<'_, R>,
27    in_grad: &TensorHandleRef<'_, R>,
28    problem: ConvolutionProblem,
29    line_sizes: MatmulLineSizes,
30    selection: TilingBlueprint,
31    dtypes: &MatmulElems,
32) -> Result<(), ConvSetupError>
33where
34    InputArg<A::Args>: ConcreteInputsFactory,
35    OutputArg<A::Args>: ConcreteOutputFactory,
36{
37    let config = A::expand_config(
38        client.properties(),
39        &problem,
40        &selection,
41        &line_sizes,
42        dtypes,
43    )?;
44
45    let (input, runtime_args) = <InputArg<A::Args> as ConcreteInputsFactory>::create(
46        client,
47        out_grad,
48        weights,
49        &selection,
50        &problem,
51        &line_sizes,
52        config,
53        dtypes,
54    );
55    let output = <OutputArg<A::Args> as ConcreteOutputFactory>::create(
56        client,
57        in_grad,
58        &selection,
59        &problem,
60        &line_sizes,
61        config,
62    );
63
64    let result = unsafe {
65        A::GlobalConvolution::launch_unchecked::<A::Args, R>(
66            client,
67            config.cube_dim(),
68            A::cube_count(&selection, &problem),
69            input,
70            output,
71            runtime_args,
72            config,
73            dtypes,
74        )
75    };
76
77    match result {
78        Ok(_) => Ok(()),
79        Err(err) => Err(ConvSetupError::Launch(err)),
80    }
81}
82
83/// Select which kernel to launch for the given Algorithm.
84#[allow(clippy::too_many_arguments)]
85pub fn launch_kernel_virtual<'a, MA: MatmulArgs, R: Runtime, A: Algorithm>(
86    client: &ComputeClient<R>,
87    input: InputRuntimeArg<'a, MA, R>,
88    output: OutputRuntimeArg<'a, MA, R>,
89    runtime_args: RuntimeArgsLaunch<'a, R>,
90    problem: ConvolutionProblem,
91    line_sizes: MatmulLineSizes,
92    selection: TilingBlueprint,
93    dtypes: &MatmulElems,
94) -> Result<(), ConvSetupError> {
95    let config = A::expand_config(
96        client.properties(),
97        &problem,
98        &selection,
99        &line_sizes,
100        dtypes,
101    )?;
102
103    let result = unsafe {
104        A::GlobalConvolution::launch_unchecked::<MA, R>(
105            client,
106            config.cube_dim(),
107            A::cube_count(&selection, &problem),
108            input,
109            output,
110            runtime_args,
111            config,
112            dtypes,
113        )
114    };
115
116    match result {
117        Ok(_) => Ok(()),
118        Err(err) => Err(ConvSetupError::Launch(err)),
119    }
120}