cubecl_convolution/components/global/
entry_point.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_core::{Runtime, client::ComputeClient};
4use cubecl_matmul::components::{
5 InputRuntimeArg, MatmulSpec, OutputRuntimeArg,
6 batch::SliceIndex,
7 global::{GlobalConfig as _, args::MatmulArgs},
8};
9use cubecl_std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs};
10
11use crate::{
12 components::{
13 ConvolutionProblem,
14 global::{GlobalConvolution, GlobalConvolutionFamily},
15 },
16 kernels::layered::selector::RuntimeArgs,
17};
18
19type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
20type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
21
22pub trait ConvolutionLaunch<Config> {
24 #[allow(clippy::too_many_arguments)]
30 unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
31 client: &ComputeClient<<R as Runtime>::Server>,
32 cube_dim: CubeDim,
33 cube_count: CubeCount,
34 input: InputRuntimeArg<'a, MS, R>,
35 output: OutputRuntimeArg<'a, MS, R>,
36 problem: &ConvolutionProblem,
37 config: Config,
38 );
39}
40
41#[cube(launch_unchecked)]
42pub(crate) fn implicit_conv<
43 Args: MatmulArgs,
44 LhsG: Numeric,
45 RhsG: Numeric,
46 AccG: Numeric,
47 LhsS: Numeric,
48 RhsS: Numeric,
49 AccS: Numeric,
50 GMM: GlobalConvolutionFamily,
51>(
52 inputs: &Input<Args, LhsG, RhsG, AccG>,
53 output: &mut Output<Args, AccG>,
54 runtime_args: RuntimeArgs,
55 #[comptime] config: GMM::Config,
56) {
57 let mut state = Args::init_state::<LhsG, RhsG, AccG, GMM::Config>(inputs, output, config);
58
59 let lhs = Args::view_lhs(&state);
60 let rhs = Args::view_rhs(&state);
61 let bias = Args::view_acc(&state);
62 let out = Args::view_out(&mut state);
63
64 let stage_m = config.tiling_scheme().elements_in_stage_m().runtime();
65 let stage_n = config.tiling_scheme().elements_in_stage_n().runtime();
66
67 let m_offset = CUBE_POS_X * stage_m;
68 let n_offset = CUBE_POS_Y * stage_n;
69
70 let k_range = (0, runtime_args.shape_k);
71 let k_size = runtime_args.shape_k;
72
73 let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
74 let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
75 let bias = match bias {
76 CubeOption::Some(bias) => {
77 let view = bias.view(SliceIndex::new(0, bias.shape()));
78 CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
79 }
80 CubeOption::None => CubeOption::new_None(),
81 };
82 let out = out.view_mut(SliceIndex::new(0, out.shape()));
83
84 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::execute(
85 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_lhs_global_reader(
86 lhs,
87 (m_offset, k_range.0),
88 (stage_m, k_size),
89 &runtime_args,
90 config,
91 ),
92 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_rhs_global_reader(
93 rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
94 config,
95 ),
96 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_bias_global_reader(
97 bias, config,
98 ),
99 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_global_writer(
100 out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
101 config,
102 ),
103 &mut GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_accumulator(config),
104 k_range,
105 config,
106 );
107}
108
109pub(crate) fn shape_divmod<'a, R: Runtime>(
110 client: &ComputeClient<R::Server>,
111 shape: &[usize],
112) -> SequenceArg<'a, R, FastDivmod> {
113 shape
114 .iter()
115 .map(|s| FastDivmodArgs::new(client, *s as u32))
116 .collect()
117}