Skip to main content

cubek_convolution/kernels/forward/selector/
select_kernel.rs

1use crate::{
2    components::global::args::RuntimeArgs,
3    forward::args::{ConcreteArgs, ConcreteInputsFactory, ConcreteOutputFactory},
4};
5use cubecl::{
6    prelude::TensorBinding,
7    {Runtime, client::ComputeClient},
8};
9use cubek_matmul::{
10    definition::{MatmulElems, MatmulVectorSizes},
11    routines::Routine,
12};
13use cubek_matmul::{
14    launch::{InputArg, OutputArg},
15    routines::BlueprintStrategy,
16};
17use cubek_std::InputBinding;
18
19use crate::components::{ConvSetupError, ConvolutionProblem};
20
21/// Select which kernel to launch for the given Algorithm.
22///
23/// Only works for concrete tensor inputs and output.
24#[allow(clippy::result_large_err, clippy::too_many_arguments)]
25pub fn launch_kernel_concrete<R: Runtime, Args: ConcreteArgs<A>, A: Routine<RuntimeArgs>>(
26    client: &ComputeClient<R>,
27    input: InputBinding<R>,
28    weight: InputBinding<R>,
29    bias: Option<InputBinding<R>>,
30    out: TensorBinding<R>,
31    problem: ConvolutionProblem,
32    vector_sizes: MatmulVectorSizes,
33    blueprint_strategy: &BlueprintStrategy<Args::Config, A>,
34    dtypes: &MatmulElems,
35) -> Result<(), ConvSetupError> {
36    let mut view_vector_sizes = vector_sizes;
37
38    if let InputBinding::Quantized { scheme, .. } = input {
39        view_vector_sizes.lhs *= scheme.num_quants();
40    }
41    if let InputBinding::Quantized { scheme, .. } = weight {
42        view_vector_sizes.rhs *= scheme.num_quants();
43    }
44
45    let device_settings = A::device_settings(client, view_vector_sizes);
46    let expand_info = A::expand_blueprint(
47        &problem.as_matmul_problem(),
48        &device_settings,
49        blueprint_strategy,
50    )?;
51
52    let problem = Args::adjust_problem(client, problem, &expand_info.blueprint, dtypes);
53    let launch_info = A::prepare(&problem.as_matmul_problem(), &device_settings, expand_info)?;
54
55    let (input, runtime_args) = <InputArg<Args> as ConcreteInputsFactory<A>>::create(
56        input,
57        weight,
58        bias,
59        &launch_info.blueprint,
60        &problem,
61        dtypes,
62    );
63    let output = <OutputArg<Args> as ConcreteOutputFactory<A>>::create(
64        out,
65        &launch_info.blueprint,
66        &problem,
67        dtypes,
68    );
69
70    cubek_matmul::launch::launch_kernel::<Args, R, A>(
71        client,
72        input,
73        output,
74        runtime_args,
75        launch_info,
76    )
77    .map_err(ConvSetupError::Matmul)
78}