cubek_convolution/kernels/forward/selector/
select_kernel.rs1use crate::{
2 components::global::args::RuntimeArgs,
3 forward::args::{ConcreteArgs, ConcreteInputsFactory, ConcreteOutputFactory},
4};
5use cubecl::{
6 prelude::TensorBinding,
7 {Runtime, client::ComputeClient},
8};
9use cubek_matmul::{
10 definition::{MatmulElems, MatmulVectorSizes},
11 routines::Routine,
12};
13use cubek_matmul::{
14 launch::{InputArg, OutputArg},
15 routines::BlueprintStrategy,
16};
17use cubek_std::InputBinding;
18
19use crate::components::{ConvSetupError, ConvolutionProblem};
20
21#[allow(clippy::result_large_err, clippy::too_many_arguments)]
25pub fn launch_kernel_concrete<R: Runtime, Args: ConcreteArgs<A>, A: Routine<RuntimeArgs>>(
26 client: &ComputeClient<R>,
27 input: InputBinding<R>,
28 weight: InputBinding<R>,
29 bias: Option<InputBinding<R>>,
30 out: TensorBinding<R>,
31 problem: ConvolutionProblem,
32 vector_sizes: MatmulVectorSizes,
33 blueprint_strategy: &BlueprintStrategy<Args::Config, A>,
34 dtypes: &MatmulElems,
35) -> Result<(), ConvSetupError> {
36 let mut view_vector_sizes = vector_sizes;
37
38 if let InputBinding::Quantized { scheme, .. } = input {
39 view_vector_sizes.lhs *= scheme.num_quants();
40 }
41 if let InputBinding::Quantized { scheme, .. } = weight {
42 view_vector_sizes.rhs *= scheme.num_quants();
43 }
44
45 let device_settings = A::device_settings(client, view_vector_sizes);
46 let expand_info = A::expand_blueprint(
47 &problem.as_matmul_problem(),
48 &device_settings,
49 blueprint_strategy,
50 )?;
51
52 let problem = Args::adjust_problem(client, problem, &expand_info.blueprint, dtypes);
53 let launch_info = A::prepare(&problem.as_matmul_problem(), &device_settings, expand_info)?;
54
55 let (input, runtime_args) = <InputArg<Args> as ConcreteInputsFactory<A>>::create(
56 input,
57 weight,
58 bias,
59 &launch_info.blueprint,
60 &problem,
61 dtypes,
62 );
63 let output = <OutputArg<Args> as ConcreteOutputFactory<A>>::create(
64 out,
65 &launch_info.blueprint,
66 &problem,
67 dtypes,
68 );
69
70 cubek_matmul::launch::launch_kernel::<Args, R, A>(
71 client,
72 input,
73 output,
74 runtime_args,
75 launch_info,
76 )
77 .map_err(ConvSetupError::Matmul)
78}