Skip to main content

cubek_convolution/kernels/backward_data/
args.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    prelude::*,
5    std::{
6        CubeOptionArgs, FastDivmodArgs,
7        tensor::{
8            launch::ViewArg,
9            layout::{
10                VirtualLayoutLaunch,
11                chain::{Chain, ChainLaunch},
12            },
13        },
14    },
15};
16use cubek_matmul::{
17    components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
18    definition::{Blueprint, MatmulElems, MatmulLineSizes, MatrixLayout, TilingBlueprint},
19    launch::{
20        MatmulArgs, MatmulInputHandleRef, TensorArgs, TensorInputs, TensorInputsLaunch,
21        TensorMapArgs, TensorMapInputs, TensorMapInputsLaunch, TensorOutput, TensorOutputLaunch,
22    },
23    routines::Routine,
24};
25use enumset::EnumSet;
26
27use crate::components::{
28    ConvolutionParams, ConvolutionProblem,
29    global::{
30        args::{RuntimeArgs, RuntimeArgsLaunch},
31        layout::{
32            Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch, OutLayout,
33            OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
34            WeightLayoutLaunch,
35        },
36    },
37};
38
39pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
40    MatmulArgs<
41        Input<NumericExpand<0>, NumericExpand<1>, NumericExpand<2>>: ConcreteInputsFactory<A>,
42        Output<NumericExpand<2>>: ConcreteOutputFactory<A>,
43        Config = RuntimeArgs,
44    >
45{
46    fn adjust_problem<R: Runtime>(
47        client: &ComputeClient<R>,
48        problem: ConvolutionProblem,
49        selection: &A::Blueprint,
50        dtypes: &MatmulElems,
51    ) -> ConvolutionProblem;
52}
53
54impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
55    fn adjust_problem<R: Runtime>(
56        client: &ComputeClient<R>,
57        mut problem: ConvolutionProblem,
58        _blueprint: &A::Blueprint,
59        dtypes: &MatmulElems,
60    ) -> ConvolutionProblem {
61        let load_width = client.properties().hardware.load_width;
62        let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
63        let padded_channels = problem.out_channels.next_multiple_of(channel_align);
64        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
65
66        problem.k = shape_k;
67        problem.padded_channels = padded_channels;
68
69        problem
70    }
71}
72
73impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
74    for TensorMapArgs<RuntimeArgs>
75{
76    fn adjust_problem<R: Runtime>(
77        _client: &ComputeClient<R>,
78        mut problem: ConvolutionProblem,
79        blueprint: &TilingBlueprint,
80        _dtypes: &MatmulElems,
81    ) -> ConvolutionProblem {
82        let channel_align = blueprint.tiling_scheme.tile_size.k() as usize;
83        let padded_channels = problem.out_channels.next_multiple_of(channel_align);
84        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
85
86        problem.k = shape_k;
87        problem.padded_channels = padded_channels;
88
89        problem
90    }
91}
92
93/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
94/// output (not fused).
95pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
96    #[allow(clippy::too_many_arguments)]
97    fn create<'a, R: Runtime>(
98        client: &ComputeClient<R>,
99        out_grad: &'a MatmulInputHandleRef<'a, R>,
100        weights: &'a MatmulInputHandleRef<'a, R>,
101        blueprint: &A::Blueprint,
102        problem: &ConvolutionProblem,
103        line_sizes: &MatmulLineSizes,
104        dtypes: &MatmulElems,
105    ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>);
106}
107
108/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
109/// output (not fused).
110pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
111    fn create<'a, R: Runtime>(
112        client: &ComputeClient<R>,
113        out: &'a TensorHandleRef<'a, R>,
114        blueprint: &A::Blueprint,
115        problem: &ConvolutionProblem,
116        line_sizes: &MatmulLineSizes,
117    ) -> Self::RuntimeArg<'a, R>;
118}
119
120impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric, A: Routine<RuntimeArgs>> ConcreteInputsFactory<A>
121    for TensorInputs<Lhs, Rhs, EO>
122{
123    fn create<'a, R: Runtime>(
124        client: &ComputeClient<R>,
125        out_grad: &'a MatmulInputHandleRef<'a, R>,
126        weights: &'a MatmulInputHandleRef<'a, R>,
127        blueprint: &A::Blueprint,
128        problem: &ConvolutionProblem,
129        line_sizes: &MatmulLineSizes,
130        _dtypes: &MatmulElems,
131    ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
132        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
133        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
134
135        let padded_channels = problem.padded_channels as u32;
136        let params = ConvolutionParams::from_problem(problem);
137
138        let layout_nhwc =
139            |handle, line_size, checks| NhwcLayoutLaunch::from_handle(handle, line_size, checks);
140
141        let layout_lhs = Im2colLayoutLaunch::from_args(
142            client,
143            problem,
144            params,
145            blueprint.lhs_global_layout_config(),
146        );
147        let layout_rhs =
148            WeightLayoutLaunch::from_args(client, problem, blueprint.rhs_global_layout_config());
149
150        let layout_lhs = {
151            let mut checks = EnumSet::empty();
152            if problem.should_check_spatial_bounds() {
153                checks.insert(NhwcCheck::Spatial);
154            }
155            if problem.should_check_channel() {
156                checks.insert(NhwcCheck::Channel);
157            }
158            let global = layout_nhwc(out_grad.data(), line_sizes.lhs, checks);
159            ChainLaunch::new(global, layout_lhs)
160        };
161        let layout_rhs = {
162            let mut checks = EnumSet::empty();
163            if problem.should_check_channel() {
164                checks.insert(NhwcCheck::Batch);
165            }
166            let global = layout_nhwc(weights.data(), line_sizes.rhs, checks);
167            ChainLaunch::new(global, layout_rhs)
168        };
169
170        let inputs = TensorInputsLaunch::new(
171            ViewArg::new::<LhsLayout>(out_grad.data().as_array_arg(line_sizes.lhs), layout_lhs),
172            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
173            ViewArg::new::<RhsLayout>(weights.data().as_array_arg(line_sizes.rhs), layout_rhs),
174            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
175            CubeOptionArgs::None,
176            CubeOptionArgs::None,
177        );
178
179        let runtime_args = RuntimeArgsLaunch::new(
180            ScalarArg::new(problem.k as u32),
181            ScalarArg::new(problem.out_channels as u32),
182            FastDivmodArgs::<u32>::new(client, padded_channels),
183            problem.operation,
184        );
185
186        (inputs, runtime_args)
187    }
188}
189
190impl<EG: Numeric, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
191    fn create<'a, R: Runtime>(
192        client: &ComputeClient<R>,
193        out: &'a TensorHandleRef<'a, R>,
194        blueprint: &A::Blueprint,
195        problem: &ConvolutionProblem,
196        line_sizes: &MatmulLineSizes,
197    ) -> Self::RuntimeArg<'a, R> {
198        type Layout = Chain<NhwcLayout, OutLayout>;
199
200        let global = NhwcLayoutLaunch::from_handle(out, line_sizes.out, EnumSet::empty());
201        let layout =
202            OutLayoutLaunch::from_args(client, problem, blueprint.out_global_layout_config());
203        let layout = ChainLaunch::new(global, layout);
204        let view = ViewArg::new::<Layout>(out.as_array_arg(line_sizes.out), layout);
205        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
206        TensorOutputLaunch::new(view, batch)
207    }
208}
209
210impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric, A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>>
211    ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
212{
213    fn create<'a, R: Runtime>(
214        client: &ComputeClient<R>,
215        out_grad: &'a MatmulInputHandleRef<'a, R>,
216        weights: &'a MatmulInputHandleRef<'a, R>,
217        blueprint: &TilingBlueprint,
218        problem: &ConvolutionProblem,
219        line_sizes: &MatmulLineSizes,
220        dtypes: &MatmulElems,
221    ) -> (Self::RuntimeArg<'a, R>, RuntimeArgsLaunch<'a, R>) {
222        type LhsLayout = TmaIm2colLayout;
223        type RhsLayout = WeightLayout;
224
225        let tiling_scheme = blueprint.tiling_scheme;
226        let stage_m = tiling_scheme.elements_per_stage_along_m();
227        let stage_n = tiling_scheme.elements_per_stage_along_n();
228        let stage_k = tiling_scheme.elements_per_stage_along_k();
229        let tile_size_k = tiling_scheme.tile_size.k;
230
231        let mut stage_size_rhs = vec![1; problem.dimensionality.num_dims()];
232        stage_size_rhs.insert(0, stage_k);
233        stage_size_rhs.push(stage_n);
234
235        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
236        // It shouldn't matter, but it's better to be safe.
237        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked() {
238            tf32::as_type_native_unchecked()
239        } else {
240            dtypes.lhs_stage
241        };
242
243        let mut elem_stride = vec![1; 2 + problem.stride.len()];
244
245        for (i, stride) in problem.stride.iter().enumerate() {
246            elem_stride[i + 1] = *stride as usize;
247        }
248
249        let lhs = TensorMapArg::new(
250            Im2colArgs {
251                pixel_box_lower_corner: calculate_lower_corner(problem),
252                pixel_box_upper_corner: calculate_upper_corner(problem),
253                channels_per_pixel: tile_size_k,
254                pixels_per_column: stage_m,
255            },
256            out_grad.data().as_tensor_arg(line_sizes.lhs),
257            lhs_elem,
258        )
259        .with_elem_stride(elem_stride);
260
261        let rhs = TensorMapArg::new(
262            TiledArgs {
263                tile_size: stage_size_rhs,
264            },
265            weights.data().as_tensor_arg(line_sizes.rhs),
266            dtypes.rhs_global,
267        );
268
269        let padded_channels = problem.padded_channels as u32;
270        let shape_k = problem.k as u32;
271
272        // Im2col needs extra checking because if `k` is OOB it wraps around the kernel and can load
273        // in-bounds but not in-kernel elements. Other TMA layouts are always outside the shape if
274        // any matrix dim is out of bounds.
275        let stages_lhs = A::num_stages().lhs;
276        let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
277        let check_kernel = !shape_k.is_multiple_of(stages_size_k);
278        let lhs_layout = TmaIm2colLayoutLaunch::from_args(client, problem, check_kernel);
279        let rhs_layout = WeightLayoutLaunch::from_args(
280            client,
281            problem,
282            GlobalLayoutConfig {
283                check_row_bounds: false,
284                check_col_bounds: false,
285                matrix_layout: MatrixLayout::default(),
286            },
287        );
288
289        let inputs = TensorMapInputsLaunch::new(
290            ViewArg::new_tensor_map_im2col::<LhsLayout, _, _>(lhs, lhs_layout),
291            ViewArg::new_tensor_map_tiled::<RhsLayout>(rhs, rhs_layout),
292            CubeOptionArgs::None,
293            CubeOptionArgs::None,
294        );
295
296        let runtime_args = RuntimeArgsLaunch::new(
297            ScalarArg::new(shape_k),
298            ScalarArg::new(problem.out_channels as u32),
299            FastDivmodArgs::<u32>::new(client, padded_channels),
300            problem.operation,
301        );
302
303        (inputs, runtime_args)
304    }
305}
306
307#[allow(clippy::needless_range_loop)]
308fn calculate_lower_corner(problem: &ConvolutionProblem) -> Vec<i32> {
309    let mut out = vec![0; problem.padding.len()];
310    for i in 0..problem.padding.len() {
311        out[i] =
312            problem.padding[i] - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32;
313    }
314    out
315}
316
317#[allow(clippy::needless_range_loop)]
318fn calculate_upper_corner(problem: &ConvolutionProblem) -> Vec<i32> {
319    let mut out = vec![0; problem.padding.len()];
320    for i in 0..problem.padding.len() {
321        out[i] = problem.padding[i]
322            - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32
323            + problem.in_shape[i] as i32
324            - problem.out_shape[i] as i32;
325    }
326    out
327}