cubecl_convolution/kernels/layered/selector/
select_kernel.rs1use cubecl_core::prelude::TensorHandleRef;
2use cubecl_core::{Runtime, client::ComputeClient};
3use cubecl_matmul::components::MatmulElems;
4use cubecl_matmul::components::global::args::MatmulArgs;
5use cubecl_matmul::{
6 MatmulInputHandleRef,
7 components::{
8 InputArg, InputRuntimeArg, MatmulLineSizes, MatmulSelection, OutputArg, OutputRuntimeArg,
9 global::GlobalConfig as _,
10 },
11};
12
13use crate::{
14 components::{
15 ConvSetupError, ConvolutionProblem,
16 global::{
17 args::{ConcreteInputsFactory, ConcreteOutputFactory},
18 entry_point::ConvolutionLaunch,
19 },
20 },
21 kernels::layered::algorithm::Algorithm,
22};
23
24#[allow(clippy::result_large_err, clippy::too_many_arguments)]
28pub fn launch_kernel_concrete<R: Runtime, A: Algorithm>(
29 client: &ComputeClient<R::Server>,
30 input: &MatmulInputHandleRef<'_, R>,
31 weight: &MatmulInputHandleRef<'_, R>,
32 bias: &Option<TensorHandleRef<'_, R>>,
33 out: &TensorHandleRef<'_, R>,
34 problem: ConvolutionProblem,
35 line_sizes: MatmulLineSizes,
36 selection: MatmulSelection,
37 dtypes: &MatmulElems,
38) -> Result<(), ConvSetupError>
39where
40 InputArg<A::Args>: ConcreteInputsFactory,
41 OutputArg<A::Args>: ConcreteOutputFactory,
42{
43 let config = A::setup::<R>(client, &problem, &selection, &line_sizes, dtypes)?;
44
45 let input = <InputArg<A::Args> as ConcreteInputsFactory>::create(
46 client,
47 input,
48 weight,
49 bias.as_ref(),
50 &selection,
51 &problem,
52 &line_sizes,
53 config,
54 dtypes,
55 );
56 let output = <OutputArg<A::Args> as ConcreteOutputFactory>::create(
57 client,
58 out,
59 &selection,
60 &problem,
61 &line_sizes,
62 config,
63 dtypes,
64 );
65
66 unsafe {
67 A::GlobalConvolution::launch_unchecked::<A::Args, R>(
68 client,
69 config.cube_dim(),
70 A::cube_count(&selection, &problem),
71 input,
72 output,
73 &problem,
74 config,
75 dtypes,
76 );
77 }
78
79 Ok(())
80}
81
82pub fn launch_kernel_virtual<'a, MA: MatmulArgs, R: Runtime, A: Algorithm>(
84 client: &ComputeClient<R::Server>,
85 input: InputRuntimeArg<'a, MA, R>,
86 output: OutputRuntimeArg<'a, MA, R>,
87 problem: ConvolutionProblem,
88 line_sizes: MatmulLineSizes,
89 selection: MatmulSelection,
90 dtypes: &MatmulElems,
91) -> Result<(), ConvSetupError> {
92 let config = A::setup::<R>(client, &problem, &selection, &line_sizes, dtypes)?;
93
94 unsafe {
95 A::GlobalConvolution::launch_unchecked::<MA, R>(
96 client,
97 config.cube_dim(),
98 A::cube_count(&selection, &problem),
99 input,
100 output,
101 &problem,
102 config,
103 dtypes,
104 );
105 }
106
107 Ok(())
108}