cubecl_matmul/components/global/read/strategy/
sync_partial_tilewise.rs

1use std::marker::PhantomData;
2
3use crate::components::MatmulElems;
4use crate::components::global::GlobalReaderConfig;
5use crate::components::global::read::validate_swizzle_atom_size;
6use crate::components::global::read::{PartialLoadingStrategy, sync::Synchronous};
7use crate::components::global::{RoleRule, read::tiled::TiledLayout};
8use crate::components::stage::StridedStageFamily;
9use crate::components::stage::StridedStageMemory;
10use crate::components::stage::TilingOrderEnum;
11use crate::components::{FormattedConfigError, InvalidConfigError, StageIdent};
12use crate::components::{
13    global::memory::GlobalIterator,
14    stage::{ContiguousTilingLayout, TilingOrder},
15};
16use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation};
17use cubecl_core as cubecl;
18use cubecl_core::prelude::*;
19use cubecl_std::{tensor::layout::Coords2d, type_size};
20
21use super::{LoadingJob, LoadingValidation};
22
23#[derive(CubeType, Clone, Copy)]
24/// Each tile is guaranteed to be loaded entirely by the same plane.
25/// Each plane can load multiple tiles, provided the number of planes evenly divides the number of tiles.
26/// In this case, a plane loads contiguous tiles following the `TilingOrder`,
27/// until it would otherwise write to the opposite stage. At that point, it continues on the next
28/// row or column of the same stage, skipping over the memory region of the other stage.
29///
30/// Only supports RowMajorTilingOrder for Lhs and ColMajorTilingOrder for Rhs
31pub struct SyncPartialTilewiseLoading<T: TilingOrder> {
32    #[cube(comptime)]
33    tiling_order: PhantomData<T>,
34}
35
36impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncPartialTilewiseLoading<TO> {
37    fn max_round_plane_count(
38        _elements_per_tile: u32,
39        tiles_per_stage: u32,
40        _line_size: u8,
41        _plane_dim: u32,
42    ) -> u32 {
43        tiles_per_stage
44    }
45}
46
47impl<T: TilingOrder> LoadingValidation for SyncPartialTilewiseLoading<T> {
48    fn check<R: Runtime>(
49        _client: &ComputeClient<R::Server>,
50        config: &GlobalReaderConfig,
51        dtypes: &MatmulElems,
52    ) -> Result<(), InvalidConfigError> {
53        let line_size = config.gmem_config.line_size;
54        let num_planes = config.loading_planes_count();
55        let num_tiles = config.smem_config.tiles_per_stage();
56
57        if !num_tiles.is_multiple_of(num_planes) {
58            return Err(FormattedConfigError::new(move || {
59                "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for tilewise loading.".to_string()
60            }));
61        }
62
63        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
64        let num_lines_per_tile = comptime!(config.smem_config.elements_per_tile() / line_size);
65        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
66        let num_planes = config.plane_dim;
67
68        if num_lines_per_plane % num_planes != 0 {
69            return Err(FormattedConfigError::new(move || {
70                "Number of planes {num_planes:?} must divide number of lines per plane {num_lines_per_plane:?} for tilewise loading.".to_string()
71            }));
72        }
73
74        match config.stage_ident {
75            StageIdent::Lhs => {
76                if !matches!(T::to_enum(), TilingOrderEnum::RowMajor) {
77                    return Err(FormattedConfigError::new(move || {
78                        "Sync partial tilewise on Lhs is only supported with RowMajor tiling order"
79                            .to_string()
80                    }));
81                }
82            }
83            StageIdent::Rhs => {
84                if !matches!(T::to_enum(), TilingOrderEnum::ColMajor) {
85                    return Err(FormattedConfigError::new(move || {
86                        "Sync partial tilewise on Rhs is only supported with ColMajor tiling order"
87                            .to_string()
88                    }));
89                }
90            }
91            _ => unreachable!(),
92        }
93
94        validate_swizzle_atom_size(config.smem_config, config.stage_ident, dtypes)?;
95        ContiguousTilingLayout::<T>::check(config.smem_config)?;
96
97        Ok(())
98    }
99}
100
101#[cube]
102impl<TO: TilingOrder> PartialLoadingStrategy for SyncPartialTilewiseLoading<TO> {
103    type TilingLayout = ContiguousTilingLayout<TO>;
104    type SyncStrategy = Synchronous;
105    type Stage = StridedStageFamily;
106
107    type Job<EG: Numeric, ES: Numeric> = SyncPartialTilewiseJob;
108
109    fn new_job<EG: Numeric, ES: Numeric>(
110        #[comptime] stage_index: u32,
111        #[comptime] line_size: u32,
112        #[comptime] config: GlobalReaderConfig,
113    ) -> SyncPartialTilewiseJob {
114        let num_planes = config.loading_planes_count();
115        let num_tiles = config.smem_config.tiles_per_stage();
116        let plane_dim = config.plane_dim;
117
118        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
119        let num_lines_per_tile = comptime!(config.smem_config.elements_per_tile() / line_size);
120        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
121        let num_lines_per_unit = num_lines_per_plane / plane_dim;
122
123        let stage_width = comptime!(match config.stage_ident {
124            StageIdent::Lhs => config.smem_config.tiles_per_stage_along_col(),
125            StageIdent::Rhs => config.smem_config.tiles_per_stage_along_row(),
126            _ => unreachable!(),
127        });
128
129        let num_tiles_to_skip = RoleRule::new(config.plane_role_config.rule)
130            .load_index(config.specialization_tensor_config)
131            * num_tiles_per_plane;
132
133        SyncPartialTilewiseJob {
134            stage_index,
135            num_tiles_to_skip,
136            stage_width,
137            num_lines_per_tile,
138            num_lines_per_unit,
139            plane_dim,
140            line_size,
141        }
142    }
143}
144
145#[derive(CubeType, Clone, Copy)]
146pub struct SyncPartialTilewiseJob {
147    num_tiles_to_skip: u32,
148    stage_index: u32,
149
150    #[cube(comptime)]
151    stage_width: u32,
152    #[cube(comptime)]
153    num_lines_per_tile: u32,
154    #[cube(comptime)]
155    num_lines_per_unit: u32,
156    #[cube(comptime)]
157    plane_dim: u32,
158    #[cube(comptime)]
159    line_size: u32,
160}
161
162#[cube]
163impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
164    LoadingJob<EG, ES, ContiguousTilingLayout<TO>, Synchronous> for SyncPartialTilewiseJob
165{
166    type Stage = StridedStageFamily;
167
168    fn execute_task(
169        this: &mut Self,
170        #[comptime] task_id: u32,
171        global_iter: &GlobalIterator<Line<EG>>,
172        stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
173        _barrier: &mut (),
174        #[comptime] config: GlobalReaderConfig,
175    ) {
176        let mut stage = stage.with_buffer_index(this.stage_index);
177        let pos_across_tiles = task_id * this.plane_dim + UNIT_POS_X;
178        let nth_tile_for_this_plane = pos_across_tiles / this.num_lines_per_tile;
179        let line_index_within_tile = pos_across_tiles % this.num_lines_per_tile;
180
181        let nth_tile_global = this.num_tiles_to_skip + nth_tile_for_this_plane;
182
183        let tile = TO::to_row_col(
184            nth_tile_global,
185            config.smem_config.tiles_per_stage_along_row(),
186            config.smem_config.tiles_per_stage_along_col(),
187            config.smem_config,
188        );
189
190        let tile = match comptime![config.stage_ident] {
191            StageIdent::Lhs => (tile.0, tile.1 + this.stage_index * this.stage_width),
192            StageIdent::Rhs => (tile.0 + this.stage_index * this.stage_width, tile.1),
193            _ => tile,
194        };
195
196        let num_lines_to_skip_global = nth_tile_global * this.num_lines_per_tile;
197
198        SyncPartialTilewiseJob::load_and_store_line::<EG, ES, TO>(
199            this,
200            tile,
201            line_index_within_tile,
202            num_lines_to_skip_global,
203            global_iter,
204            &mut stage,
205            config,
206        );
207    }
208
209    fn task_count(this: &Self) -> comptime_type!(u32) {
210        comptime!(this.num_lines_per_unit)
211    }
212}
213
214#[cube]
215impl SyncPartialTilewiseJob {
216    #[allow(clippy::too_many_arguments)]
217    fn load_and_store_line<EG: Numeric, ES: Numeric, TO: TilingOrder>(
218        this: &Self,
219        tile: Coords2d,
220        line_index_within_tile: u32,
221        num_lines_to_skip_global: u32,
222        global_iter: &GlobalIterator<Line<EG>>,
223        stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
224        #[comptime] config: GlobalReaderConfig,
225    ) {
226        let layout = TiledLayout::new(comptime!(config.smem_config));
227        let view = global_iter.view().view(layout);
228
229        let line_read = view.read_checked((tile, line_index_within_tile * this.line_size));
230
231        let offset = line_index_within_tile + num_lines_to_skip_global;
232        let type_size = type_size::<ES>(this.line_size);
233        let offset = stage.swizzle.apply(offset, type_size);
234
235        stage.as_slice_mut(this.line_size)[offset] = Line::cast_from(line_read);
236    }
237}