1use std::any::TypeId;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5use cubecl_std::{
6 FastDivmodArgs,
7 tensor::{
8 View,
9 launch::ViewArg,
10 layout::{
11 Coords3d,
12 chain::{Chain, ChainLaunch},
13 },
14 },
15};
16
17use crate::{
18 components::{
19 ConvGemmConfig, ConvolutionProblem,
20 global::{
21 layout::{
22 BiasLayout, BiasLayoutLaunch, Im2colLayout, Im2colLayoutLaunch, NhwcLayout,
23 NhwcLayoutLaunch, OutLayout, OutLayoutLaunch, WeightLayout, WeightLayoutLaunch,
24 },
25 read::layout::{
26 TmaDummyLayout, TmaDummyLayoutLaunch, TmaWeightLayout, TmaWeightLayoutLaunch,
27 },
28 },
29 },
30 kernels::layered::algorithm::simple_tma::{calculate_lower_corner, calculate_upper_corner},
31};
32use cubecl_matmul::{
33 MatmulInputHandleRef,
34 components::{
35 MatmulIdent, MatmulLineSizes, MatmulSelection,
36 global::args::{TensorInputs, TensorInputsLaunch, TensorMapInputs, TensorMapInputsLaunch},
37 },
38};
39
40pub trait ConcreteInputsFactory: LaunchArg {
43 #[allow(clippy::too_many_arguments)]
44 fn create<'a, R: Runtime>(
45 client: &ComputeClient<R::Server>,
46 lhs: &'a MatmulInputHandleRef<'a, R>,
47 rhs: &'a MatmulInputHandleRef<'a, R>,
48 bias: Option<&'a TensorHandleRef<'a, R>>,
49 selection: &MatmulSelection,
50 problem: &ConvolutionProblem,
51 line_sizes: &MatmulLineSizes,
52 config: impl ConvGemmConfig,
53 ) -> Self::RuntimeArg<'a, R>;
54}
55
56pub trait ConcreteOutputFactory: LaunchArg {
59 fn create<'a, R: Runtime>(
60 client: &ComputeClient<R::Server>,
61 out: &'a TensorHandleRef<'a, R>,
62 selection: &MatmulSelection,
63 problem: &ConvolutionProblem,
64 line_sizes: &MatmulLineSizes,
65 config: impl ConvGemmConfig,
66 ) -> Self::RuntimeArg<'a, R>;
67}
68
69impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
70 fn create<'a, R: Runtime>(
71 client: &ComputeClient<R::Server>,
72 lhs: &'a MatmulInputHandleRef<'a, R>,
73 rhs: &'a MatmulInputHandleRef<'a, R>,
74 bias: Option<&'a TensorHandleRef<'a, R>>,
75 _selection: &MatmulSelection,
76 problem: &ConvolutionProblem,
77 line_sizes: &MatmulLineSizes,
78 config: impl ConvGemmConfig,
79 ) -> Self::RuntimeArg<'a, R> {
80 type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
81 type RhsLayout = Chain<NhwcLayout, WeightLayout>;
82
83 let layout_nhwc = |handle, line_size, check| {
84 NhwcLayoutLaunch::from_handle(handle, line_size as u32, check)
85 };
86 let layout_lhs = Im2colLayoutLaunch::from_args(
87 client,
88 problem,
89 config.convolution_params(),
90 config.global_memory_config(MatmulIdent::Lhs),
91 );
92 let layout_rhs = WeightLayoutLaunch::from_args(
93 client,
94 problem,
95 config.convolution_params(),
96 config.global_memory_config(MatmulIdent::Rhs),
97 );
98 let layout_bias =
99 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
100
101 let layout_lhs = {
102 let global = layout_nhwc(lhs.data(), line_sizes.lhs, config.check_spatial_bounds());
103 ChainLaunch::new(global, layout_lhs)
104 };
105 let layout_rhs = {
106 let global = layout_nhwc(rhs.data(), line_sizes.rhs, false);
107 ChainLaunch::new(global, layout_rhs)
108 };
109
110 TensorInputsLaunch::new(
111 ViewArg::new::<LhsLayout>(lhs.data().as_array_arg(line_sizes.lhs), layout_lhs),
112 ViewArg::new::<RhsLayout>(rhs.data().as_array_arg(line_sizes.rhs), layout_rhs),
113 bias.map(|bias| {
114 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout_bias)
115 })
116 .into(),
117 )
118 }
119}
120
121impl<EG: Numeric> ConcreteOutputFactory for View<Line<EG>, Coords3d, ReadWrite> {
122 fn create<'a, R: Runtime>(
123 client: &ComputeClient<R::Server>,
124 out: &'a TensorHandleRef<'a, R>,
125 _selection: &MatmulSelection,
126 problem: &ConvolutionProblem,
127 line_sizes: &MatmulLineSizes,
128 config: impl ConvGemmConfig,
129 ) -> Self::RuntimeArg<'a, R> {
130 type Layout = Chain<NhwcLayout, OutLayout>;
131
132 let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out as u32, false);
133 let layout = OutLayoutLaunch::from_args(
134 client,
135 problem,
136 config.global_memory_config(MatmulIdent::Out),
137 );
138 let layout = ChainLaunch::new(global, layout);
139 ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout)
140 }
141}
142
143impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
144 for TensorMapInputs<Lhs, Rhs, EO>
145{
146 fn create<'a, R: Runtime>(
147 client: &ComputeClient<R::Server>,
148 lhs: &'a MatmulInputHandleRef<'a, R>,
149 rhs: &'a MatmulInputHandleRef<'a, R>,
150 bias: Option<&'a TensorHandleRef<'a, R>>,
151 selection: &MatmulSelection,
152 problem: &ConvolutionProblem,
153 line_sizes: &MatmulLineSizes,
154 config: impl ConvGemmConfig,
155 ) -> Self::RuntimeArg<'a, R> {
156 let tiling_scheme = selection.tiling_scheme;
157 let stage_m = tiling_scheme.elements_in_stage_m();
158 let stage_n = tiling_scheme.elements_in_stage_n();
159 let tile_size_k = tiling_scheme.elements_in_tile_k();
160 let stage_size_rhs = vec![stage_n, 1, tile_size_k];
161
162 let lhs_elem_size = size_of::<Lhs>();
163 let rhs_elem_size = size_of::<Rhs>();
164
165 fn prefetch(bytes: usize) -> TensorMapPrefetch {
166 match bytes {
167 ..64 => TensorMapPrefetch::None,
168 64..128 => TensorMapPrefetch::B64,
169 128..256 => TensorMapPrefetch::B128,
170 256.. => TensorMapPrefetch::B256,
171 }
172 }
173
174 let prefetch_lhs = prefetch(tile_size_k as usize * lhs_elem_size);
175 let prefetch_rhs = prefetch(stage_size_rhs[2] as usize * rhs_elem_size);
176
177 let lhs_elem = if TypeId::of::<Lhs>() == TypeId::of::<f32>() {
180 tf32::as_type_native_unchecked()
181 } else {
182 Lhs::as_type_native_unchecked()
183 };
184
185 let mut elem_stride = vec![1; 2 + problem.stride.len()];
186
187 for (i, stride) in problem.stride.iter().enumerate() {
188 elem_stride[i + 1] = *stride as usize;
189 }
190
191 let lhs = TensorMapArg::new(
192 TensorMapFormat::Im2col {
193 pixel_box_lower_corner: calculate_lower_corner(&problem.padding),
194 pixel_box_upper_corner: calculate_upper_corner(
195 &problem.padding,
196 &problem.kernel_size,
197 &problem.dilation,
198 ),
199 channels_per_pixel: tile_size_k,
200 pixels_per_column: stage_m,
201 },
202 lhs.data().as_tensor_arg(line_sizes.lhs),
203 lhs_elem,
204 )
205 .with_elem_stride(elem_stride)
206 .with_prefetch(prefetch_lhs);
207
208 let rhs = TensorMapArg::new(
209 TensorMapFormat::Tiled {
210 tile_size: stage_size_rhs,
211 },
212 rhs.data().as_tensor_arg(1),
213 Rhs::as_type_native_unchecked(),
214 )
215 .with_prefetch(prefetch_rhs);
216
217 let padded_channels =
218 (problem.channels as u32).next_multiple_of(config.tiling_scheme().elements_in_tile_k());
219
220 let lhs_layout = TmaDummyLayoutLaunch::new();
222 let rhs_layout = TmaWeightLayoutLaunch::new(FastDivmodArgs::new(client, padded_channels));
223
224 let bias = bias.map(|bias| {
225 let layout =
226 BiasLayoutLaunch::new(ScalarArg::new(problem.n as u32), line_sizes.out as u32);
227 ViewArg::new::<BiasLayout>(bias.as_array_arg(line_sizes.out), layout)
228 });
229
230 TensorMapInputsLaunch::new(
231 ViewArg::new_tensor_map::<TmaDummyLayout>(lhs, lhs_layout),
232 ViewArg::new_tensor_map::<TmaWeightLayout>(rhs, rhs_layout),
233 bias.into(),
234 )
235 }
236}