1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_std::{
4 CubeOptionArgs, FastDivmod, FastDivmodArgs,
5 tensor::{
6 launch::ViewArg,
7 layout::{
8 VirtualLayoutLaunch,
9 chain::{Chain, ChainLaunch},
10 },
11 },
12};
13
14use crate::{
15 components::{
16 ConvGemmConfig, ConvolutionProblem,
17 global::{
18 layout::{
19 BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
20 NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
21 },
22 read::layout::{
23 TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
24 },
25 },
26 },
27 kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
28};
29use cubecl_matmul::{
30 MatmulInputHandleRef,
31 components::{
32 MatmulElems, MatmulLineSizes, MatmulSelection,
33 global::{
34 GlobalConfig,
35 args::{
36 TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch,
37 TensorOutput, TensorOutputLaunch,
38 },
39 memory::{NoopLayout, NoopLayoutLaunch},
40 },
41 stage::StageConfig as _,
42 },
43};
44
45#[derive(CubeType, CubeLaunch, Clone)]
46pub struct RuntimeArgs {
47 pub shape_m: u32,
48 pub shape_n: u32,
49 pub shape_k: u32,
50 pub channels: u32,
51 pub padded_channels: FastDivmod,
52 pub shape_out: Sequence<FastDivmod>,
53}
54
55pub trait ConcreteInputsFactory: LaunchArg {
58 #[allow(clippy::too_many_arguments)]
59 fn create<'a, R: Runtime>(
60 client: &ComputeClient<R>,
61 lhs: &'a MatmulInputHandleRef<'a, R>,
62 rhs: &'a MatmulInputHandleRef<'a, R>,
63 bias: Option<&'a TensorHandleRef<'a, R>>,
64 selection: &MatmulSelection,
65 problem: &ConvolutionProblem,
66 line_sizes: &MatmulLineSizes,
67 config: impl ConvGemmConfig,
68 dtypes: &MatmulElems,
69 ) -> Self::RuntimeArg<'a, R>;
70}
71
72pub trait ConcreteOutputFactory: LaunchArg {
75 fn create<'a, R: Runtime>(
76 client: &ComputeClient<R>,
77 out: &'a TensorHandleRef<'a, R>,
78 selection: &MatmulSelection,
79 problem: &ConvolutionProblem,
80 line_sizes: &MatmulLineSizes,
81 config: impl ConvGemmConfig,
82 dtypes: &MatmulElems,
83 ) -> Self::RuntimeArg<'a, R>;
84}
85
86impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
87 fn create<'a, R: Runtime>(
88 client: &ComputeClient<R>,
89 lhs: &'a MatmulInputHandleRef<'a, R>,
90 rhs: &'a MatmulInputHandleRef<'a, R>,
91 bias: Option<&'a TensorHandleRef<'a, R>>,
92 _selection: &MatmulSelection,
93 problem: &ConvolutionProblem,
94 line_sizes: &MatmulLineSizes,
95 config: impl ConvGemmConfig,
96 dtypes: &MatmulElems,
97 ) -> Self::RuntimeArg<'a, R> {
98 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
99 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
100
101 let load_width = client.properties().hardware.load_width;
102 let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
103
104 let layout_nhwc = |handle, line_size, check_spatial| {
105 NhwcLayoutLaunch::from_handle(
106 handle,
107 line_size as u32,
108 check_spatial,
109 !problem.channels.is_multiple_of(channel_align),
110 )
111 };
112 let layout_lhs = Im2colLayoutLaunch::from_args(
113 client,
114 problem,
115 config.convolution_params(),
116 config.lhs_global_memory_config(),
117 dtypes,
118 );
119 let layout_rhs = WeightLayoutLaunch::from_args(
120 client,
121 problem,
122 config.convolution_params(),
123 config.rhs_global_memory_config(),
124 dtypes,
125 );
126 let layout_bias =
127 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
128
129 let layout_lhs = {
130 let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
131 ChainLaunch::new(global, layout_lhs)
132 };
133 let layout_rhs = {
134 let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
135 ChainLaunch::new(global, layout_rhs)
136 };
137
138 TensorInputsLaunch::new(
139 ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
140 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
141 ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
142 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
143 bias.map(|bias| {
144 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
145 })
146 .into(),
147 bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
148 .into(),
149 )
150 }
151}
152
153impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
154 fn create<'a, R: Runtime>(
155 client: &ComputeClient<R>,
156 out: &'a TensorHandleRef<'a, R>,
157 _selection: &MatmulSelection,
158 problem: &ConvolutionProblem,
159 line_sizes: &MatmulLineSizes,
160 config: impl ConvGemmConfig,
161 _dtypes: &MatmulElems,
162 ) -> Self::RuntimeArg<'a, R> {
163 type Layout = Chain<NhwcLayout, OutLayout>;
164
165 let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false, false);
166 let layout = OutLayoutLaunch::from_args(client, problem, config.out_global_memory_config());
167 let layout = ChainLaunch::new(global, layout);
168 let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
169 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
170 TensorOutputLaunch::new(view, batch)
171 }
172}
173
174impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
175 for TensorMapInputs<Lhs, Rhs, EO>
176{
177 fn create<'a, R: Runtime>(
178 client: &ComputeClient<R>,
179 lhs: &'a MatmulInputHandleRef<'a, R>,
180 rhs: &'a MatmulInputHandleRef<'a, R>,
181 bias: Option<&'a TensorHandleRef<'a, R>>,
182 selection: &MatmulSelection,
183 problem: &ConvolutionProblem,
184 line_sizes: &MatmulLineSizes,
185 config: impl ConvGemmConfig,
186 dtypes: &MatmulElems,
187 ) -> Self::RuntimeArg<'a, R> {
188 let tiling_scheme = selection.tiling_scheme;
189 let stage_m = tiling_scheme.elements_per_stage_along_m();
190 let stage_n = tiling_scheme.elements_per_stage_along_n();
191 let tile_size_k = tiling_scheme.tile_size.k;
192
193 let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims() as usize];
194 stage_size_rhs.insert(0, stage_n);
195 stage_size_rhs.push(tile_size_k);
196
197 let lhs_elem = if *dtypes.lhs_stage == f32::as_type_native_unchecked() {
200 tf32::as_type_native_unchecked()
201 } else {
202 *dtypes.lhs_stage
203 };
204
205 let mut elem_stride = vec![1; 2 + problem.stride.len()];
206
207 for (i, stride) in problem.stride.iter().enumerate() {
208 elem_stride[i + 1] = *stride as usize;
209 }
210
211 let lhs = TensorMapArg::new(
212 TensorMapFormat::Im2col {
213 pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
214 pixel_box_upper_corner: calculate_upper_corner(
215 &problem.padding,
216 &problem.kernel_size,
217 &problem.dilation,
218 ),
219 channels_per_pixel: tile_size_k,
220 pixels_per_column: stage_m,
221 },
222 lhs.data().as_tensor_arg(line_sizes.lhs),
223 lhs_elem,
224 )
225 .with_elem_stride(elem_stride);
226
227 let rhs = TensorMapArg::new(
228 TensorMapFormat::Tiled {
229 tile_size: stage_size_rhs,
230 },
231 rhs.data().as_tensor_arg(1),
232 *dtypes.rhs_global,
233 );
234
235 let padded_channels = (problem.channels as u32)
236 .next_multiple_of(config.matmul_config().stage_config().elements_in_tile_k());
237
238 let lhs_layout = TmaDummyLayoutLaunch::new();
240 let rhs_layout = TmaWeightLayoutLaunch::new(
241 FastDivmodArgs::new(client, padded_channels),
242 problem.kernel_size.clone(),
243 );
244
245 let bias = bias.map(|bias| {
246 let layout =
247 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
248 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
249 });
250
251 TensorMapInputsLaunch::new(
252 ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
253 ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
254 bias.into(),
255 CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
256 NoopLayoutLaunch::new(),
257 )),
258 )
259 }
260}