cubecl_matmul/components/global/read/strategy/
sync_full_tilewise.rs

1use std::marker::PhantomData;
2
3use crate::components::MatmulElems;
4use crate::components::global::GlobalReaderConfig;
5use crate::components::global::read::validate_swizzle_atom_size;
6use crate::components::global::read::{FullLoadingStrategy, sync::Synchronous};
7use crate::components::global::{RoleRule, read::tiled::TiledLayout};
8use crate::components::stage::StridedStageFamily;
9use crate::components::stage::{StridedStageMemory, TilingOrder};
10use crate::components::{FormattedConfigError, InvalidConfigError};
11use crate::components::{global::memory::GlobalIterator, stage::ContiguousTilingLayout};
12use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation};
13use cubecl_core as cubecl;
14use cubecl_core::prelude::*;
15use cubecl_std::{tensor::layout::Coords2d, type_size};
16
17use super::{LoadingJob, LoadingValidation};
18
19#[derive(CubeType, Clone, Copy)]
20/// Each tile is guaranteed to be loaded entirely by the same plane.
21/// Each plane can load multiple tiles, provided the number of planes evenly divides the number of tiles.
22/// In this case, a plane loads contiguous tiles following the TilingOrder.
23///
24/// If number of planes = number of rows of Lhs and TilingOrder is RowMajor,
25/// each plane loads its own row and a sync can be saved.
26/// In multi-row, number of planes must divide number of rows,
27/// and each plane loads a contiguous chunk of rows (e.g. plane 0 loads rows 0–1, plane 1 loads 2–3, etc.).
28pub struct SyncFullTilewiseLoading<T: TilingOrder> {
29    #[cube(comptime)]
30    tiling_order: PhantomData<T>,
31}
32
33impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncFullTilewiseLoading<TO> {
34    fn max_round_plane_count(
35        _elements_per_tile: u32,
36        tiles_per_stage: u32,
37        _line_size: u8,
38        _plane_dim: u32,
39    ) -> u32 {
40        tiles_per_stage
41    }
42}
43
44impl<T: TilingOrder> LoadingValidation for SyncFullTilewiseLoading<T> {
45    fn check<R: Runtime>(
46        _client: &ComputeClient<R::Server>,
47        config: &GlobalReaderConfig,
48
49        dtypes: &MatmulElems,
50    ) -> Result<(), InvalidConfigError> {
51        let line_size = config.gmem_config.line_size;
52        let num_planes = config.loading_planes_count();
53        let num_tiles = config.smem_config.tiles_per_stage();
54
55        if !num_tiles.is_multiple_of(num_planes) {
56            return Err(FormattedConfigError::new(move || {
57                format!(
58                    "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for tilewise loading.",
59                )
60            }));
61        }
62
63        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
64        let num_lines_per_tile = comptime!(config.smem_config.elements_per_tile() / line_size);
65        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
66        let plane_dim = config.plane_dim;
67
68        if num_lines_per_plane % plane_dim != 0 {
69            return Err(FormattedConfigError::new(move || {
70                format!(
71                    "Plane dimension {plane_dim:?} must divide number of lines per plane {num_lines_per_plane:?} for tilewise loading.",
72                )
73            }));
74        }
75
76        validate_swizzle_atom_size(config.smem_config, config.stage_ident, dtypes)?;
77        ContiguousTilingLayout::<T>::check(config.smem_config)?;
78
79        Ok(())
80    }
81}
82
83#[cube]
84impl<TO: TilingOrder> FullLoadingStrategy for SyncFullTilewiseLoading<TO> {
85    type TilingLayout = ContiguousTilingLayout<TO>;
86    type SyncStrategy = Synchronous;
87    type Job<EG: Numeric, ES: Numeric> = SyncFullTilewiseJob;
88
89    fn new_job<EG: Numeric, ES: Numeric>(
90        #[comptime] line_size: u32,
91        #[comptime] config: GlobalReaderConfig,
92    ) -> Self::Job<EG, ES> {
93        let num_planes = config.loading_planes_count();
94        let num_tiles = config.smem_config.tiles_per_stage();
95
96        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
97        let num_lines_per_tile = comptime!(config.smem_config.elements_per_tile() / line_size);
98        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
99        let num_lines_per_unit = num_lines_per_plane / config.plane_dim;
100
101        let num_tiles_to_skip = RoleRule::new(config.plane_role_config.rule)
102            .load_index(config.specialization_tensor_config)
103            * num_tiles_per_plane;
104        let num_lines_to_skip = num_tiles_to_skip * num_lines_per_tile;
105
106        SyncFullTilewiseJob {
107            num_tiles_to_skip,
108            num_lines_to_skip,
109            num_lines_per_tile,
110            num_lines_per_unit,
111            plane_dim: config.plane_dim,
112            line_size,
113        }
114    }
115}
116
117#[derive(CubeType, Clone, Copy)]
118pub struct SyncFullTilewiseJob {
119    pub num_tiles_to_skip: u32,
120    pub num_lines_to_skip: u32,
121
122    #[cube(comptime)]
123    pub num_lines_per_tile: u32,
124    #[cube(comptime)]
125    pub num_lines_per_unit: u32,
126    #[cube(comptime)]
127    pub plane_dim: u32,
128    #[cube(comptime)]
129    pub line_size: u32,
130}
131
132#[cube]
133impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
134    LoadingJob<EG, ES, ContiguousTilingLayout<TO>, Synchronous> for SyncFullTilewiseJob
135{
136    type Stage = StridedStageFamily;
137
138    fn execute_task(
139        this: &mut Self,
140        #[comptime] task_id: u32,
141        global_iter: &GlobalIterator<Line<EG>>,
142        stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
143        _barrier: &mut (),
144        #[comptime] config: GlobalReaderConfig,
145    ) {
146        let pos_across_tiles = task_id * this.plane_dim + UNIT_POS_X;
147        let nth_tile_for_this_plane = pos_across_tiles / this.num_lines_per_tile;
148        let line_index_within_tile = pos_across_tiles % this.num_lines_per_tile;
149
150        let nth_tile_global = nth_tile_for_this_plane + this.num_tiles_to_skip;
151        let tile =
152            ContiguousTilingLayout::<TO>::to_x_y(nth_tile_global, comptime!(config.smem_config));
153
154        SyncFullTilewiseJob::load_and_store_line::<EG, ES, TO>(
155            this,
156            tile,
157            line_index_within_tile,
158            nth_tile_for_this_plane * this.num_lines_per_tile,
159            global_iter,
160            stage,
161            config,
162        );
163    }
164
165    fn task_count(this: &Self) -> comptime_type!(u32) {
166        comptime!(this.num_lines_per_unit)
167    }
168}
169
170#[cube]
171impl SyncFullTilewiseJob {
172    #[allow(clippy::too_many_arguments)]
173    fn load_and_store_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
174        this: &Self,
175        tile: Coords2d,
176        line_index_within_tile: u32,
177        num_lines_to_skip_local: u32,
178        global_iter: &GlobalIterator<Line<EG>>,
179        stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
180        #[comptime] config: GlobalReaderConfig,
181    ) {
182        let layout = TiledLayout::new(comptime!(config.smem_config));
183        let view = global_iter.view().view(layout);
184
185        let line_read = view.read_checked((tile, line_index_within_tile * this.line_size));
186
187        let offset = this.num_lines_to_skip + line_index_within_tile + num_lines_to_skip_local;
188        let type_size = type_size::<ES>(this.line_size);
189        let offset = stage.swizzle.apply(offset, type_size);
190
191        stage.as_slice_mut(this.line_size)[offset] = Line::cast_from(line_read);
192    }
193}