cubecl_convolution/kernels/layered/selector/
select_kernel.rs1use cubecl_core::prelude::TensorHandleRef;
2use cubecl_core::{Runtime, client::ComputeClient};
3use cubecl_matmul::{
4 MatmulInputHandleRef,
5 components::{
6 InputArg, InputRuntimeArg, MatmulLineSizes, MatmulSelection, MatmulSpec, OutputArg,
7 OutputRuntimeArg, global::GlobalConfig as _,
8 },
9};
10
11use crate::{
12 components::{
13 ConvSetupError, ConvolutionProblem,
14 global::{
15 args::{ConcreteInputsFactory, ConcreteOutputFactory},
16 entry_point::ConvolutionLaunch,
17 },
18 },
19 kernels::layered::algorithm::Algorithm,
20};
21
22#[allow(clippy::result_large_err, clippy::too_many_arguments)]
26pub fn launch_kernel_concrete<MS: MatmulSpec, R: Runtime, A: Algorithm>(
27 client: &ComputeClient<R::Server>,
28 input: &MatmulInputHandleRef<'_, R>,
29 weight: &MatmulInputHandleRef<'_, R>,
30 bias: &Option<TensorHandleRef<'_, R>>,
31 out: &TensorHandleRef<'_, R>,
32 problem: ConvolutionProblem,
33 line_sizes: MatmulLineSizes,
34 selection: MatmulSelection,
35) -> Result<(), ConvSetupError>
36where
37 InputArg<MS>: ConcreteInputsFactory,
38 OutputArg<MS>: ConcreteOutputFactory,
39{
40 let config = A::setup::<R, MS::Precision>(client, &problem, &selection, &line_sizes)?;
41
42 let input = <InputArg<MS> as ConcreteInputsFactory>::create(
43 client,
44 input,
45 weight,
46 bias.as_ref(),
47 &selection,
48 &problem,
49 &line_sizes,
50 config,
51 );
52 let output = <OutputArg<MS> as ConcreteOutputFactory>::create(
53 client,
54 out,
55 &selection,
56 &problem,
57 &line_sizes,
58 config,
59 );
60
61 unsafe {
62 A::GlobalConvolution::launch_unchecked::<MS, R>(
63 client,
64 config.cube_dim(),
65 A::cube_count(&selection, &problem),
66 input,
67 output,
68 &problem,
69 config,
70 );
71 }
72
73 Ok(())
74}
75
76pub fn launch_kernel_virtual<'a, MS: MatmulSpec, R: Runtime, A: Algorithm>(
78 client: &ComputeClient<R::Server>,
79 input: InputRuntimeArg<'a, MS, R>,
80 output: OutputRuntimeArg<'a, MS, R>,
81 problem: ConvolutionProblem,
82 line_sizes: MatmulLineSizes,
83 selection: MatmulSelection,
84) -> Result<(), ConvSetupError> {
85 let config = A::setup::<R, MS::Precision>(client, &problem, &selection, &line_sizes)?;
86
87 unsafe {
88 A::GlobalConvolution::launch_unchecked::<MS, R>(
89 client,
90 config.cube_dim(),
91 A::cube_count(&selection, &problem),
92 input,
93 output,
94 &problem,
95 config,
96 );
97 }
98
99 Ok(())
100}