cubecl_matmul/components/global/read/strategy/
sync_full_tilewise.rs1use std::marker::PhantomData;
2
3use crate::components::MatmulElems;
4use crate::components::global::GlobalReaderConfig;
5use crate::components::global::read::validate_swizzle_atom_size;
6use crate::components::global::read::{FullLoadingStrategy, sync::Synchronous};
7use crate::components::global::{RoleRule, read::tiled::TiledLayout};
8use crate::components::stage::StridedStageFamily;
9use crate::components::stage::{StridedStageMemory, TilingOrder};
10use crate::components::{FormattedConfigError, InvalidConfigError};
11use crate::components::{global::memory::GlobalIterator, stage::ContiguousTilingLayout};
12use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation};
13use cubecl_core as cubecl;
14use cubecl_core::prelude::*;
15use cubecl_std::{tensor::layout::Coords2d, type_size};
16
17use super::{LoadingJob, LoadingValidation};
18
19#[derive(CubeType, Clone, Copy)]
20pub struct SyncFullTilewiseLoading<T: TilingOrder> {
29 #[cube(comptime)]
30 tiling_order: PhantomData<T>,
31}
32
33impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncFullTilewiseLoading<TO> {
34 fn max_round_plane_count(
35 _elements_per_tile: u32,
36 tiles_per_stage: u32,
37 _line_size: u8,
38 _plane_dim: u32,
39 ) -> u32 {
40 tiles_per_stage
41 }
42}
43
44impl<T: TilingOrder> LoadingValidation for SyncFullTilewiseLoading<T> {
45 fn check<R: Runtime>(
46 _client: &ComputeClient<R::Server>,
47 config: &GlobalReaderConfig,
48
49 dtypes: &MatmulElems,
50 ) -> Result<(), InvalidConfigError> {
51 let line_size = config.gmem_config.line_size;
52 let num_planes = config.loading_planes_count();
53 let num_tiles = config.smem_config.tiles_per_stage();
54
55 if !num_tiles.is_multiple_of(num_planes) {
56 return Err(FormattedConfigError::new(move || {
57 format!(
58 "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for tilewise loading.",
59 )
60 }));
61 }
62
63 let num_tiles_per_plane = comptime!(num_tiles / num_planes);
64 let num_lines_per_tile = comptime!(config.smem_config.elements_per_tile() / line_size);
65 let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
66 let plane_dim = config.plane_dim;
67
68 if num_lines_per_plane % plane_dim != 0 {
69 return Err(FormattedConfigError::new(move || {
70 format!(
71 "Plane dimension {plane_dim:?} must divide number of lines per plane {num_lines_per_plane:?} for tilewise loading.",
72 )
73 }));
74 }
75
76 validate_swizzle_atom_size(config.smem_config, config.stage_ident, dtypes)?;
77 ContiguousTilingLayout::<T>::check(config.smem_config)?;
78
79 Ok(())
80 }
81}
82
83#[cube]
84impl<TO: TilingOrder> FullLoadingStrategy for SyncFullTilewiseLoading<TO> {
85 type TilingLayout = ContiguousTilingLayout<TO>;
86 type SyncStrategy = Synchronous;
87 type Job<EG: Numeric, ES: Numeric> = SyncFullTilewiseJob;
88
89 fn new_job<EG: Numeric, ES: Numeric>(
90 #[comptime] line_size: u32,
91 #[comptime] config: GlobalReaderConfig,
92 ) -> Self::Job<EG, ES> {
93 let num_planes = config.loading_planes_count();
94 let num_tiles = config.smem_config.tiles_per_stage();
95
96 let num_tiles_per_plane = comptime!(num_tiles / num_planes);
97 let num_lines_per_tile = comptime!(config.smem_config.elements_per_tile() / line_size);
98 let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
99 let num_lines_per_unit = num_lines_per_plane / config.plane_dim;
100
101 let num_tiles_to_skip = RoleRule::new(config.plane_role_config.rule)
102 .load_index(config.specialization_tensor_config)
103 * num_tiles_per_plane;
104 let num_lines_to_skip = num_tiles_to_skip * num_lines_per_tile;
105
106 SyncFullTilewiseJob {
107 num_tiles_to_skip,
108 num_lines_to_skip,
109 num_lines_per_tile,
110 num_lines_per_unit,
111 plane_dim: config.plane_dim,
112 line_size,
113 }
114 }
115}
116
117#[derive(CubeType, Clone, Copy)]
118pub struct SyncFullTilewiseJob {
119 pub num_tiles_to_skip: u32,
120 pub num_lines_to_skip: u32,
121
122 #[cube(comptime)]
123 pub num_lines_per_tile: u32,
124 #[cube(comptime)]
125 pub num_lines_per_unit: u32,
126 #[cube(comptime)]
127 pub plane_dim: u32,
128 #[cube(comptime)]
129 pub line_size: u32,
130}
131
132#[cube]
133impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
134 LoadingJob<EG, ES, ContiguousTilingLayout<TO>, Synchronous> for SyncFullTilewiseJob
135{
136 type Stage = StridedStageFamily;
137
138 fn execute_task(
139 this: &mut Self,
140 #[comptime] task_id: u32,
141 global_iter: &GlobalIterator<Line<EG>>,
142 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
143 _barrier: &mut (),
144 #[comptime] config: GlobalReaderConfig,
145 ) {
146 let pos_across_tiles = task_id * this.plane_dim + UNIT_POS_X;
147 let nth_tile_for_this_plane = pos_across_tiles / this.num_lines_per_tile;
148 let line_index_within_tile = pos_across_tiles % this.num_lines_per_tile;
149
150 let nth_tile_global = nth_tile_for_this_plane + this.num_tiles_to_skip;
151 let tile =
152 ContiguousTilingLayout::<TO>::to_x_y(nth_tile_global, comptime!(config.smem_config));
153
154 SyncFullTilewiseJob::load_and_store_line::<EG, ES, TO>(
155 this,
156 tile,
157 line_index_within_tile,
158 nth_tile_for_this_plane * this.num_lines_per_tile,
159 global_iter,
160 stage,
161 config,
162 );
163 }
164
165 fn task_count(this: &Self) -> comptime_type!(u32) {
166 comptime!(this.num_lines_per_unit)
167 }
168}
169
170#[cube]
171impl SyncFullTilewiseJob {
172 #[allow(clippy::too_many_arguments)]
173 fn load_and_store_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
174 this: &Self,
175 tile: Coords2d,
176 line_index_within_tile: u32,
177 num_lines_to_skip_local: u32,
178 global_iter: &GlobalIterator<Line<EG>>,
179 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
180 #[comptime] config: GlobalReaderConfig,
181 ) {
182 let layout = TiledLayout::new(comptime!(config.smem_config));
183 let view = global_iter.view().view(layout);
184
185 let line_read = view.read_checked((tile, line_index_within_tile * this.line_size));
186
187 let offset = this.num_lines_to_skip + line_index_within_tile + num_lines_to_skip_local;
188 let type_size = type_size::<ES>(this.line_size);
189 let offset = stage.swizzle.apply(offset, type_size);
190
191 stage.as_slice_mut(this.line_size)[offset] = Line::cast_from(line_read);
192 }
193}