cubecl_matmul/components/global/read/strategy/
sync_full_tilewise.rs

1use std::marker::PhantomData;
2
3use crate::components::global::read::SyncFullLoadingStrategy;
4use crate::components::global::{RoleRule, read::tiled::TiledLayout};
5use crate::components::{
6    FormattedConfigError, InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme,
7};
8use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation};
9use crate::components::{
10    global::{GlobalConfig, memory::GlobalIterator},
11    stage::{ContiguousTilingLayout, StridedStage, TilingOrder},
12};
13use cubecl_core as cubecl;
14use cubecl_core::prelude::*;
15use cubecl_std::tensor::layout::Coords2d;
16
17use super::{LoadingJob, LoadingValidation};
18
19#[derive(CubeType, Clone, Copy)]
20/// Each tile is guaranteed to be loaded entirely by the same plane.
21/// Each plane can load multiple tiles, provided the number of planes evenly divides the number of tiles.
22/// In this case, a plane loads contiguous tiles following the TilingOrder.
23///
24/// If number of planes = number of rows of Lhs and TilingOrder is RowMajor,
25/// each plane loads its own row and a sync can be saved.
26/// In multi-row, number of planes must divide number of rows,
27/// and each plane loads a contiguous chunk of rows (e.g. plane 0 loads rows 0–1, plane 1 loads 2–3, etc.).
28pub struct SyncFullTilewiseLoading<T: TilingOrder> {
29    #[cube(comptime)]
30    tiling_order: PhantomData<T>,
31}
32
33impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncFullTilewiseLoading<TO> {
34    fn max_round_plane_count(
35        tiling_scheme: &TilingScheme,
36        ident: MatmulIdent,
37        _line_size: u8,
38        _plane_dim: u32,
39    ) -> u32 {
40        tiling_scheme.tiles_in_stage(ident)
41    }
42}
43
44impl<T: TilingOrder> LoadingValidation for SyncFullTilewiseLoading<T> {
45    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
46        let line_size = config.global_line_size(ident);
47        let num_planes = config.num_loading_planes(ident);
48        let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
49
50        if !num_tiles.is_multiple_of(num_planes) {
51            return Err(FormattedConfigError::new(move || {
52                format!(
53                    "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for tilewise loading.",
54                )
55            }));
56        }
57
58        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
59        let num_lines_per_tile =
60            comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
61        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
62        let plane_dim = config.plane_dim();
63
64        if num_lines_per_plane % plane_dim != 0 {
65            return Err(FormattedConfigError::new(move || {
66                format!(
67                    "Plane dimension {plane_dim:?} must divide number of lines per plane {num_lines_per_plane:?} for tilewise loading.",
68                )
69            }));
70        }
71
72        ContiguousTilingLayout::<T>::check(config.global_memory_config(ident))?;
73
74        Ok(())
75    }
76}
77
78#[cube]
79impl<TO: TilingOrder> SyncFullLoadingStrategy for SyncFullTilewiseLoading<TO> {
80    type TilingLayout = ContiguousTilingLayout<TO>;
81    type Job<IP: MatrixPrecision> = SyncFullTilewiseJob;
82
83    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
84        #[comptime] ident: MatmulIdent,
85        #[comptime] line_size: u32,
86        #[comptime] config: G,
87    ) -> Self::Job<IP> {
88        let num_planes = config.num_loading_planes(ident);
89        let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
90        let plane_dim = config.plane_dim();
91
92        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
93        let num_lines_per_tile =
94            comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
95        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
96        let num_lines_per_unit = num_lines_per_plane / plane_dim;
97
98        let num_tiles_to_skip = RoleRule::new(config.role_rule_config())
99            .load_index(ident, config.specialized_loading_sides())
100            * num_tiles_per_plane;
101        let num_lines_to_skip = num_tiles_to_skip * num_lines_per_tile;
102
103        SyncFullTilewiseJob {
104            num_tiles_to_skip,
105            num_lines_to_skip,
106            num_lines_per_tile,
107            num_lines_per_unit,
108            plane_dim: config.plane_dim(),
109            line_size,
110            ident,
111        }
112    }
113}
114
115#[derive(CubeType, Clone, Copy)]
116pub struct SyncFullTilewiseJob {
117    pub num_tiles_to_skip: u32,
118    pub num_lines_to_skip: u32,
119
120    #[cube(comptime)]
121    pub num_lines_per_tile: u32,
122    #[cube(comptime)]
123    pub num_lines_per_unit: u32,
124    #[cube(comptime)]
125    pub plane_dim: u32,
126    #[cube(comptime)]
127    pub line_size: u32,
128    #[cube(comptime)]
129    pub ident: MatmulIdent,
130}
131
132#[cube]
133impl<IP: MatrixPrecision, TO: TilingOrder> LoadingJob<IP, ContiguousTilingLayout<TO>>
134    for SyncFullTilewiseJob
135{
136    fn execute_task<G: GlobalConfig>(
137        this: &mut Self,
138        #[comptime] task_id: u32,
139        global_iter: &GlobalIterator<Line<IP::Global>>,
140        stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
141        #[comptime] config: G,
142    ) {
143        let pos_across_tiles = task_id * this.plane_dim + UNIT_POS_X;
144        let nth_tile_for_this_plane = pos_across_tiles / this.num_lines_per_tile;
145        let line_index_within_tile = pos_across_tiles % this.num_lines_per_tile;
146
147        let nth_tile_global = nth_tile_for_this_plane + this.num_tiles_to_skip;
148        let tile = ContiguousTilingLayout::<TO>::to_x_y(
149            nth_tile_global,
150            comptime!(config.stage_memory_config(this.ident)),
151        );
152
153        SyncFullTilewiseJob::load_and_store_line::<IP, TO, G>(
154            this,
155            tile,
156            line_index_within_tile,
157            nth_tile_for_this_plane * this.num_lines_per_tile,
158            global_iter,
159            stage,
160            config,
161        );
162    }
163
164    fn task_count(this: &Self) -> comptime_type!(u32) {
165        comptime!(this.num_lines_per_unit)
166    }
167}
168
169#[cube]
170impl SyncFullTilewiseJob {
171    #[allow(clippy::too_many_arguments)]
172    fn load_and_store_line<IP: MatrixPrecision, TO: TilingOrder, G: GlobalConfig>(
173        this: &Self,
174        tile: Coords2d,
175        line_index_within_tile: u32,
176        num_lines_to_skip_local: u32,
177        global_iter: &GlobalIterator<Line<IP::Global>>,
178        stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
179        #[comptime] config: G,
180    ) {
181        let layout = TiledLayout::new(comptime!(config.global_memory_config(this.ident)));
182        let view = global_iter.view().view(layout);
183
184        let line_read = view.read_checked((tile, line_index_within_tile * this.line_size));
185
186        let offset = this.num_lines_to_skip + line_index_within_tile + num_lines_to_skip_local;
187
188        stage.as_slice_mut(this.line_size)[offset] = Line::cast_from(line_read);
189    }
190}