cubek_convolution/kernels/backward_data/selector/
select_kernel.rs1use crate::{
2 backward_data::args::{ConcreteArgs, ConcreteInputsFactory, ConcreteOutputFactory},
3 components::global::args::RuntimeArgs,
4};
5use cubecl::{
6 prelude::TensorBinding,
7 {Runtime, client::ComputeClient},
8};
9use cubek_matmul::{
10 definition::{MatmulElems, MatmulVectorSizes},
11 launch::{InputArg, OutputArg},
12 routines::{BlueprintStrategy, Routine},
13};
14use cubek_std::InputBinding;
15
16use crate::components::{ConvSetupError, ConvolutionProblem};
17
18#[allow(clippy::result_large_err, clippy::too_many_arguments)]
22pub fn launch_kernel_concrete<R: Runtime, Args: ConcreteArgs<A>, A: Routine<RuntimeArgs>>(
23 client: &ComputeClient<R>,
24 out_grad: InputBinding<R>,
25 weights: InputBinding<R>,
26 in_grad: TensorBinding<R>,
27 problem: ConvolutionProblem,
28 vector_sizes: MatmulVectorSizes,
29 blueprint_strategy: &BlueprintStrategy<RuntimeArgs, A>,
30 dtypes: &MatmulElems,
31) -> Result<(), ConvSetupError> {
32 let mut view_vector_sizes = vector_sizes;
33
34 if let InputBinding::Quantized { scheme, .. } = out_grad {
35 view_vector_sizes.lhs *= scheme.num_quants();
36 }
37 if let InputBinding::Quantized { scheme, .. } = weights {
38 view_vector_sizes.rhs *= scheme.num_quants();
39 }
40
41 let device_settings = A::device_settings(client, view_vector_sizes);
42 let expand_info = A::expand_blueprint(
43 &problem.as_matmul_problem(),
44 &device_settings,
45 blueprint_strategy,
46 )?;
47
48 let problem = Args::adjust_problem(client, problem, &expand_info.blueprint, dtypes);
49 let launch_info = A::prepare(&problem.as_matmul_problem(), &device_settings, expand_info)?;
50
51 let (input, runtime_args) = <InputArg<Args> as ConcreteInputsFactory<A>>::create(
52 out_grad,
53 weights,
54 &launch_info.blueprint,
55 &problem,
56 dtypes,
57 );
58 let output = <OutputArg<Args> as ConcreteOutputFactory<A>>::create(
59 in_grad,
60 &launch_info.blueprint,
61 &problem,
62 );
63
64 let result = cubek_matmul::launch::launch_kernel::<Args, R, A>(
65 client,
66 input,
67 output,
68 runtime_args,
69 launch_info,
70 );
71
72 result.map_err(ConvSetupError::Matmul)
73}