1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_core::{Runtime, client::ComputeClient};
4use cubecl_matmul::components::MatmulElems;
5use cubecl_matmul::components::global::GlobalConfig;
6use cubecl_matmul::components::stage::StageConfig as _;
7use cubecl_matmul::components::{
8 InputRuntimeArg, OutputRuntimeArg, batch::SliceIndex, global::args::MatmulArgs,
9};
10use cubecl_std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs};
11
12use crate::components::{ConvGemmConfig, global::args::RuntimeArgs};
13use crate::components::{
14 ConvolutionProblem,
15 global::{GlobalConvolution, GlobalConvolutionFamily},
16};
17
18type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
19type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
20
21pub trait ConvolutionLaunch<Config> {
23 #[allow(clippy::too_many_arguments)]
29 unsafe fn launch_unchecked<'a, MA: MatmulArgs, R: Runtime>(
30 client: &ComputeClient<R>,
31 cube_dim: CubeDim,
32 cube_count: CubeCount,
33 input: InputRuntimeArg<'a, MA, R>,
34 output: OutputRuntimeArg<'a, MA, R>,
35 problem: &ConvolutionProblem,
36 config: Config,
37 dtypes: &MatmulElems,
38 ) -> Result<(), LaunchError>;
39}
40
41#[cube(launch_unchecked)]
42pub(crate) fn implicit_conv<
43 Args: MatmulArgs,
44 LhsG: Numeric,
45 RhsG: Numeric,
46 AccG: Numeric,
47 LhsS: Numeric,
48 RhsS: Numeric,
49 AccS: Numeric,
50 LhsR: Numeric,
51 RhsR: Numeric,
52 AccR: Numeric,
53 GMM: GlobalConvolutionFamily,
54>(
55 inputs: &Input<Args, LhsG, RhsG, AccG>,
56 output: &mut Output<Args, AccG>,
57 runtime_args: RuntimeArgs,
58 #[comptime] config: GMM::Config,
59 #[define(LhsG, RhsG, AccG)] _global: [StorageType; 3],
60 #[define(LhsS, RhsS, AccS)] _stage: [StorageType; 3],
61 #[define(LhsR, RhsR, AccR)] _register: [StorageType; 3],
62) {
63 let mut state = Args::init_state::<
64 LhsG,
65 RhsG,
66 AccG,
67 <GMM::Config as ConvGemmConfig>::GlobalMatmulConfig,
68 >(inputs, output, config.matmul_config());
69
70 let lhs = Args::view_lhs(&state);
71 let rhs = Args::view_rhs(&state);
72 let bias = Args::view_acc(&state);
73 let out = Args::view_out(&mut state);
74
75 let stage_m = config
76 .matmul_config()
77 .stage_config()
78 .elements_in_stage_m()
79 .runtime();
80 let stage_n = config
81 .matmul_config()
82 .stage_config()
83 .elements_in_stage_n()
84 .runtime();
85
86 let m_offset = CUBE_POS_X * stage_m;
87 let n_offset = CUBE_POS_Y * stage_n;
88
89 let k_range = (0, runtime_args.shape_k);
90 let k_size = runtime_args.shape_k;
91
92 let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
93 let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
94 let bias = match bias {
95 CubeOption::Some(bias) => {
96 let view = bias.view(SliceIndex::new(0, bias.shape()));
97 CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
98 }
99 CubeOption::None => CubeOption::new_None(),
100 };
101 let out = out.view_mut(SliceIndex::new(0, out.shape()));
102
103 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::execute(
104 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_lhs_global_reader(
105 lhs,
106 (m_offset, k_range.0),
107 (stage_m, k_size),
108 &runtime_args,
109 config,
110 ),
111 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_rhs_global_reader(
112 rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
113 &runtime_args,
114 config,
115 ),
116 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_bias_global_reader(
117 bias, config,
118 ),
119 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_global_writer(
120 out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
121 config,
122 ),
123 &mut GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_accumulator(
124 config,
125 ),
126 k_range,
127 config,
128 );
129}
130
131pub(crate) fn shape_divmod<'a, R: Runtime>(
132 client: &ComputeClient<R>,
133 shape: &[usize],
134) -> SequenceArg<'a, R, FastDivmod> {
135 shape
136 .iter()
137 .map(|s| FastDivmodArgs::new(client, *s as u32))
138 .collect()
139}