cubek_convolution/components/global/read/strategy/
async_full_cyclic.rs1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl::std::tensor::layout::{Layout, LayoutExpand};
5use cubecl::{ir::DeviceProperties, prelude::barrier::Barrier};
6use cubek_matmul::components::{
7 global::{
8 GlobalReaderConfig, PlaneFlowPartition,
9 memory::GlobalIterator,
10 multi_stage::LoadMaxRoundPlaneCount,
11 read::{
12 FullLoadingStrategy, LoadingJob, LoadingValidation, ReaderMode,
13 async_barrier::AsyncCopy,
14 async_full_cyclic::AsyncFullCyclicLoading as MatmulCyclicLoading, tiled::TiledLayout,
15 },
16 },
17 stage::{ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder},
18 tile::io::Strided,
19};
20use cubek_matmul::definition::{InvalidConfigError, MatmulElems, MatmulProblem, StageIdent};
21
22use crate::components::global::{
23 args::RuntimeArgs,
24 read::strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
25};
26
27#[derive(CubeType, Clone, Copy)]
28pub struct AsyncFullCyclicLoading<T: TilingOrder> {
31 #[cube(comptime)]
32 _t: PhantomData<T>,
33}
34
35impl<TO: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<TO> {
36 fn validate_with_config(
37 device_props: &DeviceProperties,
38 config: &GlobalReaderConfig,
39 ) -> Result<(), InvalidConfigError> {
40 MatmulCyclicLoading::<TO>::validate_with_config(device_props, config)
41 }
42
43 fn validate_with_problem(
44 problem: &MatmulProblem,
45 dtypes: &MatmulElems,
46 ident: StageIdent,
47 ) -> Result<(), InvalidConfigError> {
48 MatmulCyclicLoading::<TO>::validate_with_problem(problem, dtypes, ident)
49 }
50}
51
52impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
53 fn max_round_plane_count(
54 elements_per_tile: u32,
55 tiles_per_stage: u32,
56 line_size: LineSize,
57 plane_dim: u32,
58 dtype: StorageType,
59 ) -> u32 {
60 MatmulCyclicLoading::<TO>::max_round_plane_count(
61 elements_per_tile,
62 tiles_per_stage,
63 line_size,
64 plane_dim,
65 dtype,
66 )
67 }
68}
69
70#[cube]
71impl<TO: TilingOrder> FullLoadingStrategy<RuntimeArgs> for AsyncFullCyclicLoading<TO> {
72 type TilingLayout = ContiguousTilingLayout<TO>;
73 type SyncStrategy = AsyncCopy;
74 type Job<EG: Numeric, ES: Numeric> = AsyncFullCyclicJob;
75 type Stage = StridedStageFamily;
76 type TileKind = Strided;
77
78 fn new_job<EG: Numeric, ES: Numeric>(
79 runtime_args: RuntimeArgs,
80 #[comptime] _line_size: LineSize,
81 #[comptime] config: GlobalReaderConfig,
82 ) -> Self::Job<EG, ES> {
83 let type_size = ES::type_size_bits().comptime();
84 let line_size = ASYNC_COPY_WIDTH / type_size as u32;
85 let tile_num_elements = config.smem_config.elements_per_tile();
86 let num_stage_elements = config.smem_config.elements_per_stage();
87
88 let num_stage_lines = num_stage_elements.div_ceil(line_size);
89 let total_units = config.loading_units_count();
90 let num_tasks_per_unit = num_stage_lines.div_ceil(total_units);
91 let balanced_workload = num_stage_lines.is_multiple_of(total_units);
92 let jump_length = total_units * line_size;
93
94 let unit_id = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
95 .load_index(config.input_load_flow)
96 * config.plane_dim
97 + UNIT_POS_X;
98 let unit_position_base = unit_id * line_size;
99
100 AsyncFullCyclicJob {
101 unit_position_base,
102 runtime_args,
103 num_tasks_per_unit,
104 tile_num_elements,
105 jump_length,
106 copy_line_size: line_size,
107 balanced_workload,
108 num_stage_elements,
109 reader_mode: config.reader_mode,
110 }
111 }
112}
113
114#[derive(CubeType, Clone)]
115pub struct AsyncFullCyclicJob {
116 unit_position_base: u32,
117 runtime_args: RuntimeArgs,
118
119 #[cube(comptime)]
120 num_tasks_per_unit: u32,
121 #[cube(comptime)]
122 tile_num_elements: u32,
123 #[cube(comptime)]
124 jump_length: u32,
125 #[cube(comptime)]
126 copy_line_size: u32,
127 #[cube(comptime)]
128 balanced_workload: bool,
129 #[cube(comptime)]
130 num_stage_elements: u32,
131 #[cube(comptime)]
132 reader_mode: ReaderMode,
133}
134
135#[cube]
136impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
137 LoadingJob<EG, ES, ContiguousTilingLayout<TO>, AsyncCopy> for AsyncFullCyclicJob
138{
139 type Stage = StridedStageFamily;
140
141 fn execute_task(
142 this: &mut Self,
143 #[comptime] task_id: u32,
144 global_iter: &GlobalIterator<Line<EG>>,
145 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
146 _barrier: &mut Shared<Barrier>,
147 #[comptime] config: GlobalReaderConfig,
148 ) {
149 let unit_position = this.unit_position_base + task_id * this.jump_length;
150
151 #[allow(clippy::collapsible_else_if)]
152 if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
153 copy_line::<EG, ES, TO>(
154 this,
155 unit_position,
156 global_iter,
157 stage,
158 &this.runtime_args,
159 config,
160 );
161 } else {
162 if unit_position < this.num_stage_elements {
163 copy_line::<EG, ES, TO>(
164 this,
165 unit_position,
166 global_iter,
167 stage,
168 &this.runtime_args,
169 config,
170 );
171 }
172 }
173 }
174
175 fn task_count(this: &Self) -> comptime_type!(u32) {
176 this.num_tasks_per_unit
177 }
178}
179
180#[cube]
181pub(crate) fn copy_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
182 job: &AsyncFullCyclicJob,
183 unit_position: u32,
184 global_iter: &GlobalIterator<Line<EG>>,
185 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
186 runtime_args: &RuntimeArgs,
187 #[comptime] config: GlobalReaderConfig,
188) {
189 let nth_tile = unit_position / job.tile_num_elements;
190 let pos_within_tile = unit_position % job.tile_num_elements;
191
192 let layout = TiledLayout::new(config.stage_ident, config.smem_config);
193 let view = global_iter.view();
194
195 let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile, config.smem_config);
196
197 let pos = layout.to_source_pos((tile, pos_within_tile));
198 let stage_offset = unit_position / stage.smem.line_size() as u32;
199
200 async_copy_from(
201 view,
202 pos,
203 stage,
204 stage_offset,
205 runtime_args,
206 global_iter.offset(),
207 config,
208 job.copy_line_size,
209 );
210}