cubek_convolution/kernels/forward/selector/
select_kernel.rs1use crate::{
2 components::{ConvGemmConfig as _, global::args::RuntimeArgsLaunch},
3 forward::args::{ConcreteInputsFactory, ConcreteOutputFactory},
4};
5use cubecl::prelude::TensorHandleRef;
6use cubecl::{Runtime, client::ComputeClient};
7use cubek_matmul::definition::{MatmulElems, MatmulLineSizes, TilingBlueprint};
8use cubek_matmul::launch::{
9 InputArg, InputRuntimeArg, MatmulArgs, MatmulInputHandleRef, OutputArg, OutputRuntimeArg,
10};
11
12use crate::{
13 components::{ConvSetupError, ConvolutionProblem, global::entry_point::ConvolutionLaunch},
14 kernels::forward::algorithm::Algorithm,
15};
16
17#[allow(clippy::result_large_err, clippy::too_many_arguments)]
21pub fn launch_kernel_concrete<R: Runtime, A: Algorithm>(
22 client: &ComputeClient<R>,
23 input: &MatmulInputHandleRef<'_, R>,
24 weight: &MatmulInputHandleRef<'_, R>,
25 bias: &Option<MatmulInputHandleRef<'_, R>>,
26 out: &TensorHandleRef<'_, R>,
27 problem: ConvolutionProblem,
28 line_sizes: MatmulLineSizes,
29 selection: TilingBlueprint,
30 dtypes: &MatmulElems,
31) -> Result<(), ConvSetupError>
32where
33 InputArg<A::Args>: ConcreteInputsFactory,
34 OutputArg<A::Args>: ConcreteOutputFactory,
35{
36 let config = A::expand_config(
37 client.properties(),
38 &problem,
39 &selection,
40 &line_sizes,
41 dtypes,
42 )?;
43
44 let (input, runtime_args) = <InputArg<A::Args> as ConcreteInputsFactory>::create(
45 client,
46 input,
47 weight,
48 bias.as_ref(),
49 &selection,
50 &problem,
51 &line_sizes,
52 config,
53 dtypes,
54 );
55 let output = <OutputArg<A::Args> as ConcreteOutputFactory>::create(
56 client,
57 out,
58 &selection,
59 &problem,
60 &line_sizes,
61 config,
62 dtypes,
63 );
64
65 let result = unsafe {
66 A::GlobalConvolution::launch_unchecked::<A::Args, R>(
67 client,
68 config.cube_dim(),
69 A::cube_count(&selection, &problem),
70 input,
71 output,
72 runtime_args,
73 config,
74 dtypes,
75 )
76 };
77
78 match result {
79 Ok(_) => Ok(()),
80 Err(err) => Err(ConvSetupError::Launch(err)),
81 }
82}
83
84#[allow(clippy::too_many_arguments)]
86pub fn launch_kernel_virtual<'a, MA: MatmulArgs, R: Runtime, A: Algorithm>(
87 client: &ComputeClient<R>,
88 input: InputRuntimeArg<'a, MA, R>,
89 output: OutputRuntimeArg<'a, MA, R>,
90 runtime_args: RuntimeArgsLaunch<'a, R>,
91 problem: ConvolutionProblem,
92 line_sizes: MatmulLineSizes,
93 selection: TilingBlueprint,
94 dtypes: &MatmulElems,
95) -> Result<(), ConvSetupError> {
96 let config = A::expand_config(
97 client.properties(),
98 &problem,
99 &selection,
100 &line_sizes,
101 dtypes,
102 )?;
103
104 let result = unsafe {
105 A::GlobalConvolution::launch_unchecked::<MA, R>(
106 client,
107 config.cube_dim(),
108 A::cube_count(&selection, &problem),
109 input,
110 output,
111 runtime_args,
112 config,
113 dtypes,
114 )
115 };
116
117 match result {
118 Ok(_) => Ok(()),
119 Err(err) => Err(ConvSetupError::Launch(err)),
120 }
121}