cubecl_matmul/components/global/read/strategy/
sync_full_cyclic.rs1use std::marker::PhantomData;
2
3use crate::components::InvalidConfigError;
4use crate::components::MatmulElems;
5use crate::components::global::read::validate_swizzle_atom_size;
6use crate::components::global::read::{FullLoadingStrategy, tiled::TiledLayout};
7use crate::components::global::{GlobalReaderConfig, RoleRule};
8use crate::components::global::{multi_stage::LoadMaxRoundPlaneCount, read::sync::Synchronous};
9use crate::components::stage::StridedStageFamily;
10use crate::components::stage::{ContiguousTilingLayout, StridedStageMemory, TilingOrder};
11use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
12use cubecl_core as cubecl;
13use cubecl_core::prelude::*;
14
15use super::{LoadingJob, LoadingValidation, ReaderMode};
16
17#[derive(CubeType, Clone, Copy)]
18pub struct SyncFullCyclicLoading<T: TilingOrder> {
21 #[cube(comptime)]
22 _t: PhantomData<T>,
23}
24
25impl<TO: TilingOrder> LoadingValidation for SyncFullCyclicLoading<TO> {
26 fn check<R: Runtime>(
27 _client: &ComputeClient<R::Server>,
28 config: &GlobalReaderConfig,
29 dtypes: &MatmulElems,
30 ) -> Result<(), InvalidConfigError> {
31 if let ReaderMode::Strict = config.reader_mode {
32 let line_size = config.gmem_config.line_size;
33
34 let num_stage_lines = config.smem_config.elements_per_stage() / line_size;
35 let total_units = config.loading_units_count();
36
37 if !num_stage_lines.is_multiple_of(total_units) {
38 return Err(Box::new(
39 "Too many data will be loaded, resulting in out of bounds.
40 Try setting line size and number of planes so that total unit count {:?} divides number of lines in stage.",
41 ));
42 }
43 }
44
45 validate_swizzle_atom_size(config.smem_config, config.stage_ident, dtypes)?;
46 ContiguousTilingLayout::<TO>::check(config.smem_config)?;
47
48 Ok(())
49 }
50}
51
52impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncFullCyclicLoading<TO> {
53 fn max_round_plane_count(
54 elements_per_tile: u32,
55 tiles_per_stage: u32,
56 line_size: u8,
57 plane_dim: u32,
58 ) -> u32 {
59 let elements_per_stage = elements_per_tile * tiles_per_stage;
60 let num_lines = elements_per_stage / line_size as u32;
61 num_lines.div_ceil(plane_dim)
62 }
63}
64
65#[cube]
66impl<TO: TilingOrder> FullLoadingStrategy for SyncFullCyclicLoading<TO> {
67 type TilingLayout = ContiguousTilingLayout<TO>;
68 type SyncStrategy = Synchronous;
69 type Job<EG: Numeric, ES: Numeric> = SyncFullCyclicJob;
70
71 fn new_job<EG: Numeric, ES: Numeric>(
72 #[comptime] line_size: u32,
73 #[comptime] config: GlobalReaderConfig,
74 ) -> Self::Job<EG, ES> {
75 let tile_num_elements = config.smem_config.elements_per_tile();
76 let num_stage_elements = config.smem_config.elements_per_stage();
77
78 let num_stage_lines = num_stage_elements.div_ceil(line_size);
79 let total_units = config.loading_units_count();
80 let num_tasks_per_unit = comptime!(num_stage_lines.div_ceil(total_units));
81 let balanced_workload = comptime!(num_stage_lines.is_multiple_of(total_units));
82 let jump_length = comptime!(total_units * line_size);
83
84 let unit_id = RoleRule::new(config.plane_role_config.rule)
85 .load_index(config.specialization_tensor_config)
86 * config.plane_dim
87 + UNIT_POS_X;
88 let unit_position_base = unit_id * line_size;
89
90 SyncFullCyclicJob {
91 unit_position_base,
92 num_tasks_per_unit,
93 tile_num_elements,
94 jump_length,
95 line_size,
96 balanced_workload,
97 num_stage_elements,
98 reader_mode: config.reader_mode,
99 }
100 }
101}
102
103#[derive(CubeType, Clone, Copy)]
104pub struct SyncFullCyclicJob {
105 unit_position_base: u32,
106
107 #[cube(comptime)]
108 num_tasks_per_unit: u32,
109 #[cube(comptime)]
110 tile_num_elements: u32,
111 #[cube(comptime)]
112 jump_length: u32,
113 #[cube(comptime)]
114 line_size: u32,
115 #[cube(comptime)]
116 balanced_workload: bool,
117 #[cube(comptime)]
118 num_stage_elements: u32,
119 #[cube(comptime)]
120 reader_mode: ReaderMode,
121}
122
123#[cube]
124impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
125 LoadingJob<EG, ES, ContiguousTilingLayout<TO>, Synchronous> for SyncFullCyclicJob
126{
127 type Stage = StridedStageFamily;
128
129 fn execute_task(
130 this: &mut Self,
131 #[comptime] task_id: u32,
132 global_iter: &GlobalIterator<Line<EG>>,
133 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
134 _barrier: &mut (),
135 #[comptime] config: GlobalReaderConfig,
136 ) {
137 let unit_position = this.unit_position_base + task_id * this.jump_length;
138
139 #[allow(clippy::collapsible_else_if)]
140 if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
141 load_and_store_line::<EG, ES, TO>(this, unit_position, global_iter, stage, config);
142 } else {
143 if unit_position < this.num_stage_elements {
144 load_and_store_line::<EG, ES, TO>(this, unit_position, global_iter, stage, config);
145 }
146 }
147 }
148
149 fn task_count(this: &Self) -> comptime_type!(u32) {
150 this.num_tasks_per_unit
151 }
152}
153
154#[cube]
155pub(crate) fn load_and_store_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
156 job: &SyncFullCyclicJob,
157 unit_position: u32,
158 global_iter: &GlobalIterator<Line<EG>>,
159 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
160 #[comptime] config: GlobalReaderConfig,
161) {
162 let line_size = job.line_size;
163 let nth_tile = unit_position / job.tile_num_elements;
164 let pos_within_tile = unit_position % job.tile_num_elements;
165
166 let layout = TiledLayout::new(comptime![config.smem_config]);
167 let view = global_iter.view().view(layout);
168
169 let tile = ContiguousTilingLayout::<TO>::to_x_y(nth_tile, comptime!(config.smem_config));
170
171 let mut slice = stage.as_slice_mut(line_size);
172
173 let line_read = view.read_checked((tile, pos_within_tile));
174 let stage_offs = stage.swizzle.apply(unit_position, ES::type_size());
175
176 slice[stage_offs / job.line_size] = Line::cast_from(line_read);
177}