Skip to main content

cubek_convolution/kernels/backward_data/selector/
select_kernel.rs

1use crate::{
2    backward_data::args::{ConcreteArgs, ConcreteInputsFactory, ConcreteOutputFactory},
3    components::global::args::RuntimeArgs,
4};
5use cubecl::{
6    prelude::TensorBinding,
7    {Runtime, client::ComputeClient},
8};
9use cubek_matmul::{
10    definition::{MatmulElems, MatmulVectorSizes},
11    launch::{InputArg, OutputArg},
12    routines::{BlueprintStrategy, Routine},
13};
14use cubek_std::InputBinding;
15
16use crate::components::{ConvSetupError, ConvolutionProblem};
17
18/// Select which kernel to launch for the given Algorithm.
19///
20/// Only works for concrete tensor inputs and output.
21#[allow(clippy::result_large_err, clippy::too_many_arguments)]
22pub fn launch_kernel_concrete<R: Runtime, Args: ConcreteArgs<A>, A: Routine<RuntimeArgs>>(
23    client: &ComputeClient<R>,
24    out_grad: InputBinding<R>,
25    weights: InputBinding<R>,
26    in_grad: TensorBinding<R>,
27    problem: ConvolutionProblem,
28    vector_sizes: MatmulVectorSizes,
29    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, A>,
30    dtypes: &MatmulElems,
31) -> Result<(), ConvSetupError> {
32    let mut view_vector_sizes = vector_sizes;
33
34    if let InputBinding::Quantized { scheme, .. } = out_grad {
35        view_vector_sizes.lhs *= scheme.num_quants();
36    }
37    if let InputBinding::Quantized { scheme, .. } = weights {
38        view_vector_sizes.rhs *= scheme.num_quants();
39    }
40
41    let device_settings = A::device_settings(client, view_vector_sizes);
42    let expand_info = A::expand_blueprint(
43        &problem.as_matmul_problem(),
44        &device_settings,
45        blueprint_strategy,
46    )?;
47
48    let problem = Args::adjust_problem(client, problem, &expand_info.blueprint, dtypes);
49    let launch_info = A::prepare(&problem.as_matmul_problem(), &device_settings, expand_info)?;
50
51    let (input, runtime_args) = <InputArg<Args> as ConcreteInputsFactory<A>>::create(
52        out_grad,
53        weights,
54        &launch_info.blueprint,
55        &problem,
56        dtypes,
57    );
58    let output = <OutputArg<Args> as ConcreteOutputFactory<A>>::create(
59        in_grad,
60        &launch_info.blueprint,
61        &problem,
62    );
63
64    let result = cubek_matmul::launch::launch_kernel::<Args, R, A>(
65        client,
66        input,
67        output,
68        runtime_args,
69        launch_info,
70    );
71
72    result.map_err(ConvSetupError::Matmul)
73}