cubecl_matmul/components/global/read/strategy/
sync_full_cyclic.rs

1use std::marker::PhantomData;
2
3use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
4use crate::components::global::read::{SyncFullLoadingStrategy, tiled::TiledLayout};
5use crate::components::global::{GlobalConfig, RoleRule};
6use crate::components::stage::{ContiguousTilingLayout, StridedStage, TilingOrder};
7use crate::components::{InvalidConfigError, MatmulIdent};
8use crate::components::{MatrixPrecision, TilingScheme};
9use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
10use cubecl_core as cubecl;
11use cubecl_core::prelude::*;
12
13use super::{LoadingJob, LoadingValidation, ReaderMode};
14
15#[derive(CubeType, Clone, Copy)]
16/// Loads the content of all tiles in the stage using all planes.
17/// Unit with pos X loads lines with indices X, X + NUM_UNITS, X + 2 * NUM_UNITS, ...
18pub struct SyncFullCyclicLoading<T: TilingOrder> {
19    #[cube(comptime)]
20    _t: PhantomData<T>,
21}
22
23impl<TO: TilingOrder> LoadingValidation for SyncFullCyclicLoading<TO> {
24    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
25        if let ReaderMode::Strict = config.reader_mode() {
26            let line_size = config.global_line_size(ident);
27
28            let num_stage_lines = config.tiling_scheme().elements_in_stage(ident) / line_size;
29            let total_units = config.num_loading_planes(ident) * config.plane_dim();
30
31            if !num_stage_lines.is_multiple_of(total_units) {
32                return Err(Box::new(
33                "Too many data will be loaded, resulting in out of bounds.
34        Try setting line size and number of planes so that total unit count {:?} divides number of lines in stage.",
35            ));
36            }
37        }
38
39        ContiguousTilingLayout::<TO>::check(config.global_memory_config(ident))?;
40
41        Ok(())
42    }
43}
44
45impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncFullCyclicLoading<TO> {
46    fn max_round_plane_count(
47        tiling_scheme: &TilingScheme,
48        ident: MatmulIdent,
49        line_size: u8,
50        plane_dim: u32,
51    ) -> u32 {
52        let num_lines = tiling_scheme.elements_in_stage(ident) / line_size as u32;
53        num_lines.div_ceil(plane_dim)
54    }
55}
56
57#[cube]
58impl<TO: TilingOrder> SyncFullLoadingStrategy for SyncFullCyclicLoading<TO> {
59    type TilingLayout = ContiguousTilingLayout<TO>;
60    type Job<IP: MatrixPrecision> = SyncFullCyclicJob;
61
62    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
63        #[comptime] ident: MatmulIdent,
64        #[comptime] line_size: u32,
65        #[comptime] config: G,
66    ) -> Self::Job<IP> {
67        let tile_num_elements = config.tiling_scheme().elements_in_tile(ident);
68        let num_stage_elements = config.tiling_scheme().elements_in_stage(ident);
69
70        let num_stage_lines = num_stage_elements.div_ceil(line_size);
71        let total_units = comptime!(config.num_loading_planes(ident) * config.plane_dim());
72        let num_tasks_per_unit = comptime!(num_stage_lines.div_ceil(total_units));
73        let balanced_workload = comptime!(num_stage_lines.is_multiple_of(total_units));
74        let jump_length = comptime!(total_units * line_size);
75
76        let unit_id = RoleRule::new(config.role_rule_config())
77            .load_index(ident, config.specialized_loading_sides())
78            * config.plane_dim()
79            + UNIT_POS_X;
80        let unit_position_base = unit_id * line_size;
81
82        SyncFullCyclicJob {
83            unit_position_base,
84            num_tasks_per_unit,
85            tile_num_elements,
86            jump_length,
87            line_size,
88            ident,
89            balanced_workload,
90            num_stage_elements,
91            reader_mode: comptime!(config.reader_mode()),
92        }
93    }
94}
95
96#[derive(CubeType, Clone, Copy)]
97pub struct SyncFullCyclicJob {
98    unit_position_base: u32,
99
100    #[cube(comptime)]
101    num_tasks_per_unit: u32,
102    #[cube(comptime)]
103    tile_num_elements: u32,
104    #[cube(comptime)]
105    jump_length: u32,
106    #[cube(comptime)]
107    line_size: u32,
108    #[cube(comptime)]
109    ident: MatmulIdent,
110    #[cube(comptime)]
111    balanced_workload: bool,
112    #[cube(comptime)]
113    num_stage_elements: u32,
114    #[cube(comptime)]
115    reader_mode: ReaderMode,
116}
117
118#[cube]
119impl<IP: MatrixPrecision, TO: TilingOrder> LoadingJob<IP, ContiguousTilingLayout<TO>>
120    for SyncFullCyclicJob
121{
122    fn execute_task<G: GlobalConfig>(
123        this: &mut Self,
124        #[comptime] task_id: u32,
125        global_iter: &GlobalIterator<Line<IP::Global>>,
126        stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
127        #[comptime] config: G,
128    ) {
129        let unit_position = this.unit_position_base + task_id * this.jump_length;
130
131        #[allow(clippy::collapsible_else_if)]
132        if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
133            load_and_store_line::<IP, TO, G>(this, unit_position, global_iter, stage, config);
134        } else {
135            if unit_position < this.num_stage_elements {
136                load_and_store_line::<IP, TO, G>(this, unit_position, global_iter, stage, config);
137            }
138        }
139    }
140
141    fn task_count(this: &Self) -> comptime_type!(u32) {
142        this.num_tasks_per_unit
143    }
144}
145
146#[cube]
147pub(crate) fn load_and_store_line<IP: MatrixPrecision, TO: TilingOrder, G: GlobalConfig>(
148    job: &SyncFullCyclicJob,
149    unit_position: u32,
150    global_iter: &GlobalIterator<Line<IP::Global>>,
151    stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
152    #[comptime] config: G,
153) {
154    let nth_tile = unit_position / job.tile_num_elements;
155    let pos_within_tile = unit_position % job.tile_num_elements;
156
157    let layout = TiledLayout::new(comptime![config.global_memory_config(job.ident)]);
158    let view = global_iter.view().view(layout);
159
160    let tile = ContiguousTilingLayout::<TO>::to_x_y(
161        nth_tile,
162        comptime!(config.stage_memory_config(job.ident)),
163    );
164
165    let line_read = view.read_checked((tile, pos_within_tile));
166
167    stage.as_slice_mut(job.line_size)[unit_position / job.line_size] = Line::cast_from(line_read);
168}