cubecl_matmul/components/global/read/strategy/
sync_partial_cyclic.rs1use std::marker::PhantomData;
2
3use crate::components::MatmulElems;
4use crate::components::global::read::validate_swizzle_atom_size;
5use crate::components::global::read::{PartialLoadingStrategy, tiled::TiledLayout};
6use crate::components::global::{GlobalReaderConfig, RoleRule};
7use crate::components::global::{multi_stage::LoadMaxRoundPlaneCount, read::sync::Synchronous};
8use crate::components::stage::StridedStageFamily;
9use crate::components::stage::StridedStageMemory;
10use crate::components::stage::{ContiguousTilingLayout, TilingOrder};
11use crate::components::{InvalidConfigError, StageIdent};
12use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
13use cubecl_core as cubecl;
14use cubecl_core::prelude::*;
15use cubecl_std::type_size;
16
17use super::{LoadingJob, LoadingValidation, ReaderMode};
18
19#[derive(CubeType, Clone, Copy)]
20pub struct SyncPartialCyclicLoading<T: TilingOrder> {
23 #[cube(comptime)]
24 _phantom: PhantomData<T>,
25}
26
27impl<TO: TilingOrder> LoadingValidation for SyncPartialCyclicLoading<TO> {
28 fn check<R: Runtime>(
29 _client: &ComputeClient<R::Server>,
30 config: &GlobalReaderConfig,
31 dtypes: &MatmulElems,
32 ) -> Result<(), InvalidConfigError> {
33 if let ReaderMode::Strict = config.reader_mode {
34 let line_size = config.gmem_config.line_size;
35 let num_lines_per_tile = config.smem_config.elements_per_tile() / line_size;
36 let num_tiles_in_stage = config.smem_config.tiles_per_stage();
37 let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
38
39 let total_units = config.loading_units_count();
40 let jump_length = total_units * line_size;
41 let num_tasks_per_unit = total_num_lines.div_ceil(total_units);
42
43 let max_id = total_units - 1;
44 let max_task_id = num_tasks_per_unit - 1;
45 let max_position_base = max_id * line_size;
46 let max_position = max_position_base + max_task_id * jump_length;
47 let num_stage_elements = config.smem_config.elements_per_stage();
48
49 if max_position > num_stage_elements {
50 return Err(Box::new(
51 "Too many data will be loaded, resulting in out-of-bounds",
52 ));
53 }
54 }
55
56 validate_swizzle_atom_size(config.smem_config, config.stage_ident, dtypes)?;
57 ContiguousTilingLayout::<TO>::check(config.smem_config)?;
58
59 Ok(())
60 }
61}
62
63impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncPartialCyclicLoading<TO> {
64 fn max_round_plane_count(
65 elements_per_tile: u32,
66 tiles_per_stage: u32,
67 line_size: u8,
68 plane_dim: u32,
69 ) -> u32 {
70 let num_lines_per_tile = elements_per_tile / line_size as u32;
71 let total_num_lines = tiles_per_stage * num_lines_per_tile;
72 total_num_lines.div_ceil(plane_dim)
73 }
74}
75
76#[cube]
77impl<TO: TilingOrder> PartialLoadingStrategy for SyncPartialCyclicLoading<TO> {
78 type TilingLayout = ContiguousTilingLayout<TO>;
79 type SyncStrategy = Synchronous;
80 type Stage = StridedStageFamily;
81
82 type Job<EG: Numeric, ES: Numeric> = SyncPartialCyclicJob;
83
84 fn new_job<EG: Numeric, ES: Numeric>(
85 #[comptime] stage_index: u32,
86 #[comptime] line_size: u32,
87 #[comptime] config: GlobalReaderConfig,
88 ) -> SyncPartialCyclicJob {
89 let num_stage_elements = config.smem_config.elements_per_stage();
90
91 let tile_size = config.smem_config.elements_per_tile();
92 let tile_count_row = config.smem_config.tiles_per_stage_along_row();
93 let tile_count_col = config.smem_config.tiles_per_stage_along_col();
94
95 let num_lines_per_tile = tile_size / line_size;
96 let total_units = config.loading_units_count();
97
98 let num_tiles_in_stage = tile_count_row * tile_count_col;
99 let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
100 let balanced_workload = total_num_lines.is_multiple_of(total_units);
101 let num_tasks_per_unit = total_num_lines.div_ceil(total_units);
102 let jump_length = total_units * line_size;
103
104 let plane_id = RoleRule::new(config.plane_role_config.rule)
105 .load_index(config.specialization_tensor_config);
106 let unit_id = plane_id * config.plane_dim + UNIT_POS_X;
107 let unit_position_base = unit_id * line_size;
108
109 SyncPartialCyclicJob {
110 unit_position_base,
111 num_tasks_per_unit,
112 stage_index,
113 jump_length,
114 num_lines_per_tile,
115 balanced_workload,
116 num_stage_elements,
117 reader_mode: config.reader_mode,
118 }
119 }
120}
121
122#[derive(CubeType, Clone, Copy)]
123pub struct SyncPartialCyclicJob {
124 unit_position_base: u32,
125
126 #[cube(comptime)]
127 num_tasks_per_unit: u32,
128 #[cube(comptime)]
129 stage_index: u32,
130 #[cube(comptime)]
131 jump_length: u32,
132 #[cube(comptime)]
133 num_lines_per_tile: u32,
134 #[cube(comptime)]
135 balanced_workload: bool,
136 #[cube(comptime)]
137 num_stage_elements: u32,
138 #[cube(comptime)]
139 reader_mode: ReaderMode,
140}
141
142#[cube]
143impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
144 LoadingJob<EG, ES, ContiguousTilingLayout<TO>, Synchronous> for SyncPartialCyclicJob
145{
146 type Stage = StridedStageFamily;
147
148 fn execute_task(
149 this: &mut Self,
150 #[comptime] task_id: u32,
151 global_iter: &GlobalIterator<Line<EG>>,
152 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
153 _barrier: &mut (),
154 #[comptime] config: GlobalReaderConfig,
155 ) {
156 let unit_position = this.unit_position_base + task_id * this.jump_length;
157 let mut stage = stage.with_buffer_index(this.stage_index);
158
159 #[allow(clippy::collapsible_else_if)]
160 if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
161 load_and_store_line::<EG, ES, TO>(this, unit_position, global_iter, &mut stage, config);
162 } else {
163 if unit_position < this.num_stage_elements {
164 load_and_store_line::<EG, ES, TO>(
165 this,
166 unit_position,
167 global_iter,
168 &mut stage,
169 config,
170 );
171 }
172 }
173 }
174
175 fn task_count(this: &Self) -> comptime_type!(u32) {
176 this.num_tasks_per_unit
177 }
178}
179
180#[cube]
181pub(crate) fn load_and_store_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
182 job: &SyncPartialCyclicJob,
183 unit_position: u32,
184 global_iter: &GlobalIterator<Line<EG>>,
185 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
186 #[comptime] config: GlobalReaderConfig,
187) {
188 let layout = TiledLayout::new(comptime!(config.smem_config));
189 let view = global_iter.view().view(layout);
190
191 let (tile_size, tile_count_row, tile_count_col) = comptime! {
192 (
193 config.smem_config.elements_per_tile(),
194 config.smem_config.tiles_per_stage_along_row(),
195 config.smem_config.tiles_per_stage_along_col(),
196 )
197 };
198 let line_size = view.line_size();
199
200 let tile_index = unit_position / tile_size;
201 let pos_within_tile = unit_position % tile_size;
202
203 let (tile_x_within_stage, tile_y_within_stage) = TO::to_row_col(
204 tile_index,
205 tile_count_row,
206 tile_count_col,
207 comptime!(config.smem_config),
208 );
209
210 let tile = match comptime!(config.stage_ident) {
211 StageIdent::Lhs => (
212 tile_x_within_stage,
213 job.stage_index * tile_count_col + tile_y_within_stage,
214 ),
215 StageIdent::Rhs => (
216 job.stage_index * tile_count_row + tile_x_within_stage,
217 tile_y_within_stage,
218 ),
219 _ => comptime!(unreachable!()),
220 };
221
222 let line_read = view.read_checked((tile, pos_within_tile));
223
224 let tile_start = tile_index * job.num_lines_per_tile;
225 let mut tile_slice = stage.as_slice_mut(line_size);
226 let offset = tile_start + pos_within_tile / line_size;
227 let type_size = type_size::<ES>(line_size);
228 let offset = stage.swizzle.apply(offset, type_size);
229
230 tile_slice[offset] = Line::cast_from(line_read);
231}