cubecl_convolution/components/global/
entry_point.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_core::{Runtime, client::ComputeClient};
4use cubecl_matmul::components::MatmulElems;
5use cubecl_matmul::components::{
6 InputRuntimeArg, OutputRuntimeArg,
7 batch::SliceIndex,
8 global::{GlobalConfig as _, args::MatmulArgs},
9};
10use cubecl_std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs};
11
12use crate::{
13 components::{
14 ConvolutionProblem,
15 global::{GlobalConvolution, GlobalConvolutionFamily},
16 },
17 kernels::layered::selector::RuntimeArgs,
18};
19
20type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
21type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
22
23pub trait ConvolutionLaunch<Config> {
25 #[allow(clippy::too_many_arguments)]
31 unsafe fn launch_unchecked<'a, MA: MatmulArgs, R: Runtime>(
32 client: &ComputeClient<<R as Runtime>::Server>,
33 cube_dim: CubeDim,
34 cube_count: CubeCount,
35 input: InputRuntimeArg<'a, MA, R>,
36 output: OutputRuntimeArg<'a, MA, R>,
37 problem: &ConvolutionProblem,
38 config: Config,
39 dtypes: &MatmulElems,
40 );
41}
42
43#[cube(launch_unchecked)]
44pub(crate) fn implicit_conv<
45 Args: MatmulArgs,
46 LhsG: Numeric,
47 RhsG: Numeric,
48 AccG: Numeric,
49 LhsS: Numeric,
50 RhsS: Numeric,
51 AccS: Numeric,
52 GMM: GlobalConvolutionFamily,
53>(
54 inputs: &Input<Args, LhsG, RhsG, AccG>,
55 output: &mut Output<Args, AccG>,
56 runtime_args: RuntimeArgs,
57 #[comptime] config: GMM::Config,
58 #[define(LhsG)] _lhs_global: StorageType,
59 #[define(RhsG)] _rhs_global: StorageType,
60 #[define(AccG)] _acc_global: StorageType,
61 #[define(LhsS)] _lhs_stage: StorageType,
62 #[define(RhsS)] _rhs_stage: StorageType,
63 #[define(AccS)] _acc_stage: StorageType,
64) {
65 let mut state = Args::init_state::<LhsG, RhsG, AccG, GMM::Config>(inputs, output, config);
66
67 let lhs = Args::view_lhs(&state);
68 let rhs = Args::view_rhs(&state);
69 let bias = Args::view_acc(&state);
70 let out = Args::view_out(&mut state);
71
72 let stage_m = config.tiling_scheme().elements_in_stage_m().runtime();
73 let stage_n = config.tiling_scheme().elements_in_stage_n().runtime();
74
75 let m_offset = CUBE_POS_X * stage_m;
76 let n_offset = CUBE_POS_Y * stage_n;
77
78 let k_range = (0, runtime_args.shape_k);
79 let k_size = runtime_args.shape_k;
80
81 let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
82 let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
83 let bias = match bias {
84 CubeOption::Some(bias) => {
85 let view = bias.view(SliceIndex::new(0, bias.shape()));
86 CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
87 }
88 CubeOption::None => CubeOption::new_None(),
89 };
90 let out = out.view_mut(SliceIndex::new(0, out.shape()));
91
92 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::execute(
93 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_lhs_global_reader(
94 lhs,
95 (m_offset, k_range.0),
96 (stage_m, k_size),
97 &runtime_args,
98 config,
99 ),
100 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_rhs_global_reader(
101 rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
102 config,
103 ),
104 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_bias_global_reader(
105 bias, config,
106 ),
107 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_global_writer(
108 out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
109 config,
110 ),
111 &mut GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_accumulator(config),
112 k_range,
113 config,
114 );
115}
116
117pub(crate) fn shape_divmod<'a, R: Runtime>(
118 client: &ComputeClient<R::Server>,
119 shape: &[usize],
120) -> SequenceArg<'a, R, FastDivmod> {
121 shape
122 .iter()
123 .map(|s| FastDivmodArgs::new(client, *s as u32))
124 .collect()
125}