cubek_convolution/components/global/
entry_point.rs1use cubecl;
2use cubecl::prelude::*;
3use cubecl::std::{CubeOption, CubeOptionExpand};
4use cubecl::{Runtime, client::ComputeClient};
5use cubek_matmul::components::batch::SliceIndex;
6use cubek_matmul::components::global::GlobalConfig;
7use cubek_matmul::components::stage::StageConfig as _;
8use cubek_matmul::definition::MatmulElems;
9use cubek_matmul::launch::{InputRuntimeArg, MatmulArgs, OutputRuntimeArg};
10
11use crate::components::global::{GlobalConvolution, GlobalConvolutionFamily};
12use crate::components::{
13 ConvGemmConfig,
14 global::args::{RuntimeArgs, RuntimeArgsLaunch},
15};
16
17type Input<Args, Lhs, Rhs, EO> = <Args as MatmulArgs>::Input<Lhs, Rhs, EO>;
18type Output<Args, EO> = <Args as MatmulArgs>::Output<EO>;
19
20pub trait ConvolutionLaunch<Config> {
22 #[allow(clippy::too_many_arguments)]
28 unsafe fn launch_unchecked<'a, MA: MatmulArgs, R: Runtime>(
29 client: &ComputeClient<R>,
30 cube_dim: CubeDim,
31 cube_count: CubeCount,
32 input: InputRuntimeArg<'a, MA, R>,
33 output: OutputRuntimeArg<'a, MA, R>,
34 runtime_args: RuntimeArgsLaunch<'a, R>,
35 config: Config,
36 dtypes: &MatmulElems,
37 ) -> Result<(), LaunchError>;
38}
39
40#[cube(launch_unchecked)]
41pub(crate) fn implicit_conv<
42 Args: MatmulArgs,
43 LhsG: Numeric,
44 RhsG: Numeric,
45 AccG: Numeric,
46 LhsS: Numeric,
47 RhsS: Numeric,
48 AccS: Numeric,
49 LhsR: Numeric,
50 RhsR: Numeric,
51 AccR: Numeric,
52 GMM: GlobalConvolutionFamily,
53>(
54 inputs: &Input<Args, LhsG, RhsG, AccG>,
55 output: &mut Output<Args, AccG>,
56 runtime_args: RuntimeArgs,
57 #[comptime] config: GMM::Config,
58 #[define(LhsG, RhsG, AccG)] _global: [StorageType; 3],
59 #[define(LhsS, RhsS, AccS)] _stage: [StorageType; 3],
60 #[define(LhsR, RhsR, AccR)] _register: [StorageType; 3],
61) {
62 let mut state = Args::init_state::<LhsG, RhsG, AccG>(
63 inputs,
64 output,
65 config
66 .matmul_config()
67 .lhs_reader_config()
68 .gmem_config
69 .as_global_layout_config(),
70 config
71 .matmul_config()
72 .rhs_reader_config()
73 .gmem_config
74 .as_global_layout_config(),
75 config
76 .matmul_config()
77 .writer_config()
78 .gmem_config
79 .as_global_layout_config(),
80 );
81
82 let lhs = Args::view_lhs(&state);
83 let rhs = Args::view_rhs(&state);
84 let bias = Args::view_acc(&state);
85 let out = Args::view_out(&mut state);
86
87 let stage_m = config
88 .matmul_config()
89 .stage_config()
90 .elements_in_stage_m()
91 .runtime();
92 let stage_n = config
93 .matmul_config()
94 .stage_config()
95 .elements_in_stage_n()
96 .runtime();
97
98 let m_offset = CUBE_POS_X * stage_m;
99 let n_offset = CUBE_POS_Y * stage_n;
100
101 let k_range = (0, runtime_args.shape_k);
102 let k_size = runtime_args.shape_k;
103
104 let lhs = lhs.view(SliceIndex::new(0, lhs.shape()));
105 let rhs = rhs.view(SliceIndex::new(0, rhs.shape()));
106 let bias = match bias {
107 CubeOption::Some(bias) => {
108 let view = bias.view(SliceIndex::new(0, bias.shape()));
109 CubeOption::new_Some(view.slice_unchecked((0, n_offset), (1, stage_n)))
110 }
111 CubeOption::None => CubeOption::new_None(),
112 };
113 let out = out.view_mut(SliceIndex::new(0, out.shape()));
114
115 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::execute(
116 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_lhs_global_reader(
117 lhs,
118 (m_offset, k_range.0),
119 (stage_m, k_size),
120 &runtime_args,
121 config,
122 ),
123 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_rhs_global_reader(
124 rhs.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
125 &runtime_args,
126 config,
127 ),
128 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_bias_global_reader(
129 bias, config,
130 ),
131 GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_global_writer(
132 out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
133 config,
134 ),
135 &mut GMM::Convolution::<((LhsG, LhsS, LhsR), (RhsG, RhsS, RhsR), (AccG, AccS, AccR))>::init_accumulator(
136 config,
137 ),
138 k_range,
139 config,
140 );
141}