cubek_convolution/kernels/backward_data/
args.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    prelude::*,
5    std::{
6        CubeOptionArgs, FastDivmodArgs,
7        tensor::{
8            launch::ViewArg,
9            layout::{
10                VirtualLayoutLaunch,
11                chain::{Chain, ChainLaunch},
12            },
13        },
14    },
15};
16use cubek_matmul::{
17    components::{
18        global::{
19            GlobalConfig as _,
20            memory::{GlobalMemoryConfig, NoopLayout, NoopLayoutLaunch, ViewDirection},
21        },
22        stage::StageConfig as _,
23    },
24    definition::{MatmulElems, MatmulLineSizes, MatrixLayout, TilingBlueprint},
25    launch::{
26        MatmulArgs, MatmulInputHandleRef, TensorArgs, TensorInputs, TensorInputsLaunch,
27        TensorMapArgs, TensorMapInputs, TensorMapInputsLaunch, TensorOutput, TensorOutputLaunch,
28    },
29};
30use enumset::EnumSet;
31
32use crate::components::{
33    ConvGemmConfig, ConvolutionParams, ConvolutionProblem,
34    global::{
35        args::RuntimeArgsLaunch,
36        layout::{
37            Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch, OutLayout,
38            OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
39            WeightLayoutLaunch,
40        },
41    },
42};
43
44pub trait ConcreteArgs:
45    MatmulArgs<
46        Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>: ConcreteInputsFactory,
47        Output<NumericExpand<2>>: ConcreteOutputFactory,
48    >
49{
50    fn adjust_problem<R: Runtime>(
51        client: &ComputeClient<R>,
52        problem: ConvolutionProblem,
53        selection: &TilingBlueprint,
54        dtypes: &MatmulElems,
55    ) -> ConvolutionProblem;
56}
57
58impl ConcreteArgs for TensorArgs {
59    fn adjust_problem<R: Runtime>(
60        client: &ComputeClient<R>,
61        mut problem: ConvolutionProblem,
62        _selection: &TilingBlueprint,
63        dtypes: &MatmulElems,
64    ) -> ConvolutionProblem {
65        let load_width = client.properties().hardware.load_width;
66        let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
67        let padded_channels = problem.out_channels.next_multiple_of(channel_align);
68        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
69
70        problem.k = shape_k;
71        problem.padded_channels = padded_channels;
72
73        problem
74    }
75}
76
77impl ConcreteArgs for TensorMapArgs {
78    fn adjust_problem<R: Runtime>(
79        _client: &ComputeClient<R>,
80        mut problem: ConvolutionProblem,
81        selection: &TilingBlueprint,
82        _dtypes: &MatmulElems,
83    ) -> ConvolutionProblem {
84        let channel_align = selection.tiling_scheme.tile_size.k() as usize;
85        let padded_channels = problem.out_channels.next_multiple_of(channel_align);
86        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
87
88        problem.k = shape_k;
89        problem.padded_channels = padded_channels;
90
91        problem
92    }
93}
94
95/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
96/// output (not fused).
97pub trait ConcreteInputsFactory: LaunchArg {
98    #[allow(clippy::too_many_arguments)]
99    fn create<'a, R: Runtime>(
100        client: &ComputeClient<R>,
101        out_grad: &'a MatmulInputHandleRef<'a, R>,
102        weights: &'a MatmulInputHandleRef<'a, R>,
103        selection: &TilingBlueprint,
104        problem: &ConvolutionProblem,
105        line_sizes: &MatmulLineSizes,
106        config: impl ConvGemmConfig,
107        dtypes: &MatmulElems,
108    ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>);
109}
110
111/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
112/// output (not fused).
113pub trait ConcreteOutputFactory: LaunchArg {
114    fn create<'a, R: Runtime>(
115        client: &ComputeClient<R>,
116        out: &'a TensorHandleRef<'a, R>,
117        selection: &TilingBlueprint,
118        problem: &ConvolutionProblem,
119        line_sizes: &MatmulLineSizes,
120        config: impl ConvGemmConfig,
121    ) -> Self::RuntimeArg<'a, R>;
122}
123
124impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorInputs<Lhs, Rhs, EO> {
125    fn create<'a, R: Runtime>(
126        client: &ComputeClient<R>,
127        out_grad: &'a MatmulInputHandleRef<'a, R>,
128        weights: &'a MatmulInputHandleRef<'a, R>,
129        _selection: &TilingBlueprint,
130        problem: &ConvolutionProblem,
131        line_sizes: &MatmulLineSizes,
132        config: impl ConvGemmConfig,
133        _dtypes: &MatmulElems,
134    ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
135        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
136        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
137
138        let padded_channels = problem.padded_channels as u32;
139
140        let layout_nhwc =
141            |handle, line_size, checks| NhwcLayoutLaunch::from_handle(handle, line_size, checks);
142
143        let layout_lhs = Im2colLayoutLaunch::from_args(
144            client,
145            problem,
146            config.params(),
147            config.lhs_global_memory_config(),
148        );
149        let layout_rhs =
150            WeightLayoutLaunch::from_args(client, problem, config.rhs_global_memory_config());
151
152        let layout_lhs = {
153            let mut checks = EnumSet::empty();
154            if problem.should_check_spatial_bounds() {
155                checks.insert(NhwcCheck::Spatial);
156            }
157            if problem.should_check_channel() {
158                checks.insert(NhwcCheck::Channel);
159            }
160            let global = layout_nhwc(out_grad.data(), line_sizes.lhs, checks);
161            ChainLaunch::new(global, layout_lhs)
162        };
163        let layout_rhs = {
164            let mut checks = EnumSet::empty();
165            if problem.should_check_channel() {
166                checks.insert(NhwcCheck::Batch);
167            }
168            let global = layout_nhwc(weights.data(), line_sizes.rhs, checks);
169            ChainLaunch::new(global, layout_rhs)
170        };
171
172        let inputs = TensorInputsLaunch::new(
173            ViewArg::new::<LhsLayout>(out_grad.data().as_array_arg(line_sizes.lhs), layout_lhs),
174            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
175            ViewArg::new::<RhsLayout>(weights.data().as_array_arg(line_sizes.rhs), layout_rhs),
176            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
177            CubeOptionArgs::None,
178            CubeOptionArgs::None,
179        );
180
181        let runtime_args = RuntimeArgsLaunch::new(
182            ScalarArg::new(problem.k as u32),
183            ScalarArg::new(problem.out_channels as u32),
184            FastDivmodArgs::<u32>::new(client, padded_channels),
185            config.operation(),
186        );
187
188        (inputs, runtime_args)
189    }
190}
191
192impl<EG: Numeric> ConcreteOutputFactory for TensorOutput<EG> {
193    fn create<'a, R: Runtime>(
194        client: &ComputeClient<R>,
195        out: &'a TensorHandleRef<'a, R>,
196        _selection: &TilingBlueprint,
197        problem: &ConvolutionProblem,
198        line_sizes: &MatmulLineSizes,
199        config: impl ConvGemmConfig,
200    ) -> Self::RuntimeArg<'a, R> {
201        type Layout = Chain<NhwcLayout, OutLayout>;
202
203        let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out, EnumSet::empty());
204        let layout = OutLayoutLaunch::from_args(client, problem, config.rhs_global_memory_config());
205        let layout = ChainLaunch::new(global, layout);
206        let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
207        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
208        TensorOutputLaunch::new(view, batch)
209    }
210}
211
212impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
213    for TensorMapInputs<Lhs, Rhs, EO>
214{
215    fn create<'a, R: Runtime>(
216        client: &ComputeClient<R>,
217        out_grad: &'a MatmulInputHandleRef<'a, R>,
218        weights: &'a MatmulInputHandleRef<'a, R>,
219        selection: &TilingBlueprint,
220        problem: &ConvolutionProblem,
221        line_sizes: &MatmulLineSizes,
222        config: impl ConvGemmConfig,
223        dtypes: &MatmulElems,
224    ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
225        type LhsLayout = TmaIm2colLayout;
226        type RhsLayout = WeightLayout;
227
228        let tiling_scheme = selection.tiling_scheme;
229        let stage_m = tiling_scheme.elements_per_stage_along_m();
230        let stage_n = tiling_scheme.elements_per_stage_along_n();
231        let stage_k = tiling_scheme.elements_per_stage_along_k();
232        let tile_size_k = tiling_scheme.tile_size.k;
233
234        let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims()];
235        stage_size_rhs.insert(0, stage_k);
236        stage_size_rhs.push(stage_n);
237
238        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
239        // It shouldn't matter, but it's better to be safe.
240        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
241            tf32::as_type_native_unchecked()
242        } else {
243            dtypes.lhs_stage
244        };
245
246        let mut elem_stride = vec![1; 2 + problem.stride.len()];
247
248        for (i, stride) in problem.stride.iter().enumerate() {
249            elem_stride[i + 1] = *stride as usize;
250        }
251
252        let lhs = TensorMapArg::new(
253            Im2colArgs {
254                pixel_box_lower_corner: calculate_lower_corner(problem),
255                pixel_box_upper_corner: calculate_upper_corner(problem),
256                channels_per_pixel: tile_size_k,
257                pixels_per_column: stage_m,
258            },
259            out_grad.data().as_tensor_arg(line_sizes.lhs),
260            lhs_elem,
261        )
262        .with_elem_stride(elem_stride);
263
264        let rhs = TensorMapArg::new(
265            TiledArgs {
266                tile_size: stage_size_rhs,
267            },
268            weights.data().as_tensor_arg(line_sizes.rhs),
269            dtypes.rhs_global,
270        );
271
272        let padded_channels = problem.padded_channels as u32;
273        let shape_k = problem.k as u32;
274
275        let shape_out = problem
276            .out_shape
277            .iter()
278            .map(|it| FastDivmodArgs::<u32>::new(client, *it as u32))
279            .collect();
280
281        // Im2col needs extra checking because if `k` is OOB it wraps around the kernel and can load
282        // in-bounds but not in-kernel elements. Other TMA layouts are always outside the shape if
283        // any matrix dim is out of bounds.
284        let stages_lhs = config.stage_config().lhs_smem_config().num_stages;
285        let stages_size_k = selection.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
286        let lhs_layout = TmaIm2colLayoutLaunch::new(
287            shape_out,
288            FastDivmodArgs::<u32>::new(client, padded_channels),
289            ConvolutionParams::from_problem(problem),
290            !shape_k.is_multiple_of(stages_size_k),
291        );
292        let rhs_layout = WeightLayoutLaunch::from_args(
293            client,
294            problem,
295            GlobalMemoryConfig {
296                line_size: line_sizes.rhs,
297                check_row_bounds: false,
298                check_col_bounds: false,
299                matrix_layout: MatrixLayout::default(),
300                view_direction: ViewDirection::default(),
301                dtype: dtypes.rhs_global,
302            },
303        );
304
305        let inputs = TensorMapInputsLaunch::new(
306            ViewArg::new_tensor_map_im2col::<LhsLayout, _, _>(lhs, lhs_layout),
307            ViewArg::new_tensor_map_tiled::<RhsLayout>(rhs, rhs_layout),
308            CubeOptionArgs::None,
309            CubeOptionArgs::None,
310        );
311
312        let runtime_args = RuntimeArgsLaunch::new(
313            ScalarArg::new(shape_k),
314            ScalarArg::new(problem.out_channels as u32),
315            FastDivmodArgs::<u32>::new(client, padded_channels),
316            config.operation(),
317        );
318
319        (inputs, runtime_args)
320    }
321}
322
323#[allow(clippy::needless_range_loop)]
324fn calculate_lower_corner(problem: &ConvolutionProblem) -> Vec<i32> {
325    let mut out = vec![0; problem.padding.len()];
326    for i in 0..problem.padding.len() {
327        out[i] =
328            problem.padding[i] - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32;
329    }
330    out
331}
332
333#[allow(clippy::needless_range_loop)]
334fn calculate_upper_corner(problem: &ConvolutionProblem) -> Vec<i32> {
335    let mut out = vec![0; problem.padding.len()];
336    for i in 0..problem.padding.len() {
337        out[i] = problem.padding[i]
338            - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32
339            + problem.in_shape[i] as i32
340            - problem.out_shape[i] as i32;
341    }
342    out
343}