Skip to main content

cubek_convolution/kernels/backward_data/
args.rs

1use cubecl::{
2    Runtime,
3    client::ComputeClient,
4    prelude::*,
5    std::tensor::{
6        launch::ViewArg,
7        layout::{
8            VirtualLayoutLaunch,
9            chain::{Chain, ChainLaunch},
10        },
11    },
12    zspace::{shape, strides},
13};
14use cubek_matmul::{
15    components::global::memory::{GlobalLayoutConfig, NoopLayout, NoopLayoutLaunch},
16    definition::{Blueprint, MatmulElems, TilingBlueprint},
17    launch::*,
18    routines::Routine,
19};
20use cubek_std::launch::tma::remap_storage_for_tma;
21use cubek_std::{InputBinding, MatrixLayout};
22use enumset::EnumSet;
23
24use crate::components::{
25    ConvolutionParams, ConvolutionProblem,
26    global::{
27        args::{RuntimeArgs, RuntimeArgsLaunch},
28        layout::{
29            Im2colLayout, Im2colLayoutLaunch, NhwcCheck, NhwcLayout, NhwcLayoutLaunch, OutLayout,
30            OutLayoutLaunch, TmaIm2colLayout, TmaIm2colLayoutLaunch, WeightLayout,
31            WeightLayoutLaunch,
32        },
33    },
34};
35
36pub trait ConcreteArgs<A: Routine<RuntimeArgs>>:
37    MatmulArgs<
38        Input<Vector<Lhs, LhsSize>, Vector<Rhs, RhsSize>, Vector<Acc, AccSize>>: ConcreteInputsFactory<A>,
39        Output<Vector<Acc, AccSize>>: ConcreteOutputFactory<A>,
40        Config = RuntimeArgs,
41    >
42{
43    fn adjust_problem<R: Runtime>(
44        client: &ComputeClient<R>,
45        problem: ConvolutionProblem,
46        selection: &A::Blueprint,
47        dtypes: &MatmulElems,
48    ) -> ConvolutionProblem;
49}
50
51impl<A: Routine<RuntimeArgs>> ConcreteArgs<A> for TensorArgs<RuntimeArgs> {
52    fn adjust_problem<R: Runtime>(
53        client: &ComputeClient<R>,
54        mut problem: ConvolutionProblem,
55        _blueprint: &A::Blueprint,
56        dtypes: &MatmulElems,
57    ) -> ConvolutionProblem {
58        let load_width = client.properties().hardware.load_width;
59        let channel_align = load_width as usize / dtypes.lhs_global.size_bits();
60        let padded_channels = problem.out_channels.next_multiple_of(channel_align);
61        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
62
63        problem.k = shape_k;
64        problem.padded_channels = padded_channels;
65
66        problem
67    }
68}
69
70impl<A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>> ConcreteArgs<A>
71    for TensorMapArgs<RuntimeArgs>
72{
73    fn adjust_problem<R: Runtime>(
74        _client: &ComputeClient<R>,
75        mut problem: ConvolutionProblem,
76        blueprint: &TilingBlueprint,
77        _dtypes: &MatmulElems,
78    ) -> ConvolutionProblem {
79        let channel_align = blueprint.tiling_scheme.tile_size.k() as usize;
80        let padded_channels = problem.out_channels.next_multiple_of(channel_align);
81        let shape_k = problem.kernel_size.iter().product::<u32>() as usize * padded_channels;
82
83        problem.k = shape_k;
84        problem.padded_channels = padded_channels;
85
86        problem
87    }
88}
89
90/// Create the input runtime arguments for a matmul kernel that works on concrete inputs and
91/// output (not fused).
92pub trait ConcreteInputsFactory<A: Routine<RuntimeArgs>>: LaunchArg {
93    #[allow(clippy::too_many_arguments)]
94    fn create<R: Runtime>(
95        out_grad: InputBinding<R>,
96        weights: InputBinding<R>,
97        blueprint: &A::Blueprint,
98        problem: &ConvolutionProblem,
99        dtypes: &MatmulElems,
100    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>);
101}
102
103/// Create the output runtime arguments for a matmul kernel that works on concrete inputs and
104/// output (not fused).
105pub trait ConcreteOutputFactory<A: Routine<RuntimeArgs>>: LaunchArg {
106    fn create<R: Runtime>(
107        out: TensorBinding<R>,
108        blueprint: &A::Blueprint,
109        problem: &ConvolutionProblem,
110    ) -> Self::RuntimeArg<R>;
111}
112
113impl<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive, A: Routine<RuntimeArgs>>
114    ConcreteInputsFactory<A> for TensorInputs<Lhs, Rhs, EO>
115{
116    fn create<R: Runtime>(
117        out_grad: InputBinding<R>,
118        weights: InputBinding<R>,
119        blueprint: &A::Blueprint,
120        problem: &ConvolutionProblem,
121        _dtypes: &MatmulElems,
122    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
123        type LhsLayout = Chain<NhwcLayout, Im2colLayout>;
124        type RhsLayout = Chain<NhwcLayout, WeightLayout>;
125
126        let padded_channels = problem.padded_channels as u32;
127        let params = ConvolutionParams::from_problem(problem);
128
129        let layout_lhs =
130            Im2colLayoutLaunch::from_args(problem, params, blueprint.lhs_global_layout_config());
131        let layout_rhs =
132            WeightLayoutLaunch::from_args(problem, blueprint.rhs_global_layout_config());
133
134        let layout_lhs = {
135            let mut checks = EnumSet::empty();
136            if problem.should_check_spatial_bounds() {
137                checks.insert(NhwcCheck::Spatial);
138            }
139            if problem.should_check_channel() {
140                checks.insert(NhwcCheck::Channel);
141            }
142            let global = NhwcLayoutLaunch::checked(checks);
143            ChainLaunch::new(global, layout_lhs)
144        };
145        let layout_rhs = {
146            let mut checks = EnumSet::empty();
147            if problem.should_check_channel() {
148                checks.insert(NhwcCheck::Batch);
149            }
150            let global = NhwcLayoutLaunch::checked(checks);
151            ChainLaunch::new(global, layout_rhs)
152        };
153
154        let inputs = TensorInputsLaunch::new(
155            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
156            ViewArg::new_tensor::<LhsLayout>(out_grad.into_data().into_tensor_arg(), layout_lhs),
157            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
158            ViewArg::new_tensor::<RhsLayout>(weights.into_data().into_tensor_arg(), layout_rhs),
159            ComptimeOptionArgs::None,
160            ComptimeOptionArgs::None,
161        );
162
163        let runtime_args = RuntimeArgsLaunch::new(
164            problem.k as u32,
165            problem.out_channels as u32,
166            padded_channels,
167            problem.operation,
168        );
169
170        (inputs, runtime_args)
171    }
172}
173
174impl<EG: CubePrimitive, A: Routine<RuntimeArgs>> ConcreteOutputFactory<A> for TensorOutput<EG> {
175    fn create<R: Runtime>(
176        out: TensorBinding<R>,
177        blueprint: &A::Blueprint,
178        problem: &ConvolutionProblem,
179    ) -> Self::RuntimeArg<R> {
180        type Layout = Chain<NhwcLayout, OutLayout>;
181
182        let global = NhwcLayoutLaunch::unchecked();
183        let layout = OutLayoutLaunch::from_args(problem, blueprint.out_global_layout_config());
184        let layout = ChainLaunch::new(global, layout);
185        let view = ViewArg::new_tensor::<Layout>(out.into_tensor_arg(), layout);
186        let batch = VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new());
187        TensorOutputLaunch::new(view, batch)
188    }
189}
190
191impl<
192    Lhs: CubePrimitive,
193    Rhs: CubePrimitive,
194    EO: CubePrimitive,
195    A: Routine<RuntimeArgs, Blueprint = TilingBlueprint>,
196> ConcreteInputsFactory<A> for TensorMapInputs<Lhs, Rhs, EO>
197{
198    fn create<R: Runtime>(
199        out_grad: InputBinding<R>,
200        weights: InputBinding<R>,
201        blueprint: &TilingBlueprint,
202        problem: &ConvolutionProblem,
203        dtypes: &MatmulElems,
204    ) -> (Self::RuntimeArg<R>, RuntimeArgsLaunch<R>) {
205        type LhsLayout = TmaIm2colLayout;
206        type RhsLayout = WeightLayout;
207
208        let tiling_scheme = blueprint.tiling_scheme;
209        let stage_m = tiling_scheme.elements_per_stage_along_m();
210        let stage_n = tiling_scheme.elements_per_stage_along_n();
211        let stage_k = tiling_scheme.elements_per_stage_along_k();
212        let tile_size_k = tiling_scheme.tile_size.k;
213
214        let mut stage_size_rhs = shape![1; problem.dimensionality.num_dims()];
215        stage_size_rhs.insert(0, stage_k as usize);
216        stage_size_rhs.push(stage_n as usize);
217
218        let lhs_elem = remap_storage_for_tma(dtypes.lhs_stage);
219
220        let mut elem_stride = strides![1; 2 + problem.stride.len()];
221
222        for (i, stride) in problem.stride.iter().enumerate() {
223            elem_stride[i + 1] = *stride as usize;
224        }
225
226        let lhs = TensorMapArg::new(
227            Im2colArgs {
228                pixel_box_lower_corner: calculate_lower_corner(problem),
229                pixel_box_upper_corner: calculate_upper_corner(problem),
230                channels_per_pixel: tile_size_k,
231                pixels_per_column: stage_m,
232            },
233            out_grad.into_data().into_tensor_arg(),
234            lhs_elem,
235        )
236        .with_elem_stride(elem_stride);
237
238        let rhs = TensorMapArg::new(
239            TiledArgs {
240                tile_size: stage_size_rhs,
241            },
242            weights.into_data().into_tensor_arg(),
243            dtypes.rhs_global,
244        );
245
246        let padded_channels = problem.padded_channels as u32;
247        let shape_k = problem.k as u32;
248
249        // Im2col needs extra checking because if `k` is OOB it wraps around the kernel and can load
250        // in-bounds but not in-kernel elements. Other TMA layouts are always outside the shape if
251        // any matrix dim is out of bounds.
252        let stages_lhs = A::num_stages().lhs;
253        let stages_size_k = blueprint.tiling_scheme.elements_per_stage_along_k() * stages_lhs;
254        let check_kernel = !shape_k.is_multiple_of(stages_size_k);
255        let lhs_layout = TmaIm2colLayoutLaunch::from_args(problem, check_kernel);
256        let rhs_layout = WeightLayoutLaunch::from_args(
257            problem,
258            GlobalLayoutConfig {
259                check_row_bounds: false,
260                check_col_bounds: false,
261                matrix_layout: MatrixLayout::default(),
262            },
263        );
264
265        let inputs = TensorMapInputsLaunch::new(
266            ViewArg::new_tensor_map_im2col::<LhsLayout, _, _>(lhs, lhs_layout),
267            ViewArg::new_tensor_map_tiled::<RhsLayout>(rhs, rhs_layout),
268            ComptimeOptionArgs::None,
269            ComptimeOptionArgs::None,
270        );
271
272        let runtime_args = RuntimeArgsLaunch::new(
273            shape_k,
274            problem.out_channels as u32,
275            padded_channels,
276            problem.operation,
277        );
278
279        (inputs, runtime_args)
280    }
281}
282
283#[allow(clippy::needless_range_loop)]
284fn calculate_lower_corner(problem: &ConvolutionProblem) -> Vec<i32> {
285    let mut out = vec![0; problem.padding.len()];
286    for i in 0..problem.padding.len() {
287        out[i] =
288            problem.padding[i] - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32;
289    }
290    out
291}
292
293#[allow(clippy::needless_range_loop)]
294fn calculate_upper_corner(problem: &ConvolutionProblem) -> Vec<i32> {
295    let mut out = vec![0; problem.padding.len()];
296    for i in 0..problem.padding.len() {
297        out[i] = problem.padding[i]
298            - (problem.kernel_size[i] as i32 - 1) * problem.dilation[i] as i32
299            + problem.in_shape[i] as i32
300            - problem.out_shape[i] as i32;
301    }
302    out
303}