cubecl_convolution/components/global/
entry_point.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_core::{Runtime, client::ComputeClient};
4use cubecl_matmul::components::{
5    InputRuntimeArg, MatmulSpec, OutputRuntimeArg,
6    batch::SliceIndex,
7    global::{GlobalConfig as _, args::MatmulArgs},
8};
9use cubecl_std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs};
10
11use crate::{
12    components::{
13        ConvolutionProblem,
14        global::{GlobalConvolution, GlobalConvolutionFamily},
15    },
16    kernels::layered::selector::RuntimeArgs,
17};
18
19type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
20type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
21
22/// Provides launch entry point to solve a matmul
23pub trait ConvolutionLaunch<Config> {
24    /// Entry point
25    ///
26    /// # Safety
27    ///
28    /// Out-of-bounds can happen
29    #[allow(clippy::too_many_arguments)]
30    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
31        client: &ComputeClient<<R as Runtime>::Server>,
32        cube_dim: CubeDim,
33        cube_count: CubeCount,
34        input: InputRuntimeArg<'a, MS, R>,
35        output: OutputRuntimeArg<'a, MS, R>,
36        problem: &ConvolutionProblem,
37        config: Config,
38    );
39}
40
41#[cube(launch_unchecked)]
42pub(crate) fn implicit_conv<
43    Args: MatmulArgs,
44    LhsG: Numeric,
45    RhsG: Numeric,
46    AccG: Numeric,
47    LhsS: Numeric,
48    RhsS: Numeric,
49    AccS: Numeric,
50    GMM: GlobalConvolutionFamily,
51>(
52    inputs: &Input<Args, LhsG, RhsG, AccG>,
53    output: &mut Output<Args, AccG>,
54    runtime_args: RuntimeArgs,
55    #[comptime] config: GMM::Config,
56) {
57    let mut state = Args::init_state::<LhsG, RhsG, AccG, GMM::Config>(inputs, output, config);
58
59    let lhs = Args::view_lhs(&state);
60    let rhs = Args::view_rhs(&state);
61    let bias = Args::view_acc(&state);
62    let out = Args::view_out(&mut state);
63
64    let stage_m = config.tiling_scheme().elements_in_stage_m().runtime();
65    let stage_n = config.tiling_scheme().elements_in_stage_n().runtime();
66
67    let m_offset = CUBE_POS_X * stage_m;
68    let n_offset = CUBE_POS_Y * stage_n;
69
70    let k_range = (0, runtime_args.shape_k);
71    let k_size = runtime_args.shape_k;
72
73    let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
74    let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
75    let bias = match bias {
76        CubeOption::Some(bias) => {
77            let view = bias.view(SliceIndex::new(0, bias.shape()));
78            CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
79        }
80        CubeOption::None => CubeOption::new_None(),
81    };
82    let out = out.view_mut(SliceIndex::new(0, out.shape()));
83
84    GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::execute(
85        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_lhs_global_reader(
86            lhs,
87            (m_offset, k_range.0),
88            (stage_m, k_size),
89            &runtime_args,
90            config,
91        ),
92        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_rhs_global_reader(
93            rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
94            config,
95        ),
96        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_bias_global_reader(
97            bias, config,
98        ),
99        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_global_writer(
100            out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
101            config,
102        ),
103        &mut GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_accumulator(config),
104        k_range,
105        config,
106    );
107}
108
109pub(crate) fn shape_divmod<'a, R: Runtime>(
110    client: &ComputeClient<R::Server>,
111    shape: &[usize],
112) -> SequenceArg<'a, R, FastDivmod> {
113    shape
114        .iter()
115        .map(|s| FastDivmodArgs::new(client, *s as u32))
116        .collect()
117}