cubek_convolution/components/global/read/reader/strategy/
async_full_cyclic.rs1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl::std::tensor::layout::{Layout, LayoutExpand};
5use cubecl::{ir::DeviceProperties, prelude::barrier::Barrier};
6use cubek_matmul::components::{
7 global::{
8 GlobalReaderConfig, PlaneFlowPartition,
9 memory::GlobalIterator,
10 multi_stage::LoadMaxRoundPlaneCount,
11 read::{
12 LoadingJob, LoadingValidation, ReaderMode, async_barrier::AsyncCopy,
13 async_full_cyclic::AsyncFullCyclicLoading as MatmulCyclicLoading, tiled::TiledLayout,
14 },
15 },
16 stage::{ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder},
17};
18use cubek_matmul::definition::{InvalidConfigError, MatmulElems, MatmulProblem, StageIdent};
19
20use crate::components::global::{
21 args::RuntimeArgs,
22 read::{
23 full_reader::FullLoadingStrategy,
24 strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
25 },
26};
27
28#[derive(CubeType, Clone, Copy)]
29pub struct AsyncFullCyclicLoading<T: TilingOrder> {
32 #[cube(comptime)]
33 _t: PhantomData<T>,
34}
35
36impl<TO: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<TO> {
37 fn validate_with_config(
38 device_props: &DeviceProperties,
39 config: &GlobalReaderConfig,
40 ) -> Result<(), InvalidConfigError> {
41 MatmulCyclicLoading::<TO>::validate_with_config(device_props, config)
42 }
43
44 fn validate_with_problem(
45 _problem: &MatmulProblem,
46 _dtypes: &MatmulElems,
47 _ident: StageIdent,
48 ) -> Result<(), InvalidConfigError> {
49 Ok(())
50 }
51}
52
53impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
54 fn max_round_plane_count(
55 elements_per_tile: u32,
56 tiles_per_stage: u32,
57 line_size: LineSize,
58 plane_dim: u32,
59 dtype: StorageType,
60 ) -> u32 {
61 MatmulCyclicLoading::<TO>::max_round_plane_count(
62 elements_per_tile,
63 tiles_per_stage,
64 line_size,
65 plane_dim,
66 dtype,
67 )
68 }
69}
70
71#[cube]
72impl<TO: TilingOrder> FullLoadingStrategy for AsyncFullCyclicLoading<TO> {
73 type TilingLayout = ContiguousTilingLayout<TO>;
74 type SyncStrategy = AsyncCopy;
75 type Job<EG: Numeric, ES: Numeric> = AsyncFullCyclicJob;
76
77 fn new_job<EG: Numeric, ES: Numeric>(
78 runtime_args: RuntimeArgs,
79 #[comptime] _line_size: LineSize,
80 #[comptime] config: GlobalReaderConfig,
81 ) -> Self::Job<EG, ES> {
82 let type_size = ES::type_size_bits().comptime();
83 let line_size = ASYNC_COPY_WIDTH / type_size as u32;
84 let tile_num_elements = config.smem_config.elements_per_tile();
85 let num_stage_elements = config.smem_config.elements_per_stage();
86
87 let num_stage_lines = num_stage_elements.div_ceil(line_size);
88 let total_units = config.loading_units_count();
89 let num_tasks_per_unit = num_stage_lines.div_ceil(total_units);
90 let balanced_workload = num_stage_lines.is_multiple_of(total_units);
91 let jump_length = total_units * line_size;
92
93 let unit_id = PlaneFlowPartition::new(config.plane_flow_config.partition_rule)
94 .load_index(config.input_load_flow)
95 * config.plane_dim
96 + UNIT_POS_X;
97 let unit_position_base = unit_id * line_size;
98
99 AsyncFullCyclicJob {
100 unit_position_base,
101 runtime_args,
102 num_tasks_per_unit,
103 tile_num_elements,
104 jump_length,
105 copy_line_size: line_size,
106 balanced_workload,
107 num_stage_elements,
108 reader_mode: config.reader_mode,
109 }
110 }
111}
112
113#[derive(CubeType, Clone)]
114pub struct AsyncFullCyclicJob {
115 unit_position_base: u32,
116 runtime_args: RuntimeArgs,
117
118 #[cube(comptime)]
119 num_tasks_per_unit: u32,
120 #[cube(comptime)]
121 tile_num_elements: u32,
122 #[cube(comptime)]
123 jump_length: u32,
124 #[cube(comptime)]
125 copy_line_size: u32,
126 #[cube(comptime)]
127 balanced_workload: bool,
128 #[cube(comptime)]
129 num_stage_elements: u32,
130 #[cube(comptime)]
131 reader_mode: ReaderMode,
132}
133
134#[cube]
135impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
136 LoadingJob<EG, ES, ContiguousTilingLayout<TO>, AsyncCopy> for AsyncFullCyclicJob
137{
138 type Stage = StridedStageFamily;
139
140 fn execute_task(
141 this: &mut Self,
142 #[comptime] task_id: u32,
143 global_iter: &GlobalIterator<Line<EG>>,
144 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
145 _barrier: &mut Shared<Barrier>,
146 #[comptime] config: GlobalReaderConfig,
147 ) {
148 let unit_position = this.unit_position_base + task_id * this.jump_length;
149
150 #[allow(clippy::collapsible_else_if)]
151 if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
152 copy_line::<EG, ES, TO>(
153 this,
154 unit_position,
155 global_iter,
156 stage,
157 &this.runtime_args,
158 config,
159 );
160 } else {
161 if unit_position < this.num_stage_elements {
162 copy_line::<EG, ES, TO>(
163 this,
164 unit_position,
165 global_iter,
166 stage,
167 &this.runtime_args,
168 config,
169 );
170 }
171 }
172 }
173
174 fn task_count(this: &Self) -> comptime_type!(u32) {
175 this.num_tasks_per_unit
176 }
177}
178
179#[cube]
180pub(crate) fn copy_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
181 job: &AsyncFullCyclicJob,
182 unit_position: u32,
183 global_iter: &GlobalIterator<Line<EG>>,
184 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
185 runtime_args: &RuntimeArgs,
186 #[comptime] config: GlobalReaderConfig,
187) {
188 let nth_tile = unit_position / job.tile_num_elements;
189 let pos_within_tile = unit_position % job.tile_num_elements;
190
191 let layout = TiledLayout::new(config.stage_ident, config.smem_config);
192 let view = global_iter.view();
193
194 let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile, config.smem_config);
195
196 let pos = layout.to_source_pos((tile, pos_within_tile));
197 let stage_offset = unit_position / stage.smem.line_size() as u32;
198
199 async_copy_from(
200 view,
201 pos,
202 stage,
203 stage_offset,
204 runtime_args,
205 global_iter.offset(),
206 config,
207 job.copy_line_size,
208 );
209}