cubecl_convolution/components/global/
entry_point.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_core::{Runtime, client::ComputeClient};
4use cubecl_matmul::components::MatmulElems;
5use cubecl_matmul::components::global::GlobalConfig;
6use cubecl_matmul::components::stage::StageConfig as _;
7use cubecl_matmul::components::{
8    InputRuntimeArg, OutputRuntimeArg, batch::SliceIndex, global::args::MatmulArgs,
9};
10use cubecl_std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs};
11
12use crate::components::{ConvGemmConfig, global::args::RuntimeArgs};
13use crate::components::{
14    ConvolutionProblem,
15    global::{GlobalConvolution, GlobalConvolutionFamily},
16};
17
18type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
19type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
20
21/// Provides launch entry point to solve a matmul
22pub trait ConvolutionLaunch<Config> {
23    /// Entry point
24    ///
25    /// # Safety
26    ///
27    /// Out-of-bounds can happen
28    #[allow(clippy::too_many_arguments)]
29    unsafe fn launch_unchecked<'a, MA: MatmulArgs, R: Runtime>(
30        client: &ComputeClient<R>,
31        cube_dim: CubeDim,
32        cube_count: CubeCount,
33        input: InputRuntimeArg<'a, MA, R>,
34        output: OutputRuntimeArg<'a, MA, R>,
35        problem: &ConvolutionProblem,
36        config: Config,
37        dtypes: &MatmulElems,
38    ) -> Result<(), LaunchError>;
39}
40
41#[cube(launch_unchecked)]
42pub(crate) fn implicit_conv<
43    Args: MatmulArgs,
44    LhsG: Numeric,
45    RhsG: Numeric,
46    AccG: Numeric,
47    LhsS: Numeric,
48    RhsS: Numeric,
49    AccS: Numeric,
50    LhsR: Numeric,
51    RhsR: Numeric,
52    AccR: Numeric,
53    GMM: GlobalConvolutionFamily,
54>(
55    inputs: &Input<Args, LhsG, RhsG, AccG>,
56    output: &mut Output<Args, AccG>,
57    runtime_args: RuntimeArgs,
58    #[comptime] config: GMM::Config,
59    #[define(LhsG, RhsG, AccG)] _global: [StorageType; 3],
60    #[define(LhsS, RhsS, AccS)] _stage: [StorageType; 3],
61    #[define(LhsR, RhsR, AccR)] _register: [StorageType; 3],
62) {
63    let mut state = Args::init_state::<
64        LhsG,
65        RhsG,
66        AccG,
67        <GMM::Config as ConvGemmConfig>::GlobalMatmulConfig,
68    >(inputs, output, config.matmul_config());
69
70    let lhs = Args::view_lhs(&state);
71    let rhs = Args::view_rhs(&state);
72    let bias = Args::view_acc(&state);
73    let out = Args::view_out(&mut state);
74
75    let stage_m = config
76        .matmul_config()
77        .stage_config()
78        .elements_in_stage_m()
79        .runtime();
80    let stage_n = config
81        .matmul_config()
82        .stage_config()
83        .elements_in_stage_n()
84        .runtime();
85
86    let m_offset = CUBE_POS_X * stage_m;
87    let n_offset = CUBE_POS_Y * stage_n;
88
89    let k_range = (0, runtime_args.shape_k);
90    let k_size = runtime_args.shape_k;
91
92    let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
93    let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
94    let bias = match bias {
95        CubeOption::Some(bias) => {
96            let view = bias.view(SliceIndex::new(0, bias.shape()));
97            CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
98        }
99        CubeOption::None => CubeOption::new_None(),
100    };
101    let out = out.view_mut(SliceIndex::new(0, out.shape()));
102
103    GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::execute(
104        GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_lhs_global_reader(
105            lhs,
106            (m_offset, k_range.0),
107            (stage_m, k_size),
108            &runtime_args,
109            config,
110        ),
111        GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_rhs_global_reader(
112            rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
113            &runtime_args,
114            config,
115        ),
116        GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_bias_global_reader(
117            bias, config,
118        ),
119        GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_global_writer(
120            out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
121            config,
122        ),
123        &mut GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_accumulator(
124            config,
125        ),
126        k_range,
127        config,
128    );
129}
130
131pub(crate) fn shape_divmod<'a, R: Runtime>(
132    client: &ComputeClient<R>,
133    shape: &[usize],
134) -> SequenceArg<'a, R, FastDivmod> {
135    shape
136        .iter()
137        .map(|s| FastDivmodArgs::new(client, *s as u32))
138        .collect()
139}