cubek_convolution/kernels/backward_data/selector/
select_kernel.rs1use crate::{
2 backward_data::args::{ConcreteInputsFactory, ConcreteOutputFactory},
3 components::{ConvGemmConfig as _, global::args::RuntimeArgsLaunch},
4};
5use cubecl::prelude::TensorHandleRef;
6use cubecl::{Runtime, client::ComputeClient};
7use cubek_matmul::{
8 definition::{MatmulElems, MatmulLineSizes, TilingBlueprint},
9 launch::{
10 InputArg, InputRuntimeArg, MatmulArgs, MatmulInputHandleRef, OutputArg, OutputRuntimeArg,
11 },
12};
13
14use crate::{
15 components::{ConvSetupError, ConvolutionProblem, global::entry_point::ConvolutionLaunch},
16 kernels::forward::algorithm::Algorithm,
17};
18
19#[allow(clippy::result_large_err, clippy::too_many_arguments)]
23pub fn launch_kernel_concrete<R: Runtime, A: Algorithm>(
24 client: &ComputeClient<R>,
25 out_grad: &MatmulInputHandleRef<'_, R>,
26 weights: &MatmulInputHandleRef<'_, R>,
27 in_grad: &TensorHandleRef<'_, R>,
28 problem: ConvolutionProblem,
29 line_sizes: MatmulLineSizes,
30 selection: TilingBlueprint,
31 dtypes: &MatmulElems,
32) -> Result<(), ConvSetupError>
33where
34 InputArg<A::Args>: ConcreteInputsFactory,
35 OutputArg<A::Args>: ConcreteOutputFactory,
36{
37 let config = A::expand_config(
38 client.properties(),
39 &problem,
40 &selection,
41 &line_sizes,
42 dtypes,
43 )?;
44
45 let (input, runtime_args) = <InputArg<A::Args> as ConcreteInputsFactory>::create(
46 client,
47 out_grad,
48 weights,
49 &selection,
50 &problem,
51 &line_sizes,
52 config,
53 dtypes,
54 );
55 let output = <OutputArg<A::Args> as ConcreteOutputFactory>::create(
56 client,
57 in_grad,
58 &selection,
59 &problem,
60 &line_sizes,
61 config,
62 );
63
64 let result = unsafe {
65 A::GlobalConvolution::launch_unchecked::<A::Args, R>(
66 client,
67 config.cube_dim(),
68 A::cube_count(&selection, &problem),
69 input,
70 output,
71 runtime_args,
72 config,
73 dtypes,
74 )
75 };
76
77 match result {
78 Ok(_) => Ok(()),
79 Err(err) => Err(ConvSetupError::Launch(err)),
80 }
81}
82
83#[allow(clippy::too_many_arguments)]
85pub fn launch_kernel_virtual<'a, MA: MatmulArgs, R: Runtime, A: Algorithm>(
86 client: &ComputeClient<R>,
87 input: InputRuntimeArg<'a, MA, R>,
88 output: OutputRuntimeArg<'a, MA, R>,
89 runtime_args: RuntimeArgsLaunch<'a, R>,
90 problem: ConvolutionProblem,
91 line_sizes: MatmulLineSizes,
92 selection: TilingBlueprint,
93 dtypes: &MatmulElems,
94) -> Result<(), ConvSetupError> {
95 let config = A::expand_config(
96 client.properties(),
97 &problem,
98 &selection,
99 &line_sizes,
100 dtypes,
101 )?;
102
103 let result = unsafe {
104 A::GlobalConvolution::launch_unchecked::<MA, R>(
105 client,
106 config.cube_dim(),
107 A::cube_count(&selection, &problem),
108 input,
109 output,
110 runtime_args,
111 config,
112 dtypes,
113 )
114 };
115
116 match result {
117 Ok(_) => Ok(()),
118 Err(err) => Err(ConvSetupError::Launch(err)),
119 }
120}