cubek_convolution/components/global/
entry_point.rs

1use cubecl;
2use cubecl::prelude::*;
3use cubecl::std::{CubeOption, CubeOptionExpand};
4use cubecl::{Runtime, client::ComputeClient};
5use cubek_matmul::components::batch::SliceIndex;
6use cubek_matmul::components::global::GlobalConfig;
7use cubek_matmul::components::stage::StageConfig as _;
8use cubek_matmul::definition::MatmulElems;
9use cubek_matmul::launch::{InputRuntimeArg, MatmulArgs, OutputRuntimeArg};
10
11use crate::components::global::{GlobalConvolution, GlobalConvolutionFamily};
12use crate::components::{
13    ConvGemmConfig,
14    global::args::{RuntimeArgs, RuntimeArgsLaunch},
15};
16
17type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
18type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
19
20/// Provides launch entry point to solve a matmul
21pub trait ConvolutionLaunch<Config> {
22    /// Entry point
23    ///
24    /// # Safety
25    ///
26    /// Out-of-bounds can happen
27    #[allow(clippy::too_many_arguments)]
28    unsafe fn launch_unchecked<'a, MA: MatmulArgs, R: Runtime>(
29        client: &ComputeClient<R>,
30        cube_dim: CubeDim,
31        cube_count: CubeCount,
32        input: InputRuntimeArg<'a, MA, R>,
33        output: OutputRuntimeArg<'a, MA, R>,
34        runtime_args: RuntimeArgsLaunch<'a, R>,
35        config: Config,
36        dtypes: &MatmulElems,
37    ) -> Result<(), LaunchError>;
38}
39
40#[cube(launch_unchecked)]
41pub(crate) fn implicit_conv<
42    Args: MatmulArgs,
43    LhsG: Numeric,
44    RhsG: Numeric,
45    AccG: Numeric,
46    LhsS: Numeric,
47    RhsS: Numeric,
48    AccS: Numeric,
49    LhsR: Numeric,
50    RhsR: Numeric,
51    AccR: Numeric,
52    GMM: GlobalConvolutionFamily,
53>(
54    inputs: &Input<Args, LhsG, RhsG, AccG>,
55    output: &mut Output<Args, AccG>,
56    runtime_args: RuntimeArgs,
57    #[comptime] config: GMM::Config,
58    #[define(LhsG, RhsG, AccG)] _global: [StorageType; 3],
59    #[define(LhsS, RhsS, AccS)] _stage: [StorageType; 3],
60    #[define(LhsR, RhsR, AccR)] _register: [StorageType; 3],
61) {
62    let mut state = Args::init_state::<LhsG, RhsG, AccG>(
63        inputs,
64        output,
65        config
66            .matmul_config()
67            .lhs_reader_config()
68            .gmem_config
69            .as_global_layout_config(),
70        config
71            .matmul_config()
72            .rhs_reader_config()
73            .gmem_config
74            .as_global_layout_config(),
75        config
76            .matmul_config()
77            .writer_config()
78            .gmem_config
79            .as_global_layout_config(),
80    );
81
82    let lhs = Args::view_lhs(&state);
83    let rhs = Args::view_rhs(&state);
84    let bias = Args::view_acc(&state);
85    let out = Args::view_out(&mut state);
86
87    let stage_m = config
88        .matmul_config()
89        .stage_config()
90        .elements_in_stage_m()
91        .runtime();
92    let stage_n = config
93        .matmul_config()
94        .stage_config()
95        .elements_in_stage_n()
96        .runtime();
97
98    let m_offset = CUBE_POS_X * stage_m;
99    let n_offset = CUBE_POS_Y * stage_n;
100
101    let k_range = (0, runtime_args.shape_k);
102    let k_size = runtime_args.shape_k;
103
104    let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
105    let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
106    let bias = match bias {
107        CubeOption::Some(bias) => {
108            let view = bias.view(SliceIndex::new(0, bias.shape()));
109            CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
110        }
111        CubeOption::None => CubeOption::new_None(),
112    };
113    let out = out.view_mut(SliceIndex::new(0, out.shape()));
114
115    GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::execute(
116        GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_lhs_global_reader(
117            lhs,
118            (m_offset, k_range.0),
119            (stage_m, k_size),
120            &runtime_args,
121            config,
122        ),
123        GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_rhs_global_reader(
124            rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
125            &runtime_args,
126            config,
127        ),
128        GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_bias_global_reader(
129            bias, config,
130        ),
131        GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_global_writer(
132            out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
133            config,
134        ),
135        &mut GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_accumulator(
136            config,
137        ),
138        k_range,
139        config,
140    );
141}