cubecl_matmul/components/global/read/strategy/
sync_partial_cyclic.rs

1use std::marker::PhantomData;
2
3use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
4use crate::components::global::read::{SyncPartialLoadingStrategy, tiled::TiledLayout};
5use crate::components::global::{GlobalConfig, RoleRule};
6use crate::components::stage::{ContiguousTilingLayout, StridedStage, TilingOrder};
7use crate::components::{InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme};
8use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::*;
11
12use super::{LoadingJob, LoadingValidation, ReaderMode};
13
14#[derive(CubeType, Clone, Copy)]
15/// Loads the content of all tiles in the stage using all planes.
16/// Unit with pos X loads lines with indices X, X + NUM_UNITS, X + 2 * NUM_UNITS, ...
17pub struct SyncPartialCyclicLoading<T: TilingOrder> {
18    #[cube(comptime)]
19    _phantom: PhantomData<T>,
20}
21
22impl<TO: TilingOrder> LoadingValidation for SyncPartialCyclicLoading<TO> {
23    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
24        if let ReaderMode::Strict = config.reader_mode() {
25            let line_size = config.global_line_size(ident);
26            let num_lines_per_tile = config.tiling_scheme().elements_in_tile(ident) / line_size;
27            let num_tiles_in_stage = config.tiling_scheme().tiles_in_stage(ident);
28            let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
29
30            let total_units = config.plane_dim() * config.num_loading_planes(ident);
31            let jump_length = total_units * line_size;
32            let num_tasks_per_unit = total_num_lines.div_ceil(total_units);
33
34            let max_id = total_units - 1;
35            let max_task_id = num_tasks_per_unit - 1;
36            let max_position_base = max_id * line_size;
37            let max_position = max_position_base + max_task_id * jump_length;
38            let num_stage_elements = config.tiling_scheme().elements_in_stage(ident);
39
40            if max_position > num_stage_elements {
41                return Err(Box::new(
42                    "Too many data will be loaded, resulting in out-of-bounds",
43                ));
44            }
45        }
46
47        ContiguousTilingLayout::<TO>::check(config.global_memory_config(ident))?;
48
49        Ok(())
50    }
51}
52
53impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncPartialCyclicLoading<TO> {
54    fn max_round_plane_count(
55        tiling_scheme: &TilingScheme,
56        ident: MatmulIdent,
57        line_size: u8,
58        plane_dim: u32,
59    ) -> u32 {
60        let num_lines_per_tile = tiling_scheme.elements_in_tile(ident) / line_size as u32;
61        let num_tiles_in_stage = tiling_scheme.tiles_in_stage(ident);
62        let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
63        total_num_lines.div_ceil(plane_dim)
64    }
65}
66
67#[cube]
68impl<TO: TilingOrder> SyncPartialLoadingStrategy for SyncPartialCyclicLoading<TO> {
69    type TilingLayout = ContiguousTilingLayout<TO>;
70    type Job<IP: MatrixPrecision> = SyncPartialCyclicJob;
71
72    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
73        #[comptime] stage_index: u32,
74        #[comptime] ident: MatmulIdent,
75        #[comptime] line_size: u32,
76        #[comptime] config: G,
77    ) -> SyncPartialCyclicJob {
78        let num_stage_elements = config.tiling_scheme().elements_in_stage(ident);
79
80        let tile_size = config.tiling_scheme().elements_in_tile(ident);
81        let tile_count_row = config.tiling_scheme().tiles_in_stage_row(ident);
82        let tile_count_col = config.tiling_scheme().tiles_in_stage_col(ident);
83
84        let num_lines_per_tile = tile_size / line_size;
85        let total_units = config.plane_dim() * config.num_loading_planes(ident);
86
87        let num_tiles_in_stage = tile_count_row * tile_count_col;
88        let total_num_lines = num_tiles_in_stage * num_lines_per_tile;
89        let balanced_workload = total_num_lines.is_multiple_of(total_units);
90        let num_tasks_per_unit = total_num_lines.div_ceil(total_units);
91        let jump_length = total_units * line_size;
92
93        let plane_id = RoleRule::new(config.role_rule_config())
94            .load_index(ident, config.specialized_loading_sides());
95        let unit_id = plane_id * config.plane_dim() + UNIT_POS_X;
96        let unit_position_base = unit_id * line_size;
97
98        SyncPartialCyclicJob {
99            unit_position_base,
100            num_tasks_per_unit,
101            stage_index,
102            jump_length,
103            num_lines_per_tile,
104            ident,
105            balanced_workload,
106            num_stage_elements,
107            reader_mode: comptime!(config.reader_mode()),
108        }
109    }
110}
111
112#[derive(CubeType, Clone, Copy)]
113pub struct SyncPartialCyclicJob {
114    unit_position_base: u32,
115
116    #[cube(comptime)]
117    num_tasks_per_unit: u32,
118    #[cube(comptime)]
119    stage_index: u32,
120    #[cube(comptime)]
121    jump_length: u32,
122    #[cube(comptime)]
123    num_lines_per_tile: u32,
124    #[cube(comptime)]
125    ident: MatmulIdent,
126    #[cube(comptime)]
127    balanced_workload: bool,
128    #[cube(comptime)]
129    num_stage_elements: u32,
130    #[cube(comptime)]
131    reader_mode: ReaderMode,
132}
133
134#[cube]
135impl<IP: MatrixPrecision, TO: TilingOrder> LoadingJob<IP, ContiguousTilingLayout<TO>>
136    for SyncPartialCyclicJob
137{
138    fn execute_task<G: GlobalConfig>(
139        this: &mut Self,
140        #[comptime] task_id: u32,
141        global_iter: &GlobalIterator<Line<IP::Global>>,
142        stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
143        #[comptime] config: G,
144    ) {
145        let unit_position = this.unit_position_base + task_id * this.jump_length;
146
147        #[allow(clippy::collapsible_else_if)]
148        if comptime!(this.reader_mode == ReaderMode::Strict || this.balanced_workload) {
149            load_and_store_line::<IP, TO, G>(this, unit_position, global_iter, stage, config);
150        } else {
151            if unit_position < this.num_stage_elements {
152                load_and_store_line::<IP, TO, G>(this, unit_position, global_iter, stage, config);
153            }
154        }
155    }
156
157    fn task_count(this: &Self) -> comptime_type!(u32) {
158        this.num_tasks_per_unit
159    }
160}
161
162#[cube]
163pub(crate) fn load_and_store_line<IP: MatrixPrecision, TO: TilingOrder, G: GlobalConfig>(
164    job: &SyncPartialCyclicJob,
165    unit_position: u32,
166    global_iter: &GlobalIterator<Line<IP::Global>>,
167    stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
168    #[comptime] config: G,
169) {
170    let layout = TiledLayout::new(comptime!(config.global_memory_config(job.ident)));
171    let view = global_iter.view().view(layout);
172
173    let (tile_size, tile_count_row, tile_count_col) = comptime! {
174        (
175            config.tiling_scheme().elements_in_tile(job.ident),
176            config.tiling_scheme().tiles_in_stage_row(job.ident),
177            config.tiling_scheme().tiles_in_stage_col(job.ident),
178        )
179    };
180    let line_size = view.line_size();
181
182    let tile_index = unit_position / tile_size;
183    let pos_within_tile = unit_position % tile_size;
184
185    let (total_tile_count_row, total_tile_count_col) = match comptime!(job.ident) {
186        MatmulIdent::Lhs => (
187            comptime!(tile_count_row),
188            comptime!(tile_count_col * config.num_stages(MatmulIdent::Lhs)),
189        ),
190        MatmulIdent::Rhs => (
191            comptime!(tile_count_row * config.num_stages(MatmulIdent::Rhs)),
192            comptime!(tile_count_col),
193        ),
194        MatmulIdent::Out => comptime!(unreachable!()),
195    };
196
197    let (tile_x_within_stage, tile_y_within_stage) = TO::to_row_col(
198        tile_index,
199        tile_count_row,
200        tile_count_col,
201        comptime!(config.stage_memory_config(job.ident)),
202    );
203
204    let tile = match comptime!(job.ident) {
205        MatmulIdent::Lhs => (
206            tile_x_within_stage,
207            job.stage_index * tile_count_col + tile_y_within_stage,
208        ),
209        MatmulIdent::Rhs => (
210            job.stage_index * tile_count_row + tile_x_within_stage,
211            tile_y_within_stage,
212        ),
213        MatmulIdent::Out => comptime!(unreachable!()),
214    };
215
216    let line_read = view.read_checked((tile, pos_within_tile));
217
218    let nth_tile_in_stage = TO::to_nth_tile(
219        tile,
220        total_tile_count_row,
221        total_tile_count_col,
222        comptime!(config.stage_memory_config(job.ident)),
223    );
224
225    let tile_start = nth_tile_in_stage * job.num_lines_per_tile;
226    let tile_end = tile_start + job.num_lines_per_tile;
227    let mut tile_slice = stage
228        .as_slice_mut(line_size)
229        .slice_mut(tile_start, tile_end);
230
231    tile_slice[pos_within_tile / line_size] = Line::cast_from(line_read);
232}