cubecl_convolution/components/global/
entry_point.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_core::{Runtime, client::ComputeClient};
4use cubecl_matmul::components::MatmulElems;
5use cubecl_matmul::components::global::GlobalConfig;
6use cubecl_matmul::components::stage::StageConfig as _;
7use cubecl_matmul::components::{
8    InputRuntimeArg, OutputRuntimeArg, batch::SliceIndex, global::args::MatmulArgs,
9};
10use cubecl_std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs};
11
12use crate::components::ConvGemmConfig;
13use crate::{
14    components::{
15        ConvolutionProblem,
16        global::{GlobalConvolution, GlobalConvolutionFamily},
17    },
18    kernels::layered::selector::RuntimeArgs,
19};
20
21type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
22type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
23
24/// Provides launch entry point to solve a matmul
25pub trait ConvolutionLaunch<Config> {
26    /// Entry point
27    ///
28    /// # Safety
29    ///
30    /// Out-of-bounds can happen
31    #[allow(clippy::too_many_arguments)]
32    unsafe fn launch_unchecked<'a, MA: MatmulArgs, R: Runtime>(
33        client: &ComputeClient<<R as Runtime>::Server>,
34        cube_dim: CubeDim,
35        cube_count: CubeCount,
36        input: InputRuntimeArg<'a, MA, R>,
37        output: OutputRuntimeArg<'a, MA, R>,
38        problem: &ConvolutionProblem,
39        config: Config,
40        dtypes: &MatmulElems,
41    );
42}
43
44#[cube(launch_unchecked)]
45pub(crate) fn implicit_conv<
46    Args: MatmulArgs,
47    LhsG: Numeric,
48    RhsG: Numeric,
49    AccG: Numeric,
50    LhsS: Numeric,
51    RhsS: Numeric,
52    AccS: Numeric,
53    GMM: GlobalConvolutionFamily,
54>(
55    inputs: &Input<Args, LhsG, RhsG, AccG>,
56    output: &mut Output<Args, AccG>,
57    runtime_args: RuntimeArgs,
58    #[comptime] config: GMM::Config,
59    #[define(LhsG)] _lhs_global: StorageType,
60    #[define(RhsG)] _rhs_global: StorageType,
61    #[define(AccG)] _acc_global: StorageType,
62    #[define(LhsS)] _lhs_stage: StorageType,
63    #[define(RhsS)] _rhs_stage: StorageType,
64    #[define(AccS)] _acc_stage: StorageType,
65) {
66    let mut state = Args::init_state::<
67        LhsG,
68        RhsG,
69        AccG,
70        <GMM::Config as ConvGemmConfig>::GlobalMatmulConfig,
71    >(inputs, output, config.matmul_config());
72
73    let lhs = Args::view_lhs(&state);
74    let rhs = Args::view_rhs(&state);
75    let bias = Args::view_acc(&state);
76    let out = Args::view_out(&mut state);
77
78    let stage_m = config
79        .matmul_config()
80        .stage_config()
81        .elements_in_stage_m()
82        .runtime();
83    let stage_n = config
84        .matmul_config()
85        .stage_config()
86        .elements_in_stage_n()
87        .runtime();
88
89    let m_offset = CUBE_POS_X * stage_m;
90    let n_offset = CUBE_POS_Y * stage_n;
91
92    let k_range = (0, runtime_args.shape_k);
93    let k_size = runtime_args.shape_k;
94
95    let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
96    let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
97    let bias = match bias {
98        CubeOption::Some(bias) => {
99            let view = bias.view(SliceIndex::new(0, bias.shape()));
100            CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
101        }
102        CubeOption::None => CubeOption::new_None(),
103    };
104    let out = out.view_mut(SliceIndex::new(0, out.shape()));
105
106    GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::execute(
107        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_lhs_global_reader(
108            lhs,
109            (m_offset, k_range.0),
110            (stage_m, k_size),
111            &runtime_args,
112            config,
113        ),
114        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_rhs_global_reader(
115            rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
116            config,
117        ),
118        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_bias_global_reader(
119            bias, config,
120        ),
121        GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_global_writer(
122            out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
123            config,
124        ),
125        &mut GMM::Convolution::<(LhsG, RhsG, AccG, LhsS, RhsS, AccS)>::init_accumulator(config),
126        k_range,
127        config,
128    );
129}
130
131pub(crate) fn shape_divmod<'a, R: Runtime>(
132    client: &ComputeClient<R::Server>,
133    shape: &[usize],
134) -> SequenceArg<'a, R, FastDivmod> {
135    shape
136        .iter()
137        .map(|s| FastDivmodArgs::new(client, *s as u32))
138        .collect()
139}