1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::{
4 CubeOptionArgs, FastDivmodArgs,
5 tensor::{
6 launch::ViewArg,
7 layout::{
8 VirtualLayoutLaunch,
9 chain::{Chain, ChainLaunch},
10 },
11 },
12};
13
14use crate::{
15 components::{
16 ConvGemmConfig, ConvolutionProblem,
17 global::{
18 layout::{
19 BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
20 NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
21 },
22 read::layout::{
23 TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
24 },
25 },
26 },
27 kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
28};
29use cubecl_matmul::{
30 MatmulInputHandleRef,
31 components::{
32 MatmulElems, MatmulIdent, MatmulLineSizes, MatmulSelection,
33 global::{
34 args::{
35 TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch,
36 TensorOutput, TensorOutputLaunch,
37 },
38 memory::{NoopLayout, NoopLayoutLaunch},
39 },
40 },
41};
42
43pub trait ConcreteInputsFactory: LaunchArg {
46 #[allow(clippy::too_many_arguments)]
47 fn create<'a, R: Runtime>(
48 client: &ComputeClient<R::Server>,
49 lhs: &'a MatmulInputHandleRef<'a, R>,
50 rhs: &'a MatmulInputHandleRef<'a, R>,
51 bias: Option<&'a TensorHandleRef<'a, R>>,
52 selection: &MatmulSelection,
53 problem: &ConvolutionProblem,
54 line_sizes: &MatmulLineSizes,
55 config: impl ConvGemmConfig,
56 dtypes: &MatmulElems,
57 ) -> Self::RuntimeArg<'a, R>;
58}
59
60pub trait ConcreteOutputFactory: LaunchArg {
63 fn create<'a, R: Runtime>(
64 client: &ComputeClient<R::Server>,
65 out: &'a TensorHandleRef<'a, R>,
66 selection: &MatmulSelection,
67 problem: &ConvolutionProblem,
68 line_sizes: &MatmulLineSizes,
69 config: impl ConvGemmConfig,
70 dtypes: &MatmulElems,
71 ) -> Self::RuntimeArg<'a, R>;
72}
73
74impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
75 fn create<'a, R: Runtime>(
76 client: &ComputeClient<R::Server>,
77 lhs: &'a MatmulInputHandleRef<'a, R>,
78 rhs: &'a MatmulInputHandleRef<'a, R>,
79 bias: Option<&'a TensorHandleRef<'a, R>>,
80 _selection: &MatmulSelection,
81 problem: &ConvolutionProblem,
82 line_sizes: &MatmulLineSizes,
83 config: impl ConvGemmConfig,
84 _dtypes: &MatmulElems,
85 ) -> Self::RuntimeArg<'a, R> {
86 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
87 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
88
89 let layout_nhwc = |handle, line_size, check| {
90 NhwcLayoutLaunch::from_handle(handle, line_size as u32, check)
91 };
92 let layout_lhs = Im2colLayoutLaunch::from_args(
93 client,
94 problem,
95 config.convolution_params(),
96 config.global_memory_config(MatmulIdent::Lhs),
97 );
98 let layout_rhs = WeightLayoutLaunch::from_args(
99 client,
100 problem,
101 config.convolution_params(),
102 config.global_memory_config(MatmulIdent::Rhs),
103 );
104 let layout_bias =
105 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
106
107 let layout_lhs = {
108 let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
109 ChainLaunch::new(global, layout_lhs)
110 };
111 let layout_rhs = {
112 let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
113 ChainLaunch::new(global, layout_rhs)
114 };
115
116 TensorInputsLaunch::new(
117 ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
118 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
119 ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
120 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
121 bias.map(|bias| {
122 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
123 })
124 .into(),
125 bias.map(|_| VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()))
126 .into(),
127 )
128 }
129}
130
131impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
132 fn create<'a, R: Runtime>(
133 client: &ComputeClient<R::Server>,
134 out: &'a TensorHandleRef<'a, R>,
135 _selection: &MatmulSelection,
136 problem: &ConvolutionProblem,
137 line_sizes: &MatmulLineSizes,
138 config: impl ConvGemmConfig,
139 _dtypes: &MatmulElems,
140 ) -> Self::RuntimeArg<'a, R> {
141 type Layout = Chain<NhwcLayout, OutLayout>;
142
143 let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false);
144 let layout = OutLayoutLaunch::from_args(
145 client,
146 problem,
147 config.global_memory_config(MatmulIdent::Out),
148 );
149 let layout = ChainLaunch::new(global, layout);
150 let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
151 let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
152 TensorOutputLaunch::new(view, batch)
153 }
154}
155
156impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
157 for TensorMapInputs<Lhs, Rhs, EO>
158{
159 fn create<'a, R: Runtime>(
160 client: &ComputeClient<R::Server>,
161 lhs: &'a MatmulInputHandleRef<'a, R>,
162 rhs: &'a MatmulInputHandleRef<'a, R>,
163 bias: Option<&'a TensorHandleRef<'a, R>>,
164 selection: &MatmulSelection,
165 problem: &ConvolutionProblem,
166 line_sizes: &MatmulLineSizes,
167 config: impl ConvGemmConfig,
168 dtypes: &MatmulElems,
169 ) -> Self::RuntimeArg<'a, R> {
170 let tiling_scheme = selection.tiling_scheme;
171 let stage_m = tiling_scheme.elements_in_stage_m();
172 let stage_n = tiling_scheme.elements_in_stage_n();
173 let tile_size_k = tiling_scheme.elements_in_tile_k();
174 let stage_size_rhs = vec![stage_n, 1, tile_size_k];
175
176 let lhs_elem_size = size_of::<Lhs>();
177 let rhs_elem_size = size_of::<Rhs>();
178
179 fn prefetch(bytes: usize) -> TensorMapPrefetch {
180 match bytes {
181 ..64 => TensorMapPrefetch::None,
182 64..128 => TensorMapPrefetch::B64,
183 128..256 => TensorMapPrefetch::B128,
184 256.. => TensorMapPrefetch::B256,
185 }
186 }
187
188 let prefetch_lhs = prefetch(tile_size_k as usize * lhs_elem_size);
189 let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * rhs_elem_size);
190
191 let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
194 tf32::as_type_native_unchecked()
195 } else {
196 dtypes.lhs_stage
197 };
198
199 let mut elem_stride = vec![1; 2 + problem.stride.len()];
200
201 for (i, stride) in problem.stride.iter().enumerate() {
202 elem_stride[i + 1] = *stride as usize;
203 }
204
205 let lhs = TensorMapArg::new(
206 TensorMapFormat::Im2col {
207 pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
208 pixel_box_upper_corner: calculate_upper_corner(
209 &problem.padding,
210 &problem.kernel_size,
211 &problem.dilation,
212 ),
213 channels_per_pixel: tile_size_k,
214 pixels_per_column: stage_m,
215 },
216 lhs.data().as_tensor_arg(line_sizes.lhs),
217 lhs_elem,
218 )
219 .with_elem_stride(elem_stride)
220 .with_prefetch(prefetch_lhs);
221
222 let rhs = TensorMapArg::new(
223 TensorMapFormat::Tiled {
224 tile_size: stage_size_rhs,
225 },
226 rhs.data().as_tensor_arg(1),
227 dtypes.rhs_global,
228 )
229 .with_prefetch(prefetch_rhs);
230
231 let padded_channels =
232 (problem.channels as u32).next_multiple_of(config.tiling_scheme().elements_in_tile_k());
233
234 let lhs_layout = TmaDummyLayoutLaunch::new();
236 let rhs_layout = TmaWeightLayoutLaunch::new(FastDivmodArgs::new(client, padded_channels));
237
238 let bias = bias.map(|bias| {
239 let layout =
240 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
241 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
242 });
243
244 TensorMapInputsLaunch::new(
245 ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
246 ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
247 bias.into(),
248 CubeOptionArgs::Some(VirtualLayoutLaunch::new::<NoopLayout>(
249 NoopLayoutLaunch::new(),
250 )),
251 )
252 }
253}