cubecl_convolution/components/global/read/reader/strategy/
async_full_cyclic.rs1use std::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
5use cubecl_matmul::components::{
6 InvalidConfigError, MatmulElems, MatmulProblem,
7 global::{
8 GlobalReaderConfig, RoleRule,
9 memory::GlobalIterator,
10 multi_stage::LoadMaxRoundPlaneCount,
11 read::{
12 LoadingJob, LoadingValidation, ReaderMode, async_barrier::AsyncCopy,
13 async_full_cyclic::AsyncFullCyclicLoading as MatmulCyclicLoading, tiled::TiledLayout,
14 },
15 },
16 stage::{ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder},
17};
18use cubecl_std::tensor::layout::{Layout, LayoutExpand};
19
20use crate::components::global::{
21 args::RuntimeArgs,
22 read::{
23 full_reader::FullLoadingStrategy,
24 strategy::async_copy::{ASYNC_COPY_WIDTH, async_copy_from},
25 },
26};
27
28#[derive(CubeType, Clone, Copy)]
29pub struct AsyncFullCyclicLoading<T: TilingOrder> {
32 #[cube(comptime)]
33 _t: PhantomData<T>,
34}
35
36impl<TO: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<TO> {
37 fn check<R: Runtime>(
38 client: &ComputeClient<R>,
39 problem: &MatmulProblem,
40 config: &GlobalReaderConfig,
41 dtypes: &MatmulElems,
42 ) -> Result<(), InvalidConfigError> {
43 MatmulCyclicLoading::<TO>::check(client, problem, config, dtypes)
44 }
45}
46
47impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
48 fn max_round_plane_count(
49 elements_per_tile: u32,
50 tiles_per_stage: u32,
51 line_size: u8,
52 plane_dim: u32,
53 dtype: StorageType,
54 ) -> u32 {
55 MatmulCyclicLoading::<TO>::max_round_plane_count(
56 elements_per_tile,
57 tiles_per_stage,
58 line_size,
59 plane_dim,
60 dtype,
61 )
62 }
63}
64
65#[cube]
66impl<TO: TilingOrder> FullLoadingStrategy for AsyncFullCyclicLoading<TO> {
67 type TilingLayout = ContiguousTilingLayout<TO>;
68 type SyncStrategy = AsyncCopy;
69 type Job<EG: Numeric, ES: Numeric> = AsyncFullCyclicJob;
70
71 fn new_job<EG: Numeric, ES: Numeric>(
72 runtime_args: RuntimeArgs,
73 #[comptime] _line_size: u32,
74 #[comptime] config: GlobalReaderConfig,
75 ) -> Self::Job<EG, ES> {
76 let type_size = ES::type_size_bits();
77 let line_size = comptime![ASYNC_COPY_WIDTH / type_size];
78 let tile_num_elements = config.smem_config.elements_per_tile();
79 let num_stage_elements = config.smem_config.elements_per_stage();
80
81 let num_stage_lines = num_stage_elements.div_ceil(line_size);
82 let total_units = config.loading_units_count();
83 let num_tasks_per_unit = comptime!(num_stage_lines.div_ceil(total_units));
84 let balanced_workload = comptime!(num_stage_lines.is_multiple_of(total_units));
85 let jump_length = comptime!(total_units * line_size);
86
87 let unit_id = RoleRule::new(config.plane_role_config.rule)
88 .load_index(config.specialization_tensor_config)
89 * config.plane_dim
90 + UNIT_POS_X;
91 let unit_position_base = unit_id * line_size;
92
93 AsyncFullCyclicJob {
94 unit_position_base,
95 runtime_args,
96 num_tasks_per_unit,
97 tile_num_elements,
98 jump_length,
99 copy_line_size: line_size,
100 balanced_workload,
101 num_stage_elements,
102 reader_mode: config.reader_mode,
103 }
104 }
105}
106
107#[derive(CubeType, Clone)]
108pub struct AsyncFullCyclicJob {
109 unit_position_base: u32,
110 runtime_args: RuntimeArgs,
111
112 #[cube(comptime)]
113 num_tasks_per_unit: u32,
114 #[cube(comptime)]
115 tile_num_elements: u32,
116 #[cube(comptime)]
117 jump_length: u32,
118 #[cube(comptime)]
119 copy_line_size: u32,
120 #[cube(comptime)]
121 balanced_workload: bool,
122 #[cube(comptime)]
123 num_stage_elements: u32,
124 #[cube(comptime)]
125 reader_mode: ReaderMode,
126}
127
128#[cube]
129impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
130 LoadingJob<EG, ES, ContiguousTilingLayout<TO>, AsyncCopy> for AsyncFullCyclicJob
131{
132 type Stage = StridedStageFamily;
133
134 fn execute_task(
135 this: &mut Self,
136 #[comptime] task_id: u32,
137 global_iter: &GlobalIterator<Line<EG>>,
138 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
139 _barrier: &mut Shared<Barrier>,
140 #[comptime] config: GlobalReaderConfig,
141 ) {
142 let unit_position = this.unit_position_base + task_id * this.jump_length;
143
144 #[allow(clippy::collapsible_else_if)]
145 if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
146 copy_line::<EG, ES, TO>(
147 this,
148 unit_position,
149 global_iter,
150 stage,
151 &this.runtime_args,
152 config,
153 );
154 } else {
155 if unit_position < this.num_stage_elements {
156 copy_line::<EG, ES, TO>(
157 this,
158 unit_position,
159 global_iter,
160 stage,
161 &this.runtime_args,
162 config,
163 );
164 }
165 }
166 }
167
168 fn task_count(this: &Self) -> comptime_type!(u32) {
169 this.num_tasks_per_unit
170 }
171}
172
173#[cube]
174pub(crate) fn copy_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
175 job: &AsyncFullCyclicJob,
176 unit_position: u32,
177 global_iter: &GlobalIterator<Line<EG>>,
178 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
179 runtime_args: &RuntimeArgs,
180 #[comptime] config: GlobalReaderConfig,
181) {
182 let nth_tile = unit_position / job.tile_num_elements;
183 let pos_within_tile = unit_position % job.tile_num_elements;
184
185 let layout = TiledLayout::new(config.stage_ident, config.smem_config);
186 let view = global_iter.view();
187
188 let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile, config.smem_config);
189
190 let pos = layout.to_source_pos((tile, pos_within_tile));
191 let stage_offset = unit_position / stage.smem.line_size();
192
193 async_copy_from(
194 view,
195 pos,
196 stage,
197 stage_offset,
198 runtime_args,
199 global_iter.offset(),
200 config,
201 job.copy_line_size,
202 );
203}