cubek_convolution/components/global/read/strategy/
async_full_cyclic.rs1use std::marker::PhantomData;
2
3use cubecl::{
4 prelude::*,
5 std::tensor::layout::{Layout, LayoutExpand},
6 {ir::DeviceProperties, prelude::barrier::Barrier},
7};
8use cubek_matmul::components::{
9 global::{
10 GlobalReaderConfig, PlaneFlowPartition,
11 memory::GlobalIterator,
12 multi_stage::LoadMaxRoundPlaneCount,
13 read::{
14 FullLoadingStrategy, LoadingJob, LoadingValidation, ReaderMode,
15 async_barrier::AsyncCopy,
16 async_full_cyclic::AsyncFullCyclicLoading as MatmulCyclicLoading, tiled::TiledLayout,
17 },
18 },
19 stage::{ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder},
20};
21use cubek_matmul::definition::{MatmulElems, MatmulProblem};
22use cubek_std::{InvalidConfigError, StageIdent, tile::Strided};
23
24use crate::components::global::{
25 args::RuntimeArgs,
26 read::strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
27};
28
29#[derive(CubeType, Clone, Copy)]
30pub struct AsyncFullCyclicLoading<T: TilingOrder> {
33 #[cube(comptime)]
34 _t: PhantomData<T>,
35}
36
37impl<TO: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<TO> {
38 fn validate_with_config(
39 device_props: &DeviceProperties,
40 config: &GlobalReaderConfig,
41 ) -> Result<(), InvalidConfigError> {
42 MatmulCyclicLoading::<TO>::validate_with_config(device_props, config)
43 }
44
45 fn validate_with_problem(
46 problem: &MatmulProblem,
47 dtypes: &MatmulElems,
48 ident: StageIdent,
49 ) -> Result<(), InvalidConfigError> {
50 MatmulCyclicLoading::<TO>::validate_with_problem(problem, dtypes, ident)
51 }
52}
53
54impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
55 fn max_round_plane_count(
56 elements_per_tile: u32,
57 tiles_per_stage: u32,
58 vector_size: VectorSize,
59 plane_dim: u32,
60 dtype: StorageType,
61 ) -> u32 {
62 MatmulCyclicLoading::<TO>::max_round_plane_count(
63 elements_per_tile,
64 tiles_per_stage,
65 vector_size,
66 plane_dim,
67 dtype,
68 )
69 }
70}
71
72#[cube]
73impl<TO: TilingOrder> FullLoadingStrategy<RuntimeArgs> for AsyncFullCyclicLoading<TO> {
74 type TilingLayout = ContiguousTilingLayout<TO>;
75 type SyncStrategy = AsyncCopy;
76 type Job<EG: Numeric, NG: Size, ES: Numeric, NS: Size> = AsyncFullCyclicJob;
77 type Stage = StridedStageFamily;
78 type TileKind = Strided;
79
80 fn new_job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>(
81 runtime_args: RuntimeArgs,
82 #[comptime] config: GlobalReaderConfig,
83 ) -> Self::Job<EG, NG, ES, NS> {
84 let type_size = ES::type_size_bits().comptime();
85 let vector_size = ASYNC_COPY_WIDTH / type_size as u32;
86 let tile_num_elements = config.smem_config.elements_per_tile();
87 let num_stage_elements = config.smem_config.elements_per_stage();
88
89 let num_stage_vectors = num_stage_elements.div_ceil(vector_size);
90 let total_units = config.loading_units_count();
91 let num_tasks_per_unit = num_stage_vectors.div_ceil(total_units);
92 let balanced_workload = num_stage_vectors.is_multiple_of(total_units);
93 let jump_length = total_units * vector_size;
94
95 let unit_id = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
96 .load_index(config.input_load_flow)
97 * config.plane_dim
98 + UNIT_POS_X;
99 let unit_position_base = unit_id * vector_size;
100
101 AsyncFullCyclicJob {
102 unit_position_base,
103 runtime_args,
104 num_tasks_per_unit,
105 tile_num_elements,
106 jump_length,
107 copy_vector_size: vector_size,
108 balanced_workload,
109 num_stage_elements,
110 reader_mode: config.reader_mode,
111 }
112 }
113}
114
115#[derive(CubeType, Clone)]
116pub struct AsyncFullCyclicJob {
117 unit_position_base: u32,
118 runtime_args: RuntimeArgs,
119
120 #[cube(comptime)]
121 num_tasks_per_unit: u32,
122 #[cube(comptime)]
123 tile_num_elements: u32,
124 #[cube(comptime)]
125 jump_length: u32,
126 #[cube(comptime)]
127 copy_vector_size: u32,
128 #[cube(comptime)]
129 balanced_workload: bool,
130 #[cube(comptime)]
131 num_stage_elements: u32,
132 #[cube(comptime)]
133 reader_mode: ReaderMode,
134}
135
136#[cube]
137impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, TO: TilingOrder>
138 LoadingJob<EG, NG, ES, NS, ContiguousTilingLayout<TO>, AsyncCopy> for AsyncFullCyclicJob
139{
140 type Stage = StridedStageFamily;
141
142 fn execute_task(
143 this: &mut Self,
144 #[comptime] task_id: u32,
145 global_iter: &GlobalIterator<Vector<EG, NG>>,
146 stage: &mut StridedStageMemory<ES, NS, ContiguousTilingLayout<TO>>,
147 _barrier: &mut Shared<Barrier>,
148 #[comptime] config: GlobalReaderConfig,
149 ) {
150 let unit_position = this.unit_position_base + task_id * this.jump_length;
151
152 #[allow(clippy::collapsible_else_if)]
153 if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
154 copy_vector::<EG, NG, ES, NS, TO>(
155 this,
156 unit_position,
157 global_iter,
158 stage,
159 &this.runtime_args,
160 config,
161 );
162 } else {
163 if unit_position < this.num_stage_elements {
164 copy_vector::<EG, NG, ES, NS, TO>(
165 this,
166 unit_position,
167 global_iter,
168 stage,
169 &this.runtime_args,
170 config,
171 );
172 }
173 }
174 }
175
176 fn task_count(this: &Self) -> comptime_type!(u32) {
177 this.num_tasks_per_unit
178 }
179}
180
181#[cube]
182pub(crate) fn copy_vector<EG: Numeric, NG: Size, ES: Numeric, NS: Size, TO: TilingOrder>(
183 job: &AsyncFullCyclicJob,
184 unit_position: u32,
185 global_iter: &GlobalIterator<Vector<EG, NG>>,
186 stage: &mut StridedStageMemory<ES, NS, ContiguousTilingLayout<TO>>,
187 runtime_args: &RuntimeArgs,
188 #[comptime] config: GlobalReaderConfig,
189) {
190 let nth_tile = unit_position / job.tile_num_elements;
191 let pos_within_tile = unit_position % job.tile_num_elements;
192
193 let layout = TiledLayout::new(config.stage_ident, config.smem_config);
194 let view = global_iter.view();
195
196 let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile, config.smem_config);
197
198 let pos = layout.to_source_pos((tile, pos_within_tile));
199 let stage_offset = unit_position / stage.smem.vector_size() as u32;
200
201 async_copy_from(
202 view,
203 pos,
204 stage,
205 stage_offset,
206 runtime_args,
207 global_iter.offset(),
208 config,
209 job.copy_vector_size,
210 );
211}