cubecl_convolution/components/global/
entry_point.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_core::{Runtime, client::ComputeClient};
4use cubecl_matmul::components::MatmulElems;
5use cubecl_matmul::components::global::GlobalConfig;
6use cubecl_matmul::components::stage::StageConfig as _;
7use cubecl_matmul::components::{
8 InputRuntimeArg, OutputRuntimeArg, batch::SliceIndex, global::args::MatmulArgs,
9};
10use cubecl_std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs};
11
12use crate::components::ConvGemmConfig;
13use crate::{
14 components::{
15 ConvolutionProblem,
16 global::{GlobalConvolution, GlobalConvolutionFamily},
17 },
18 kernels::layered::selector::RuntimeArgs,
19};
20
21type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
22type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
23
24pub trait ConvolutionLaunch<Config> {
26 #[allow(clippy::too_many_arguments)]
32 unsafe fn launch_unchecked<'a, MA: MatmulArgs, R: Runtime>(
33 client: &ComputeClient<<R as Runtime>::Server>,
34 cube_dim: CubeDim,
35 cube_count: CubeCount,
36 input: InputRuntimeArg<'a, MA, R>,
37 output: OutputRuntimeArg<'a, MA, R>,
38 problem: &ConvolutionProblem,
39 config: Config,
40 dtypes: &MatmulElems,
41 );
42}
43
44#[cube(launch_unchecked)]
45pub(crate) fn implicit_conv<
46 Args: MatmulArgs,
47 LhsG: Numeric,
48 RhsG: Numeric,
49 AccG: Numeric,
50 LhsS: Numeric,
51 RhsS: Numeric,
52 AccS: Numeric,
53 GMM: GlobalConvolutionFamily,
54>(
55 inputs: &Input<Args, LhsG, RhsG, AccG>,
56 output: &mut Output<Args, AccG>,
57 runtime_args: RuntimeArgs,
58 #[comptime] config: GMM::Config,
59 #[define(LhsG)] _lhs_global: StorageType,
60 #[define(RhsG)] _rhs_global: StorageType,
61 #[define(AccG)] _acc_global: StorageType,
62 #[define(LhsS)] _lhs_stage: StorageType,
63 #[define(RhsS)] _rhs_stage: StorageType,
64 #[define(AccS)] _acc_stage: StorageType,
65) {
66 let mut state = Args::init_state::<
67 LhsG,
68 RhsG,
69 AccG,
70 <GMM::Config as ConvGemmConfig>::GlobalMatmulConfig,
71 >(inputs, output, config.matmul_config());
72
73 let lhs = Args::view_lhs(&state);
74 let rhs = Args::view_rhs(&state);
75 let bias = Args::view_acc(&state);
76 let out = Args::view_out(&mut state);
77
78 let stage_m = config
79 .matmul_config()
80 .stage_config()
81 .elements_in_stage_m()
82 .runtime();
83 let stage_n = config
84 .matmul_config()
85 .stage_config()
86 .elements_in_stage_n()
87 .runtime();
88
89 let m_offset = CUBE_POS_X * stage_m;
90 let n_offset = CUBE_POS_Y * stage_n;
91
92 let k_range = (0, runtime_args.shape_k);
93 let k_size = runtime_args.shape_k;
94
95 let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
96 let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
97 let bias = match bias {
98 CubeOption::Some(bias) => {
99 let view = bias.view(SliceIndex::new(0, bias.shape()));
100 CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
101 }
102 CubeOption::None => CubeOption::new_None(),
103 };
104 let out = out.view_mut(SliceIndex::new(0, out.shape()));
105
106 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::execute(
107 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_lhs_global_reader(
108 lhs,
109 (m_offset, k_range.0),
110 (stage_m, k_size),
111 &runtime_args,
112 config,
113 ),
114 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_rhs_global_reader(
115 rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
116 config,
117 ),
118 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_bias_global_reader(
119 bias, config,
120 ),
121 GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_global_writer(
122 out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
123 config,
124 ),
125 &mut GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_accumulator(config),
126 k_range,
127 config,
128 );
129}
130
131pub(crate) fn shape_divmod<'a, R: Runtime>(
132 client: &ComputeClient<R::Server>,
133 shape: &[usize],
134) -> SequenceArg<'a, R, FastDivmod> {
135 shape
136 .iter()
137 .map(|s| FastDivmodArgs::new(client, *s as u32))
138 .collect()
139}