cubecl_convolution/components/global/
entry_point.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_core::{Runtime, client::ComputeClient};
4use cubecl_matmul::components::MatmulElems;
5use cubecl_matmul::components::{
6    InputRuntimeArg, OutputRuntimeArg,
7    batch::SliceIndex,
8    global::{GlobalConfig as _, args::MatmulArgs},
9};
10use cubecl_std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs};
11
12use crate::{
13    components::{
14        ConvolutionProblem,
15        global::{GlobalConvolution, GlobalConvolutionFamily},
16    },
17    kernels::layered::selector::RuntimeArgs,
18};
19
20type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
21type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
22
23/// Provides launch entry point to solve a matmul
24pub trait ConvolutionLaunch<Config> {
25    /// Entry point
26    ///
27    /// # Safety
28    ///
29    /// Out-of-bounds can happen
30    #[allow(clippy::too_many_arguments)]
31    unsafe fn launch_unchecked<'a, MA: MatmulArgs, R: Runtime>(
32        client: &ComputeClient<<R as Runtime>::Server>,
33        cube_dim: CubeDim,
34        cube_count: CubeCount,
35        input: InputRuntimeArg<'a, MA, R>,
36        output: OutputRuntimeArg<'a, MA, R>,
37        problem: &ConvolutionProblem,
38        config: Config,
39        dtypes: &MatmulElems,
40    );
41}
42
43#[cube(launch_unchecked)]
44pub(crate) fn implicit_conv<
45    Args: MatmulArgs,
46    LhsG: Numeric,
47    RhsG: Numeric,
48    AccG: Numeric,
49    LhsS: Numeric,
50    RhsS: Numeric,
51    AccS: Numeric,
52    GMM: GlobalConvolutionFamily,
53>(
54    inputs: &Input<Args, LhsG, RhsG, AccG>,
55    output: &mut Output<Args, AccG>,
56    runtime_args: RuntimeArgs,
57    #[comptime] config: GMM::Config,
58    #[define(LhsG)] _lhs_global: StorageType,
59    #[define(RhsG)] _rhs_global: StorageType,
60    #[define(AccG)] _acc_global: StorageType,
61    #[define(LhsS)] _lhs_stage: StorageType,
62    #[define(RhsS)] _rhs_stage: StorageType,
63    #[define(AccS)] _acc_stage: StorageType,
64) {
65    let mut state = Args::init_state::<LhsG, RhsG, AccG, GMM::Config>(inputs, output, config);
66
67    let lhs = Args::view_lhs(&state);
68    let rhs = Args::view_rhs(&state);
69    let bias = Args::view_acc(&state);
70    let out = Args::view_out(&mut state);
71
72    let stage_m = config.tiling_scheme().elements_in_stage_m().runtime();
73    let stage_n = config.tiling_scheme().elements_in_stage_n().runtime();
74
75    let m_offset = CUBE_POS_X * stage_m;
76    let n_offset = CUBE_POS_Y * stage_n;
77
78    let k_range = (0, runtime_args.shape_k);
79    let k_size = runtime_args.shape_k;
80
81    let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
82    let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
83    let bias = match bias {
84        CubeOption::Some(bias) => {
85            let view = bias.view(SliceIndex::new(0, bias.shape()));
86            CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
87        }
88        CubeOption::None => CubeOption::new_None(),
89    };
90    let out = out.view_mut(SliceIndex::new(0, out.shape()));
91
92    GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::execute(
93        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_lhs_global_reader(
94            lhs,
95            (m_offset, k_range.0),
96            (stage_m, k_size),
97            &runtime_args,
98            config,
99        ),
100        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_rhs_global_reader(
101            rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
102            config,
103        ),
104        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_bias_global_reader(
105            bias, config,
106        ),
107        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_global_writer(
108            out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
109            config,
110        ),
111        &mut GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_accumulator(config),
112        k_range,
113        config,
114    );
115}
116
117pub(crate) fn shape_divmod<'a, R: Runtime>(
118    client: &ComputeClient<R::Server>,
119    shape: &[usize],
120) -> SequenceArg<'a, R, FastDivmod> {
121    shape
122        .iter()
123        .map(|s| FastDivmodArgs::new(client, *s as u32))
124        .collect()
125}