cubecl_matmul/components/global/read/strategy/
sync_partial_cyclic.rs

1use std::marker::PhantomData;
2
3use crate::components::MatmulElems;
4use crate::components::global::read::validate_swizzle_atom_size;
5use crate::components::global::read::{PartialLoadingStrategy, tiled::TiledLayout};
6use crate::components::global::{GlobalReaderConfig, RoleRule};
7use crate::components::global::{multi_stage::LoadMaxRoundPlaneCount, read::sync::Synchronous};
8use crate::components::stage::StridedStageFamily;
9use crate::components::stage::StridedStageMemory;
10use crate::components::stage::{ContiguousTilingLayout, TilingOrder};
11use crate::components::{InvalidConfigError, StageIdent};
12use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
13use cubecl_core as cubecl;
14use cubecl_core::prelude::*;
15use cubecl_std::type_size;
16
17use super::{LoadingJob, LoadingValidation, ReaderMode};
18
19#[derive(CubeType, Clone, Copy)]
20/// Loads the content of all tiles in the stage using all planes.
21/// Unit with pos X loads lines with indices X, X + NUM_UNITS, X + 2 * NUM_UNITS, ...
22pub struct SyncPartialCyclicLoading<T: TilingOrder> {
23    #[cube(comptime)]
24    _phantom: PhantomData<T>,
25}
26
27impl<TO: TilingOrder> LoadingValidation for SyncPartialCyclicLoading<TO> {
28    fn check<R: Runtime>(
29        _client: &ComputeClient<R::Server>,
30        config: &GlobalReaderConfig,
31        dtypes: &MatmulElems,
32    ) -> Result<(), InvalidConfigError> {
33        if let ReaderMode::Strict = config.reader_mode {
34            let line_size = config.gmem_config.line_size;
35            let num_lines_per_tile = config.smem_config.elements_per_tile() / line_size;
36            let num_tiles_in_stage = config.smem_config.tiles_per_stage();
37            let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
38
39            let total_units = config.loading_units_count();
40            let jump_length = total_units * line_size;
41            let num_tasks_per_unit = total_num_lines.div_ceil(total_units);
42
43            let max_id = total_units - 1;
44            let max_task_id = num_tasks_per_unit - 1;
45            let max_position_base = max_id * line_size;
46            let max_position = max_position_base + max_task_id * jump_length;
47            let num_stage_elements = config.smem_config.elements_per_stage();
48
49            if max_position > num_stage_elements {
50                return Err(Box::new(
51                    "Too many data will be loaded, resulting in out-of-bounds",
52                ));
53            }
54        }
55
56        validate_swizzle_atom_size(config.smem_config, config.stage_ident, dtypes)?;
57        ContiguousTilingLayout::<TO>::check(config.smem_config)?;
58
59        Ok(())
60    }
61}
62
63impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncPartialCyclicLoading<TO> {
64    fn max_round_plane_count(
65        elements_per_tile: u32,
66        tiles_per_stage: u32,
67        line_size: u8,
68        plane_dim: u32,
69    ) -> u32 {
70        let num_lines_per_tile = elements_per_tile / line_size as u32;
71        let total_num_lines = tiles_per_stage * num_lines_per_tile;
72        total_num_lines.div_ceil(plane_dim)
73    }
74}
75
76#[cube]
77impl<TO: TilingOrder> PartialLoadingStrategy for SyncPartialCyclicLoading<TO> {
78    type TilingLayout = ContiguousTilingLayout<TO>;
79    type SyncStrategy = Synchronous;
80    type Stage = StridedStageFamily;
81
82    type Job<EG: Numeric, ES: Numeric> = SyncPartialCyclicJob;
83
84    fn new_job<EG: Numeric, ES: Numeric>(
85        #[comptime] stage_index: u32,
86        #[comptime] line_size: u32,
87        #[comptime] config: GlobalReaderConfig,
88    ) -> SyncPartialCyclicJob {
89        let num_stage_elements = config.smem_config.elements_per_stage();
90
91        let tile_size = config.smem_config.elements_per_tile();
92        let tile_count_row = config.smem_config.tiles_per_stage_along_row();
93        let tile_count_col = config.smem_config.tiles_per_stage_along_col();
94
95        let num_lines_per_tile = tile_size / line_size;
96        let total_units = config.loading_units_count();
97
98        let num_tiles_in_stage = tile_count_row * tile_count_col;
99        let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
100        let balanced_workload = total_num_lines.is_multiple_of(total_units);
101        let num_tasks_per_unit = total_num_lines.div_ceil(total_units);
102        let jump_length = total_units * line_size;
103
104        let plane_id = RoleRule::new(config.plane_role_config.rule)
105            .load_index(config.specialization_tensor_config);
106        let unit_id = plane_id * config.plane_dim + UNIT_POS_X;
107        let unit_position_base = unit_id * line_size;
108
109        SyncPartialCyclicJob {
110            unit_position_base,
111            num_tasks_per_unit,
112            stage_index,
113            jump_length,
114            num_lines_per_tile,
115            balanced_workload,
116            num_stage_elements,
117            reader_mode: config.reader_mode,
118        }
119    }
120}
121
122#[derive(CubeType, Clone, Copy)]
123pub struct SyncPartialCyclicJob {
124    unit_position_base: u32,
125
126    #[cube(comptime)]
127    num_tasks_per_unit: u32,
128    #[cube(comptime)]
129    stage_index: u32,
130    #[cube(comptime)]
131    jump_length: u32,
132    #[cube(comptime)]
133    num_lines_per_tile: u32,
134    #[cube(comptime)]
135    balanced_workload: bool,
136    #[cube(comptime)]
137    num_stage_elements: u32,
138    #[cube(comptime)]
139    reader_mode: ReaderMode,
140}
141
142#[cube]
143impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
144    LoadingJob<EG, ES, ContiguousTilingLayout<TO>, Synchronous> for SyncPartialCyclicJob
145{
146    type Stage = StridedStageFamily;
147
148    fn execute_task(
149        this: &mut Self,
150        #[comptime] task_id: u32,
151        global_iter: &GlobalIterator<Line<EG>>,
152        stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
153        _barrier: &mut (),
154        #[comptime] config: GlobalReaderConfig,
155    ) {
156        let unit_position = this.unit_position_base + task_id * this.jump_length;
157        let mut stage = stage.with_buffer_index(this.stage_index);
158
159        #[allow(clippy::collapsible_else_if)]
160        if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
161            load_and_store_line::<EG, ES, TO>(this, unit_position, global_iter, &mut stage, config);
162        } else {
163            if unit_position < this.num_stage_elements {
164                load_and_store_line::<EG, ES, TO>(
165                    this,
166                    unit_position,
167                    global_iter,
168                    &mut stage,
169                    config,
170                );
171            }
172        }
173    }
174
175    fn task_count(this: &Self) -> comptime_type!(u32) {
176        this.num_tasks_per_unit
177    }
178}
179
180#[cube]
181pub(crate) fn load_and_store_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
182    job: &SyncPartialCyclicJob,
183    unit_position: u32,
184    global_iter: &GlobalIterator<Line<EG>>,
185    stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
186    #[comptime] config: GlobalReaderConfig,
187) {
188    let layout = TiledLayout::new(comptime!(config.smem_config));
189    let view = global_iter.view().view(layout);
190
191    let (tile_size, tile_count_row, tile_count_col) = comptime! {
192        (
193            config.smem_config.elements_per_tile(),
194            config.smem_config.tiles_per_stage_along_row(),
195            config.smem_config.tiles_per_stage_along_col(),
196        )
197    };
198    let line_size = view.line_size();
199
200    let tile_index = unit_position / tile_size;
201    let pos_within_tile = unit_position % tile_size;
202
203    let (tile_x_within_stage, tile_y_within_stage) = TO::to_row_col(
204        tile_index,
205        tile_count_row,
206        tile_count_col,
207        comptime!(config.smem_config),
208    );
209
210    let tile = match comptime!(config.stage_ident) {
211        StageIdent::Lhs => (
212            tile_x_within_stage,
213            job.stage_index * tile_count_col + tile_y_within_stage,
214        ),
215        StageIdent::Rhs => (
216            job.stage_index * tile_count_row + tile_x_within_stage,
217            tile_y_within_stage,
218        ),
219        _ => comptime!(unreachable!()),
220    };
221
222    let line_read = view.read_checked((tile, pos_within_tile));
223
224    let tile_start = tile_index * job.num_lines_per_tile;
225    let mut tile_slice = stage.as_slice_mut(line_size);
226    let offset = tile_start + pos_within_tile / line_size;
227    let type_size = type_size::<ES>(line_size);
228    let offset = stage.swizzle.apply(offset, type_size);
229
230    tile_slice[offset] = Line::cast_from(line_read);
231}