1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::{
4 CubeOptionArgs, FastDivmodArgs,
5 tensor::{
6 launch::ViewArg,
7 layout::{
8 VirtualLayoutLaunch,
9 chain::{Chain, ChainLaunch},
10 },
11 },
12};
13
14use crate::{
15 components::{
16 ConvGemmConfig, ConvolutionProblem,
17 global::{
18 layout::{
19 BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
20 NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
21 },
22 read::layout::{
23 TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
24 },
25 },
26 },
27 kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
28};
29use cubecl_matmul::{
30 MatmulInputHandleRef,
31 components::{
32 MatmulElems, MatmulLineSizes, MatmulSelection,
33 global::{
34 GlobalConfig,
35 args::{
36 TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch,
37 TensorOutput, TensorOutputLaunch,
38 },
39 memory::{NoopLayout, NoopLayoutLaunch},
40 },
41 stage::StageConfig as _,
42 },
43};
44
45pub trait ConcreteInputsFactory: LaunchArg {
48 #[allow(clippy::too_many_arguments)]
49 fn create<'a, R: Runtime>(
50 client: &ComputeClient<R>,
51 lhs: &'a MatmulInputHandleRef<'a, R>,
52 rhs: &'a MatmulInputHandleRef<'a, R>,
53 bias: Option<&'a TensorHandleRef<'a, R>>,
54 selection: &MatmulSelection,
55 problem: &ConvolutionProblem,
56 line_sizes: &MatmulLineSizes,
57 config: impl ConvGemmConfig,
58 dtypes: &MatmulElems,
59 ) -> Self::RuntimeArg<'a, R>;
60}
61
62pub trait ConcreteOutputFactory: LaunchArg {
65 fn create<'a, R: Runtime>(
66 client: &ComputeClient<R>,
67 out: &'a TensorHandleRef<'a, R>,
68 selection: &MatmulSelection,
69 problem: &ConvolutionProblem,
70 line_sizes: &MatmulLineSizes,
71 config: impl ConvGemmConfig,
72 dtypes: &MatmulElems,
73 ) -> Self::RuntimeArg<'a, R>;
74}
75
76impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
77 fn create<'a, R: Runtime>(
78 client: &ComputeClient<R>,
79 lhs: &'a MatmulInputHandleRef<'a, R>,
80 rhs: &'a MatmulInputHandleRef<'a, R>,
81 bias: Option<&'a TensorHandleRef<'a, R>>,
82 _selection: &MatmulSelection,
83 problem: &ConvolutionProblem,
84 line_sizes: &MatmulLineSizes,
85 config: impl ConvGemmConfig,
86 _dtypes: &MatmulElems,
87 ) -> Self::RuntimeArg<'a, R> {
88 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
89 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
90
91 let layout_nhwc = |handle, line_size, check| {
92 NhwcLayoutLaunch::from_handle(handle, line_size as u32, check)
93 };
94 let layout_lhs = Im2colLayoutLaunch::from_args(
95 client,
96 problem,
97 config.convolution_params(),
98 config.lhs_global_memory_config(),
99 );
100 let layout_rhs = WeightLayoutLaunch::from_args(
101 client,
102 problem,
103 config.convolution_params(),
104 config.rhs_global_memory_config(),
105 );
106 let layout_bias =
107 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
108
109 let layout_lhs = {
110 let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
111 ChainLaunch::new(global, layout_lhs)
112 };
113 let layout_rhs = {
114 let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
115 ChainLaunch::new(global, layout_rhs)
116 };
117
118 TensorInputsLaunch::new(
119 ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
120 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
121 ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
122 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
123 bias.map(|bias| {
124 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
125 })
126 .into(),
127 bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
128 .into(),
129 )
130 }
131}
132
133impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
134 fn create<'a, R: Runtime>(
135 client: &ComputeClient<R>,
136 out: &'a TensorHandleRef<'a, R>,
137 _selection: &MatmulSelection,
138 problem: &ConvolutionProblem,
139 line_sizes: &MatmulLineSizes,
140 config: impl ConvGemmConfig,
141 _dtypes: &MatmulElems,
142 ) -> Self::RuntimeArg<'a, R> {
143 type Layout = Chain<NhwcLayout, OutLayout>;
144
145 let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false);
146 let layout = OutLayoutLaunch::from_args(client, problem, config.out_global_memory_config());
147 let layout = ChainLaunch::new(global, layout);
148 let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
149 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
150 TensorOutputLaunch::new(view, batch)
151 }
152}
153
154impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
155 for TensorMapInputs<Lhs, Rhs, EO>
156{
157 fn create<'a, R: Runtime>(
158 client: &ComputeClient<R>,
159 lhs: &'a MatmulInputHandleRef<'a, R>,
160 rhs: &'a MatmulInputHandleRef<'a, R>,
161 bias: Option<&'a TensorHandleRef<'a, R>>,
162 selection: &MatmulSelection,
163 problem: &ConvolutionProblem,
164 line_sizes: &MatmulLineSizes,
165 config: impl ConvGemmConfig,
166 dtypes: &MatmulElems,
167 ) -> Self::RuntimeArg<'a, R> {
168 let tiling_scheme = selection.tiling_scheme;
169 let stage_m = tiling_scheme.elements_per_stage_along_m();
170 let stage_n = tiling_scheme.elements_per_stage_along_n();
171 let tile_size_k = tiling_scheme.tile_size.k;
172 let stage_size_rhs = vec![stage_n, 1, tile_size_k];
173
174 let lhs_elem_size = size_of::<Lhs>();
175 let rhs_elem_size = size_of::<Rhs>();
176
177 fn prefetch(bytes: usize) -> TensorMapPrefetch {
178 match bytes {
179 ..64 => TensorMapPrefetch::None,
180 64..128 => TensorMapPrefetch::B64,
181 128..256 => TensorMapPrefetch::B128,
182 256.. => TensorMapPrefetch::B256,
183 }
184 }
185
186 let prefetch_lhs = prefetch(tile_size_k as usize * lhs_elem_size);
187 let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * rhs_elem_size);
188
189 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
192 tf32::as_type_native_unchecked()
193 } else {
194 dtypes.lhs_stage
195 };
196
197 let mut elem_stride = vec![1; 2 + problem.stride.len()];
198
199 for (i, stride) in problem.stride.iter().enumerate() {
200 elem_stride[i + 1] = *stride as usize;
201 }
202
203 let lhs = TensorMapArg::new(
204 TensorMapFormat::Im2col {
205 pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
206 pixel_box_upper_corner: calculate_upper_corner(
207 &problem.padding,
208 &problem.kernel_size,
209 &problem.dilation,
210 ),
211 channels_per_pixel: tile_size_k,
212 pixels_per_column: stage_m,
213 },
214 lhs.data().as_tensor_arg(line_sizes.lhs),
215 lhs_elem,
216 )
217 .with_elem_stride(elem_stride)
218 .with_prefetch(prefetch_lhs);
219
220 let rhs = TensorMapArg::new(
221 TensorMapFormat::Tiled {
222 tile_size: stage_size_rhs,
223 },
224 rhs.data().as_tensor_arg(1),
225 dtypes.rhs_global,
226 )
227 .with_prefetch(prefetch_rhs);
228
229 let padded_channels = (problem.channels as u32)
230 .next_multiple_of(config.matmul_config().stage_config().elements_in_tile_k());
231
232 let lhs_layout = TmaDummyLayoutLaunch::new();
234 let rhs_layout = TmaWeightLayoutLaunch::new(FastDivmodArgs::new(client, padded_channels));
235
236 let bias = bias.map(|bias| {
237 let layout =
238 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
239 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
240 });
241
242 TensorMapInputsLaunch::new(
243 ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
244 ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
245 bias.into(),
246 CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
247 NoopLayoutLaunch::new(),
248 )),
249 )
250 }
251}