1use cubecl::prelude::*;
2use cubecl::std::{
3 CubeOptionArgs, FastDivmod, FastDivmodArgs,
4 tensor::{
5 launch::ViewArg,
6 layout::{
7 VirtualLayoutLaunch,
8 chain::{Chain, ChainLaunch},
9 },
10 },
11};
12
13use crate::{
14 components::{
15 ConvGemmConfig, ConvolutionProblem,
16 global::{
17 layout::{
18 BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
19 NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
20 },
21 read::layout::{
22 TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
23 },
24 },
25 },
26 kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
27};
28use cubek_matmul::{
29 MatmulInputHandleRef,
30 components::{
31 MatmulElems, MatmulLineSizes, MatmulSelection,
32 global::{
33 GlobalConfig,
34 args::{
35 TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch,
36 TensorOutput, TensorOutputLaunch,
37 },
38 memory::{NoopLayout, NoopLayoutLaunch},
39 },
40 stage::StageConfig as _,
41 },
42};
43
44#[derive(CubeType, CubeLaunch, Clone)]
45pub struct RuntimeArgs {
46 pub shape_m: u32,
47 pub shape_n: u32,
48 pub shape_k: u32,
49 pub channels: u32,
50 pub padded_channels: FastDivmod,
51 pub shape_out: Sequence<FastDivmod>,
52}
53
54pub trait ConcreteInputsFactory: LaunchArg {
57 #[allow(clippy::too_many_arguments)]
58 fn create<'a, R: Runtime>(
59 client: &ComputeClient<R>,
60 lhs: &'a MatmulInputHandleRef<'a, R>,
61 rhs: &'a MatmulInputHandleRef<'a, R>,
62 bias: Option<&'a TensorHandleRef<'a, R>>,
63 selection: &MatmulSelection,
64 problem: &ConvolutionProblem,
65 line_sizes: &MatmulLineSizes,
66 config: impl ConvGemmConfig,
67 dtypes: &MatmulElems,
68 ) -> Self::RuntimeArg<'a, R>;
69}
70
71pub trait ConcreteOutputFactory: LaunchArg {
74 fn create<'a, R: Runtime>(
75 client: &ComputeClient<R>,
76 out: &'a TensorHandleRef<'a, R>,
77 selection: &MatmulSelection,
78 problem: &ConvolutionProblem,
79 line_sizes: &MatmulLineSizes,
80 config: impl ConvGemmConfig,
81 dtypes: &MatmulElems,
82 ) -> Self::RuntimeArg<'a, R>;
83}
84
85impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
86 fn create<'a, R: Runtime>(
87 client: &ComputeClient<R>,
88 lhs: &'a MatmulInputHandleRef<'a, R>,
89 rhs: &'a MatmulInputHandleRef<'a, R>,
90 bias: Option<&'a TensorHandleRef<'a, R>>,
91 _selection: &MatmulSelection,
92 problem: &ConvolutionProblem,
93 line_sizes: &MatmulLineSizes,
94 config: impl ConvGemmConfig,
95 dtypes: &MatmulElems,
96 ) -> Self::RuntimeArg<'a, R> {
97 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
98 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
99
100 let load_width = client.properties().hardware.load_width;
101 let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
102
103 let layout_nhwc = |handle, line_size, check_spatial| {
104 NhwcLayoutLaunch::from_handle(
105 handle,
106 line_size as u32,
107 check_spatial,
108 !problem.channels.is_multiple_of(channel_align),
109 )
110 };
111 let layout_lhs = Im2colLayoutLaunch::from_args(
112 client,
113 problem,
114 config.convolution_params(),
115 config.lhs_global_memory_config(),
116 dtypes,
117 );
118 let layout_rhs = WeightLayoutLaunch::from_args(
119 client,
120 problem,
121 config.convolution_params(),
122 config.rhs_global_memory_config(),
123 dtypes,
124 );
125 let layout_bias =
126 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
127
128 let layout_lhs = {
129 let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
130 ChainLaunch::new(global, layout_lhs)
131 };
132 let layout_rhs = {
133 let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
134 ChainLaunch::new(global, layout_rhs)
135 };
136
137 TensorInputsLaunch::new(
138 ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
139 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
140 ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
141 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
142 bias.map(|bias| {
143 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
144 })
145 .into(),
146 bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
147 .into(),
148 )
149 }
150}
151
152impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
153 fn create<'a, R: Runtime>(
154 client: &ComputeClient<R>,
155 out: &'a TensorHandleRef<'a, R>,
156 _selection: &MatmulSelection,
157 problem: &ConvolutionProblem,
158 line_sizes: &MatmulLineSizes,
159 config: impl ConvGemmConfig,
160 _dtypes: &MatmulElems,
161 ) -> Self::RuntimeArg<'a, R> {
162 type Layout = Chain<NhwcLayout, OutLayout>;
163
164 let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false, false);
165 let layout = OutLayoutLaunch::from_args(client, problem, config.out_global_memory_config());
166 let layout = ChainLaunch::new(global, layout);
167 let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
168 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
169 TensorOutputLaunch::new(view, batch)
170 }
171}
172
173impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
174 for TensorMapInputs<Lhs, Rhs, EO>
175{
176 fn create<'a, R: Runtime>(
177 client: &ComputeClient<R>,
178 lhs: &'a MatmulInputHandleRef<'a, R>,
179 rhs: &'a MatmulInputHandleRef<'a, R>,
180 bias: Option<&'a TensorHandleRef<'a, R>>,
181 selection: &MatmulSelection,
182 problem: &ConvolutionProblem,
183 line_sizes: &MatmulLineSizes,
184 config: impl ConvGemmConfig,
185 dtypes: &MatmulElems,
186 ) -> Self::RuntimeArg<'a, R> {
187 let tiling_scheme = selection.tiling_scheme;
188 let stage_m = tiling_scheme.elements_per_stage_along_m();
189 let stage_n = tiling_scheme.elements_per_stage_along_n();
190 let tile_size_k = tiling_scheme.tile_size.k;
191
192 let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims() as usize];
193 stage_size_rhs.insert(0, stage_n);
194 stage_size_rhs.push(tile_size_k);
195
196 let lhs_elem = if *dtypes.lhs_stage == f32::as_type_native_unchecked() {
199 tf32::as_type_native_unchecked()
200 } else {
201 *dtypes.lhs_stage
202 };
203
204 let mut elem_stride = vec![1; 2 + problem.stride.len()];
205
206 for (i, stride) in problem.stride.iter().enumerate() {
207 elem_stride[i + 1] = *stride as usize;
208 }
209
210 let lhs = TensorMapArg::new(
211 TensorMapFormat::Im2col {
212 pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
213 pixel_box_upper_corner: calculate_upper_corner(
214 &problem.padding,
215 &problem.kernel_size,
216 &problem.dilation,
217 ),
218 channels_per_pixel: tile_size_k,
219 pixels_per_column: stage_m,
220 },
221 lhs.data().as_tensor_arg(line_sizes.lhs),
222 lhs_elem,
223 )
224 .with_elem_stride(elem_stride);
225
226 let rhs = TensorMapArg::new(
227 TensorMapFormat::Tiled {
228 tile_size: stage_size_rhs,
229 },
230 rhs.data().as_tensor_arg(1),
231 *dtypes.rhs_global,
232 );
233
234 let padded_channels = (problem.channels as u32)
235 .next_multiple_of(config.matmul_config().stage_config().elements_in_tile_k());
236
237 let lhs_layout = TmaDummyLayoutLaunch::new();
239 let rhs_layout = TmaWeightLayoutLaunch::new(
240 FastDivmodArgs::new(client, padded_channels),
241 problem.kernel_size.clone(),
242 );
243
244 let bias = bias.map(|bias| {
245 let layout =
246 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
247 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
248 });
249
250 TensorMapInputsLaunch::new(
251 ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
252 ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
253 bias.into(),
254 CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
255 NoopLayoutLaunch::new(),
256 )),
257 )
258 }
259}