cubecl_matmul/components/global/read/strategy/
sync_partial_tilewise.rs1use std::marker::PhantomData;
2
3use crate::components::MatmulElems;
4use crate::components::global::GlobalReaderConfig;
5use crate::components::global::read::validate_swizzle_atom_size;
6use crate::components::global::read::{PartialLoadingStrategy, sync::Synchronous};
7use crate::components::global::{RoleRule, read::tiled::TiledLayout};
8use crate::components::stage::StridedStageFamily;
9use crate::components::stage::StridedStageMemory;
10use crate::components::stage::TilingOrderEnum;
11use crate::components::{FormattedConfigError, InvalidConfigError, StageIdent};
12use crate::components::{
13 global::memory::GlobalIterator,
14 stage::{ContiguousTilingLayout, TilingOrder},
15};
16use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation};
17use cubecl_core as cubecl;
18use cubecl_core::prelude::*;
19use cubecl_std::{tensor::layout::Coords2d, type_size};
20
21use super::{LoadingJob, LoadingValidation};
22
23#[derive(CubeType, Clone, Copy)]
24pub struct SyncPartialTilewiseLoading<T: TilingOrder> {
32 #[cube(comptime)]
33 tiling_order: PhantomData<T>,
34}
35
36impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncPartialTilewiseLoading<TO> {
37 fn max_round_plane_count(
38 _elements_per_tile: u32,
39 tiles_per_stage: u32,
40 _line_size: u8,
41 _plane_dim: u32,
42 ) -> u32 {
43 tiles_per_stage
44 }
45}
46
47impl<T: TilingOrder> LoadingValidation for SyncPartialTilewiseLoading<T> {
48 fn check<R: Runtime>(
49 _client: &ComputeClient<R::Server>,
50 config: &GlobalReaderConfig,
51 dtypes: &MatmulElems,
52 ) -> Result<(), InvalidConfigError> {
53 let line_size = config.gmem_config.line_size;
54 let num_planes = config.loading_planes_count();
55 let num_tiles = config.smem_config.tiles_per_stage();
56
57 if !num_tiles.is_multiple_of(num_planes) {
58 return Err(FormattedConfigError::new(move || {
59 "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for tilewise loading.".to_string()
60 }));
61 }
62
63 let num_tiles_per_plane = comptime!(num_tiles / num_planes);
64 let num_lines_per_tile = comptime!(config.smem_config.elements_per_tile() / line_size);
65 let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
66 let num_planes = config.plane_dim;
67
68 if num_lines_per_plane % num_planes != 0 {
69 return Err(FormattedConfigError::new(move || {
70 "Number of planes {num_planes:?} must divide number of lines per plane {num_lines_per_plane:?} for tilewise loading.".to_string()
71 }));
72 }
73
74 match config.stage_ident {
75 StageIdent::Lhs => {
76 if !matches!(T::to_enum(), TilingOrderEnum::RowMajor) {
77 return Err(FormattedConfigError::new(move || {
78 "Sync partial tilewise on Lhs is only supported with RowMajor tiling order"
79 .to_string()
80 }));
81 }
82 }
83 StageIdent::Rhs => {
84 if !matches!(T::to_enum(), TilingOrderEnum::ColMajor) {
85 return Err(FormattedConfigError::new(move || {
86 "Sync partial tilewise on Rhs is only supported with ColMajor tiling order"
87 .to_string()
88 }));
89 }
90 }
91 _ => unreachable!(),
92 }
93
94 validate_swizzle_atom_size(config.smem_config, config.stage_ident, dtypes)?;
95 ContiguousTilingLayout::<T>::check(config.smem_config)?;
96
97 Ok(())
98 }
99}
100
101#[cube]
102impl<TO: TilingOrder> PartialLoadingStrategy for SyncPartialTilewiseLoading<TO> {
103 type TilingLayout = ContiguousTilingLayout<TO>;
104 type SyncStrategy = Synchronous;
105 type Stage = StridedStageFamily;
106
107 type Job<EG: Numeric, ES: Numeric> = SyncPartialTilewiseJob;
108
109 fn new_job<EG: Numeric, ES: Numeric>(
110 #[comptime] stage_index: u32,
111 #[comptime] line_size: u32,
112 #[comptime] config: GlobalReaderConfig,
113 ) -> SyncPartialTilewiseJob {
114 let num_planes = config.loading_planes_count();
115 let num_tiles = config.smem_config.tiles_per_stage();
116 let plane_dim = config.plane_dim;
117
118 let num_tiles_per_plane = comptime!(num_tiles / num_planes);
119 let num_lines_per_tile = comptime!(config.smem_config.elements_per_tile() / line_size);
120 let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
121 let num_lines_per_unit = num_lines_per_plane / plane_dim;
122
123 let stage_width = comptime!(match config.stage_ident {
124 StageIdent::Lhs => config.smem_config.tiles_per_stage_along_col(),
125 StageIdent::Rhs => config.smem_config.tiles_per_stage_along_row(),
126 _ => unreachable!(),
127 });
128
129 let num_tiles_to_skip = RoleRule::new(config.plane_role_config.rule)
130 .load_index(config.specialization_tensor_config)
131 * num_tiles_per_plane;
132
133 SyncPartialTilewiseJob {
134 stage_index,
135 num_tiles_to_skip,
136 stage_width,
137 num_lines_per_tile,
138 num_lines_per_unit,
139 plane_dim,
140 line_size,
141 }
142 }
143}
144
145#[derive(CubeType, Clone, Copy)]
146pub struct SyncPartialTilewiseJob {
147 num_tiles_to_skip: u32,
148 stage_index: u32,
149
150 #[cube(comptime)]
151 stage_width: u32,
152 #[cube(comptime)]
153 num_lines_per_tile: u32,
154 #[cube(comptime)]
155 num_lines_per_unit: u32,
156 #[cube(comptime)]
157 plane_dim: u32,
158 #[cube(comptime)]
159 line_size: u32,
160}
161
162#[cube]
163impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
164 LoadingJob<EG, ES, ContiguousTilingLayout<TO>, Synchronous> for SyncPartialTilewiseJob
165{
166 type Stage = StridedStageFamily;
167
168 fn execute_task(
169 this: &mut Self,
170 #[comptime] task_id: u32,
171 global_iter: &GlobalIterator<Line<EG>>,
172 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
173 _barrier: &mut (),
174 #[comptime] config: GlobalReaderConfig,
175 ) {
176 let mut stage = stage.with_buffer_index(this.stage_index);
177 let pos_across_tiles = task_id * this.plane_dim + UNIT_POS_X;
178 let nth_tile_for_this_plane = pos_across_tiles / this.num_lines_per_tile;
179 let line_index_within_tile = pos_across_tiles % this.num_lines_per_tile;
180
181 let nth_tile_global = this.num_tiles_to_skip + nth_tile_for_this_plane;
182
183 let tile = TO::to_row_col(
184 nth_tile_global,
185 config.smem_config.tiles_per_stage_along_row(),
186 config.smem_config.tiles_per_stage_along_col(),
187 config.smem_config,
188 );
189
190 let tile = match comptime![config.stage_ident] {
191 StageIdent::Lhs => (tile.0, tile.1 + this.stage_index * this.stage_width),
192 StageIdent::Rhs => (tile.0 + this.stage_index * this.stage_width, tile.1),
193 _ => tile,
194 };
195
196 let num_lines_to_skip_global = nth_tile_global * this.num_lines_per_tile;
197
198 SyncPartialTilewiseJob::load_and_store_line::<EG, ES, TO>(
199 this,
200 tile,
201 line_index_within_tile,
202 num_lines_to_skip_global,
203 global_iter,
204 &mut stage,
205 config,
206 );
207 }
208
209 fn task_count(this: &Self) -> comptime_type!(u32) {
210 comptime!(this.num_lines_per_unit)
211 }
212}
213
214#[cube]
215impl SyncPartialTilewiseJob {
216 #[allow(clippy::too_many_arguments)]
217 fn load_and_store_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
218 this: &Self,
219 tile: Coords2d,
220 line_index_within_tile: u32,
221 num_lines_to_skip_global: u32,
222 global_iter: &GlobalIterator<Line<EG>>,
223 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
224 #[comptime] config: GlobalReaderConfig,
225 ) {
226 let layout = TiledLayout::new(comptime!(config.smem_config));
227 let view = global_iter.view().view(layout);
228
229 let line_read = view.read_checked((tile, line_index_within_tile * this.line_size));
230
231 let offset = line_index_within_tile + num_lines_to_skip_global;
232 let type_size = type_size::<ES>(this.line_size);
233 let offset = stage.swizzle.apply(offset, type_size);
234
235 stage.as_slice_mut(this.line_size)[offset] = Line::cast_from(line_read);
236 }
237}