cubecl_matmul/components/global/read/strategy/
sync_partial_tilewise.rs1use std::marker::PhantomData;
2
3use crate::components::global::read::SyncPartialLoadingStrategy;
4use crate::components::global::{RoleRule, read::tiled::TiledLayout};
5use crate::components::stage::TilingOrderEnum;
6use crate::components::{
7 FormattedConfigError, InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme,
8};
9use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation};
10use crate::components::{
11 global::{GlobalConfig, memory::GlobalIterator},
12 stage::{ContiguousTilingLayout, StridedStage, TilingOrder},
13};
14use cubecl_core as cubecl;
15use cubecl_core::prelude::*;
16use cubecl_std::tensor::layout::Coords2d;
17
18use super::{LoadingJob, LoadingValidation};
19
20#[derive(CubeType, Clone, Copy)]
21pub struct SyncPartialTilewiseLoading<T: TilingOrder> {
29 #[cube(comptime)]
30 tiling_order: PhantomData<T>,
31}
32
33impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncPartialTilewiseLoading<TO> {
34 fn max_round_plane_count(
35 tiling_scheme: &TilingScheme,
36 ident: MatmulIdent,
37 _line_size: u8,
38 _plane_dim: u32,
39 ) -> u32 {
40 tiling_scheme.tiles_in_stage(ident)
41 }
42}
43
44impl<T: TilingOrder> LoadingValidation for SyncPartialTilewiseLoading<T> {
45 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
46 let line_size = config.global_line_size(ident);
47 let num_planes = config.num_loading_planes(ident);
48 let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
49
50 if !num_tiles.is_multiple_of(num_planes) {
51 return Err(FormattedConfigError::new(move || {
52 "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for tilewise loading.".to_string()
53 }));
54 }
55
56 let num_tiles_per_plane = comptime!(num_tiles / num_planes);
57 let num_lines_per_tile =
58 comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
59 let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
60 let num_planes = config.plane_dim();
61
62 if num_lines_per_plane % num_planes != 0 {
63 return Err(FormattedConfigError::new(move || {
64 "Number of planes {num_planes:?} must divide number of lines per plane {num_lines_per_plane:?} for tilewise loading.".to_string()
65 }));
66 }
67
68 match ident {
69 MatmulIdent::Lhs => {
70 if !matches!(T::to_enum(), TilingOrderEnum::RowMajor) {
71 return Err(FormattedConfigError::new(move || {
72 "Sync partial tilewise on Lhs is only supported with RowMajor tiling order"
73 .to_string()
74 }));
75 }
76 }
77 MatmulIdent::Rhs => {
78 if !matches!(T::to_enum(), TilingOrderEnum::ColMajor) {
79 return Err(FormattedConfigError::new(move || {
80 "Sync partial tilewise on Rhs is only supported with ColMajor tiling order"
81 .to_string()
82 }));
83 }
84 }
85 MatmulIdent::Out => unreachable!(),
86 }
87
88 ContiguousTilingLayout::<T>::check(config.global_memory_config(ident))?;
89
90 Ok(())
91 }
92}
93
94#[cube]
95impl<TO: TilingOrder> SyncPartialLoadingStrategy for SyncPartialTilewiseLoading<TO> {
96 type TilingLayout = ContiguousTilingLayout<TO>;
97 type Job<IP: MatrixPrecision> = SyncPartialTilewiseJob;
98
99 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
100 #[comptime] stage_index: u32,
101 #[comptime] ident: MatmulIdent,
102 #[comptime] line_size: u32,
103 #[comptime] config: G,
104 ) -> SyncPartialTilewiseJob {
105 let num_planes = config.num_loading_planes(ident);
106 let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
107 let plane_dim = config.plane_dim();
108
109 let num_tiles_per_plane = comptime!(num_tiles / num_planes);
110 let num_lines_per_tile =
111 comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
112 let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
113 let num_lines_per_unit = num_lines_per_plane / plane_dim;
114
115 let num_stages = config.num_stages(ident);
116 let stage_width = comptime!(match ident {
117 MatmulIdent::Lhs => config.tiling_scheme().tiles_in_stage_col(ident),
118 MatmulIdent::Rhs => config.tiling_scheme().tiles_in_stage_row(ident),
119 MatmulIdent::Out => unreachable!(),
120 });
121 let row_col_stride = num_stages * stage_width;
122 let stage_offset = stage_width * stage_index;
123
124 let starting_tile_within_stage = RoleRule::new(config.role_rule_config())
125 .load_index(ident, config.specialized_loading_sides())
126 * num_tiles_per_plane;
127 let row_col_index = starting_tile_within_stage / stage_width;
128 let inner_offset = starting_tile_within_stage % stage_width;
129 let num_tiles_to_skip = row_col_index * row_col_stride + inner_offset + stage_offset;
130
131 SyncPartialTilewiseJob {
132 num_tiles_to_skip,
133 row_col_stride,
134 stage_width,
135 num_lines_per_tile,
136 num_lines_per_unit,
137 plane_dim: config.plane_dim(),
138 line_size,
139 ident,
140 }
141 }
142}
143
144#[derive(CubeType, Clone, Copy)]
145pub struct SyncPartialTilewiseJob {
146 num_tiles_to_skip: u32,
147
148 #[cube(comptime)]
149 row_col_stride: u32,
150 #[cube(comptime)]
151 stage_width: u32,
152 #[cube(comptime)]
153 num_lines_per_tile: u32,
154 #[cube(comptime)]
155 num_lines_per_unit: u32,
156 #[cube(comptime)]
157 plane_dim: u32,
158 #[cube(comptime)]
159 line_size: u32,
160 #[cube(comptime)]
161 ident: MatmulIdent,
162}
163
164#[cube]
165impl<IP: MatrixPrecision, TO: TilingOrder> LoadingJob<IP, ContiguousTilingLayout<TO>>
166 for SyncPartialTilewiseJob
167{
168 fn execute_task<G: GlobalConfig>(
169 this: &mut Self,
170 #[comptime] task_id: u32,
171 global_iter: &GlobalIterator<Line<IP::Global>>,
172 stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
173 #[comptime] config: G,
174 ) {
175 let pos_across_tiles = task_id * this.plane_dim + UNIT_POS_X;
176 let nth_tile_for_this_plane = pos_across_tiles / this.num_lines_per_tile;
177 let line_index_within_tile = pos_across_tiles % this.num_lines_per_tile;
178
179 let row_col_index_local = nth_tile_for_this_plane / this.stage_width;
180 let inner_offset = nth_tile_for_this_plane % this.stage_width;
181 let num_tiles_to_skip_local = row_col_index_local * this.row_col_stride + inner_offset;
182 let nth_tile_global = this.num_tiles_to_skip + num_tiles_to_skip_local;
183
184 let (total_tile_count_row, total_tile_count_col) = match comptime!(this.ident) {
185 MatmulIdent::Lhs => (
186 comptime!(config.tiling_scheme().tiles_in_stage_m()),
187 comptime!(
188 config.tiling_scheme().tiles_in_stage_k() * config.num_stages(MatmulIdent::Lhs)
189 ),
190 ),
191 MatmulIdent::Rhs => (
192 comptime!(
193 config.tiling_scheme().tiles_in_stage_k() * config.num_stages(MatmulIdent::Rhs)
194 ),
195 comptime!(config.tiling_scheme().tiles_in_stage_n()),
196 ),
197 MatmulIdent::Out => comptime!(unreachable!()),
198 };
199
200 let tile = TO::to_row_col(
201 nth_tile_global,
202 total_tile_count_row,
203 total_tile_count_col,
204 comptime!(config.stage_memory_config(this.ident)),
205 );
206
207 let num_lines_to_skip_global = nth_tile_global * this.num_lines_per_tile;
208
209 SyncPartialTilewiseJob::load_and_store_line::<IP, TO, G>(
210 this,
211 tile,
212 line_index_within_tile,
213 num_lines_to_skip_global,
214 global_iter,
215 stage,
216 config,
217 );
218 }
219
220 fn task_count(this: &Self) -> comptime_type!(u32) {
221 comptime!(this.num_lines_per_unit)
222 }
223}
224
225#[cube]
226impl SyncPartialTilewiseJob {
227 #[allow(clippy::too_many_arguments)]
228 fn load_and_store_line<IP: MatrixPrecision, TO: TilingOrder, G: GlobalConfig>(
229 this: &Self,
230 tile: Coords2d,
231 line_index_within_tile: u32,
232 num_lines_to_skip_global: u32,
233 global_iter: &GlobalIterator<Line<IP::Global>>,
234 stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
235 #[comptime] config: G,
236 ) {
237 let layout = TiledLayout::new(comptime!(config.global_memory_config(this.ident)));
238 let view = global_iter.view().view(layout);
239
240 let line_read = view.read_checked((tile, line_index_within_tile * this.line_size));
241
242 let offset = line_index_within_tile + num_lines_to_skip_global;
243
244 stage.as_slice_mut(this.line_size)[offset] = Line::cast_from(line_read);
245 }
246}