Skip to main content

cubek_convolution/kernels/backward_data/
args.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    prelude::*,
5    std::tensor::{
6        launch::ViewArg,
7        layout::{
8            VirtualLayoutLaunch,
9            chain::{Chain, ChainLaunch},
10        },
11    },
12    zspace::{shape, strides},
13};
14use cubek_matmul::{
15    components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
16    definition::{Blueprint, MatmulElems, TilingBlueprint},
17    launch::*,
18    routines::Routine,
19};
20use cubek_std::{InputBinding, MatrixLayout};
21use enumset::EnumSet;
22
23use crate::components::{
24    ConvolutionParams, ConvolutionProblem,
25    global::{
26        args::{RuntimeArgs, RuntimeArgsLaunch},
27        layout::{
28            Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch, OutLayout,
29            OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
30            WeightLayoutLaunch,
31        },
32    },
33};
34
35pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
36    MatmulArgs<
37        Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>: ConcreteInputsFactory<A>,
38        Output<Vector<Acc, AccSize>>: ConcreteOutputFactory<A>,
39        Config = RuntimeArgs,
40    >
41{
42    fn adjust_problem<R: Runtime>(
43        client: &ComputeClient<R>,
44        problem: ConvolutionProblem,
45        selection: &A::Blueprint,
46        dtypes: &MatmulElems,
47    ) -> ConvolutionProblem;
48}
49
50impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
51    fn adjust_problem<R: Runtime>(
52        client: &ComputeClient<R>,
53        mut problem: ConvolutionProblem,
54        _blueprint: &A::Blueprint,
55        dtypes: &MatmulElems,
56    ) -> ConvolutionProblem {
57        let load_width = client.properties().hardware.load_width;
58        let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
59        let padded_channels = problem.out_channels.next_multiple_of(channel_align);
60        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
61
62        problem.k = shape_k;
63        problem.padded_channels = padded_channels;
64
65        problem
66    }
67}
68
69impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
70    for TensorMapArgs<RuntimeArgs>
71{
72    fn adjust_problem<R: Runtime>(
73        _client: &ComputeClient<R>,
74        mut problem: ConvolutionProblem,
75        blueprint: &TilingBlueprint,
76        _dtypes: &MatmulElems,
77    ) -> ConvolutionProblem {
78        let channel_align = blueprint.tiling_scheme.tile_size.k() as usize;
79        let padded_channels = problem.out_channels.next_multiple_of(channel_align);
80        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
81
82        problem.k = shape_k;
83        problem.padded_channels = padded_channels;
84
85        problem
86    }
87}
88
89/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
90/// output (not fused).
91pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
92    #[allow(clippy::too_many_arguments)]
93    fn create<R: Runtime>(
94        out_grad: InputBinding<R>,
95        weights: InputBinding<R>,
96        blueprint: &A::Blueprint,
97        problem: &ConvolutionProblem,
98        dtypes: &MatmulElems,
99    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>);
100}
101
102/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
103/// output (not fused).
104pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
105    fn create<R: Runtime>(
106        out: TensorBinding<R>,
107        blueprint: &A::Blueprint,
108        problem: &ConvolutionProblem,
109    ) -> Self::RuntimeArg<R>;
110}
111
112impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<RuntimeArgs>>
113    ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, EO>
114{
115    fn create<R: Runtime>(
116        out_grad: InputBinding<R>,
117        weights: InputBinding<R>,
118        blueprint: &A::Blueprint,
119        problem: &ConvolutionProblem,
120        _dtypes: &MatmulElems,
121    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
122        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
123        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
124
125        let padded_channels = problem.padded_channels as u32;
126        let params = ConvolutionParams::from_problem(problem);
127
128        let layout_lhs =
129            Im2colLayoutLaunch::from_args(problem, params, blueprint.lhs_global_layout_config());
130        let layout_rhs =
131            WeightLayoutLaunch::from_args(problem, blueprint.rhs_global_layout_config());
132
133        let layout_lhs = {
134            let mut checks = EnumSet::empty();
135            if problem.should_check_spatial_bounds() {
136                checks.insert(NhwcCheck::Spatial);
137            }
138            if problem.should_check_channel() {
139                checks.insert(NhwcCheck::Channel);
140            }
141            let global = NhwcLayoutLaunch::checked(checks);
142            ChainLaunch::new(global, layout_lhs)
143        };
144        let layout_rhs = {
145            let mut checks = EnumSet::empty();
146            if problem.should_check_channel() {
147                checks.insert(NhwcCheck::Batch);
148            }
149            let global = NhwcLayoutLaunch::checked(checks);
150            ChainLaunch::new(global, layout_rhs)
151        };
152
153        let inputs = TensorInputsLaunch::new(
154            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
155            ViewArg::new_tensor::<LhsLayout>(out_grad.into_data().into_tensor_arg(), layout_lhs),
156            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
157            ViewArg::new_tensor::<RhsLayout>(weights.into_data().into_tensor_arg(), layout_rhs),
158            ComptimeOptionArgs::None,
159            ComptimeOptionArgs::None,
160        );
161
162        let runtime_args = RuntimeArgsLaunch::new(
163            problem.k as u32,
164            problem.out_channels as u32,
165            padded_channels,
166            problem.operation,
167        );
168
169        (inputs, runtime_args)
170    }
171}
172
173impl<EG: CubePrimitive, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
174    fn create<R: Runtime>(
175        out: TensorBinding<R>,
176        blueprint: &A::Blueprint,
177        problem: &ConvolutionProblem,
178    ) -> Self::RuntimeArg<R> {
179        type Layout = Chain<NhwcLayout, OutLayout>;
180
181        let global = NhwcLayoutLaunch::unchecked();
182        let layout = OutLayoutLaunch::from_args(problem, blueprint.out_global_layout_config());
183        let layout = ChainLaunch::new(global, layout);
184        let view = ViewArg::new_tensor::<Layout>(out.into_tensor_arg(), layout);
185        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
186        TensorOutputLaunch::new(view, batch)
187    }
188}
189
190impl<
191    Lhs: CubePrimitive,
192    Rhs: CubePrimitive,
193    EO: CubePrimitive,
194    A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>,
195> ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
196{
197    fn create<R: Runtime>(
198        out_grad: InputBinding<R>,
199        weights: InputBinding<R>,
200        blueprint: &TilingBlueprint,
201        problem: &ConvolutionProblem,
202        dtypes: &MatmulElems,
203    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
204        type LhsLayout = TmaIm2colLayout;
205        type RhsLayout = WeightLayout;
206
207        let tiling_scheme = blueprint.tiling_scheme;
208        let stage_m = tiling_scheme.elements_per_stage_along_m();
209        let stage_n = tiling_scheme.elements_per_stage_along_n();
210        let stage_k = tiling_scheme.elements_per_stage_along_k();
211        let tile_size_k = tiling_scheme.tile_size.k;
212
213        let mut stage_size_rhs = shape![1; problem.dimensionality.num_dims()];
214        stage_size_rhs.insert(0, stage_k as usize);
215        stage_size_rhs.push(stage_n as usize);
216
217        // f32 gets remapped to tf32 for the tensor map just to ensure CUDA loads them correctly.
218        // It shouldn't matter, but it's better to be safe.
219        let lhs_elem = if dtypes.lhs_stage == f32::as_type_native_unchecked().storage_type() {
220            tf32::as_type_native_unchecked().storage_type()
221        } else {
222            dtypes.lhs_stage
223        };
224
225        let mut elem_stride = strides![1; 2 + problem.stride.len()];
226
227        for (i, stride) in problem.stride.iter().enumerate() {
228            elem_stride[i + 1] = *stride as usize;
229        }
230
231        let lhs = TensorMapArg::new(
232            Im2colArgs {
233                pixel_box_lower_corner: calculate_lower_corner(problem),
234                pixel_box_upper_corner: calculate_upper_corner(problem),
235                channels_per_pixel: tile_size_k,
236                pixels_per_column: stage_m,
237            },
238            out_grad.into_data().into_tensor_arg(),
239            lhs_elem,
240        )
241        .with_elem_stride(elem_stride);
242
243        let rhs = TensorMapArg::new(
244            TiledArgs {
245                tile_size: stage_size_rhs,
246            },
247            weights.into_data().into_tensor_arg(),
248            dtypes.rhs_global,
249        );
250
251        let padded_channels = problem.padded_channels as u32;
252        let shape_k = problem.k as u32;
253
254        // Im2col needs extra checking because if `k` is OOB it wraps around the kernel and can load
255        // in-bounds but not in-kernel elements. Other TMA layouts are always outside the shape if
256        // any matrix dim is out of bounds.
257        let stages_lhs = A::num_stages().lhs;
258        let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
259        let check_kernel = !shape_k.is_multiple_of(stages_size_k);
260        let lhs_layout = TmaIm2colLayoutLaunch::from_args(problem, check_kernel);
261        let rhs_layout = WeightLayoutLaunch::from_args(
262            problem,
263            GlobalLayoutConfig {
264                check_row_bounds: false,
265                check_col_bounds: false,
266                matrix_layout: MatrixLayout::default(),
267            },
268        );
269
270        let inputs = TensorMapInputsLaunch::new(
271            ViewArg::new_tensor_map_im2col::<LhsLayout, _, _>(lhs, lhs_layout),
272            ViewArg::new_tensor_map_tiled::<RhsLayout>(rhs, rhs_layout),
273            ComptimeOptionArgs::None,
274            ComptimeOptionArgs::None,
275        );
276
277        let runtime_args = RuntimeArgsLaunch::new(
278            shape_k,
279            problem.out_channels as u32,
280            padded_channels,
281            problem.operation,
282        );
283
284        (inputs, runtime_args)
285    }
286}
287
288#[allow(clippy::needless_range_loop)]
289fn calculate_lower_corner(problem: &ConvolutionProblem) -> Vec<i32> {
290    let mut out = vec![0; problem.padding.len()];
291    for i in 0..problem.padding.len() {
292        out[i] =
293            problem.padding[i] - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32;
294    }
295    out
296}
297
298#[allow(clippy::needless_range_loop)]
299fn calculate_upper_corner(problem: &ConvolutionProblem) -> Vec<i32> {
300    let mut out = vec![0; problem.padding.len()];
301    for i in 0..problem.padding.len() {
302        out[i] = problem.padding[i]
303            - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32
304            + problem.in_shape[i] as i32
305            - problem.out_shape[i] as i32;
306    }
307    out
308}