cubecl_matmul/components/global/read/strategy/
sync_partial_cyclic.rs1use std::marker::PhantomData;
2
3use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
4use crate::components::global::read::{SyncPartialLoadingStrategy, tiled::TiledLayout};
5use crate::components::global::{GlobalConfig, RoleRule};
6use crate::components::stage::{ContiguousTilingLayout, StridedStage, TilingOrder};
7use crate::components::{InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme};
8use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::*;
11
12use super::{LoadingJob, LoadingValidation, ReaderMode};
13
14#[derive(CubeType, Clone, Copy)]
15pub struct SyncPartialCyclicLoading<T: TilingOrder> {
18 #[cube(comptime)]
19 _phantom: PhantomData<T>,
20}
21
22impl<TO: TilingOrder> LoadingValidation for SyncPartialCyclicLoading<TO> {
23 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
24 if let ReaderMode::Strict = config.reader_mode() {
25 let line_size = config.global_line_size(ident);
26 let num_lines_per_tile = config.tiling_scheme().elements_in_tile(ident) / line_size;
27 let num_tiles_in_stage = config.tiling_scheme().tiles_in_stage(ident);
28 let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
29
30 let total_units = config.plane_dim() * config.num_loading_planes(ident);
31 let jump_length = total_units * line_size;
32 let num_tasks_per_unit = total_num_lines.div_ceil(total_units);
33
34 let max_id = total_units - 1;
35 let max_task_id = num_tasks_per_unit - 1;
36 let max_position_base = max_id * line_size;
37 let max_position = max_position_base + max_task_id * jump_length;
38 let num_stage_elements = config.tiling_scheme().elements_in_stage(ident);
39
40 if max_position > num_stage_elements {
41 return Err(Box::new(
42 "Too many data will be loaded, resulting in out-of-bounds",
43 ));
44 }
45 }
46
47 ContiguousTilingLayout::<TO>::check(config.global_memory_config(ident))?;
48
49 Ok(())
50 }
51}
52
53impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncPartialCyclicLoading<TO> {
54 fn max_round_plane_count(
55 tiling_scheme: &TilingScheme,
56 ident: MatmulIdent,
57 line_size: u8,
58 plane_dim: u32,
59 ) -> u32 {
60 let num_lines_per_tile = tiling_scheme.elements_in_tile(ident) / line_size as u32;
61 let num_tiles_in_stage = tiling_scheme.tiles_in_stage(ident);
62 let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
63 total_num_lines.div_ceil(plane_dim)
64 }
65}
66
67#[cube]
68impl<TO: TilingOrder> SyncPartialLoadingStrategy for SyncPartialCyclicLoading<TO> {
69 type TilingLayout = ContiguousTilingLayout<TO>;
70 type Job<IP: MatrixPrecision> = SyncPartialCyclicJob;
71
72 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
73 #[comptime] stage_index: u32,
74 #[comptime] ident: MatmulIdent,
75 #[comptime] line_size: u32,
76 #[comptime] config: G,
77 ) -> SyncPartialCyclicJob {
78 let num_stage_elements = config.tiling_scheme().elements_in_stage(ident);
79
80 let tile_size = config.tiling_scheme().elements_in_tile(ident);
81 let tile_count_row = config.tiling_scheme().tiles_in_stage_row(ident);
82 let tile_count_col = config.tiling_scheme().tiles_in_stage_col(ident);
83
84 let num_lines_per_tile = tile_size / line_size;
85 let total_units = config.plane_dim() * config.num_loading_planes(ident);
86
87 let num_tiles_in_stage = tile_count_row * tile_count_col;
88 let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
89 let balanced_workload = total_num_lines.is_multiple_of(total_units);
90 let num_tasks_per_unit = total_num_lines.div_ceil(total_units);
91 let jump_length = total_units * line_size;
92
93 let plane_id = RoleRule::new(config.role_rule_config())
94 .load_index(ident, config.specialized_loading_sides());
95 let unit_id = plane_id * config.plane_dim() + UNIT_POS_X;
96 let unit_position_base = unit_id * line_size;
97
98 SyncPartialCyclicJob {
99 unit_position_base,
100 num_tasks_per_unit,
101 stage_index,
102 jump_length,
103 num_lines_per_tile,
104 ident,
105 balanced_workload,
106 num_stage_elements,
107 reader_mode: comptime!(config.reader_mode()),
108 }
109 }
110}
111
112#[derive(CubeType, Clone, Copy)]
113pub struct SyncPartialCyclicJob {
114 unit_position_base: u32,
115
116 #[cube(comptime)]
117 num_tasks_per_unit: u32,
118 #[cube(comptime)]
119 stage_index: u32,
120 #[cube(comptime)]
121 jump_length: u32,
122 #[cube(comptime)]
123 num_lines_per_tile: u32,
124 #[cube(comptime)]
125 ident: MatmulIdent,
126 #[cube(comptime)]
127 balanced_workload: bool,
128 #[cube(comptime)]
129 num_stage_elements: u32,
130 #[cube(comptime)]
131 reader_mode: ReaderMode,
132}
133
134#[cube]
135impl<IP: MatrixPrecision, TO: TilingOrder> LoadingJob<IP, ContiguousTilingLayout<TO>>
136 for SyncPartialCyclicJob
137{
138 fn execute_task<G: GlobalConfig>(
139 this: &mut Self,
140 #[comptime] task_id: u32,
141 global_iter: &GlobalIterator<Line<IP::Global>>,
142 stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
143 #[comptime] config: G,
144 ) {
145 let unit_position = this.unit_position_base + task_id * this.jump_length;
146
147 #[allow(clippy::collapsible_else_if)]
148 if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
149 load_and_store_line::<IP, TO, G>(this, unit_position, global_iter, stage, config);
150 } else {
151 if unit_position < this.num_stage_elements {
152 load_and_store_line::<IP, TO, G>(this, unit_position, global_iter, stage, config);
153 }
154 }
155 }
156
157 fn task_count(this: &Self) -> comptime_type!(u32) {
158 this.num_tasks_per_unit
159 }
160}
161
162#[cube]
163pub(crate) fn load_and_store_line<IP: MatrixPrecision, TO: TilingOrder, G: GlobalConfig>(
164 job: &SyncPartialCyclicJob,
165 unit_position: u32,
166 global_iter: &GlobalIterator<Line<IP::Global>>,
167 stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
168 #[comptime] config: G,
169) {
170 let layout = TiledLayout::new(comptime!(config.global_memory_config(job.ident)));
171 let view = global_iter.view().view(layout);
172
173 let (tile_size, tile_count_row, tile_count_col) = comptime! {
174 (
175 config.tiling_scheme().elements_in_tile(job.ident),
176 config.tiling_scheme().tiles_in_stage_row(job.ident),
177 config.tiling_scheme().tiles_in_stage_col(job.ident),
178 )
179 };
180 let line_size = view.line_size();
181
182 let tile_index = unit_position / tile_size;
183 let pos_within_tile = unit_position % tile_size;
184
185 let (total_tile_count_row, total_tile_count_col) = match comptime!(job.ident) {
186 MatmulIdent::Lhs => (
187 comptime!(tile_count_row),
188 comptime!(tile_count_col * config.num_stages(MatmulIdent::Lhs)),
189 ),
190 MatmulIdent::Rhs => (
191 comptime!(tile_count_row * config.num_stages(MatmulIdent::Rhs)),
192 comptime!(tile_count_col),
193 ),
194 MatmulIdent::Out => comptime!(unreachable!()),
195 };
196
197 let (tile_x_within_stage, tile_y_within_stage) = TO::to_row_col(
198 tile_index,
199 tile_count_row,
200 tile_count_col,
201 comptime!(config.stage_memory_config(job.ident)),
202 );
203
204 let tile = match comptime!(job.ident) {
205 MatmulIdent::Lhs => (
206 tile_x_within_stage,
207 job.stage_index * tile_count_col + tile_y_within_stage,
208 ),
209 MatmulIdent::Rhs => (
210 job.stage_index * tile_count_row + tile_x_within_stage,
211 tile_y_within_stage,
212 ),
213 MatmulIdent::Out => comptime!(unreachable!()),
214 };
215
216 let line_read = view.read_checked((tile, pos_within_tile));
217
218 let nth_tile_in_stage = TO::to_nth_tile(
219 tile,
220 total_tile_count_row,
221 total_tile_count_col,
222 comptime!(config.stage_memory_config(job.ident)),
223 );
224
225 let tile_start = nth_tile_in_stage * job.num_lines_per_tile;
226 let tile_end = tile_start + job.num_lines_per_tile;
227 let mut tile_slice = stage
228 .as_slice_mut(line_size)
229 .slice_mut(tile_start, tile_end);
230
231 tile_slice[pos_within_tile / line_size] = Line::cast_from(line_read);
232}