cubecl_matmul/components/global/read/strategy/
sync_partial_tilewise.rs

1use std::marker::PhantomData;
2
3use crate::components::global::read::SyncPartialLoadingStrategy;
4use crate::components::global::{RoleRule, read::tiled::TiledLayout};
5use crate::components::stage::TilingOrderEnum;
6use crate::components::{
7    FormattedConfigError, InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme,
8};
9use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TilingValidation};
10use crate::components::{
11    global::{GlobalConfig, memory::GlobalIterator},
12    stage::{ContiguousTilingLayout, StridedStage, TilingOrder},
13};
14use cubecl_core as cubecl;
15use cubecl_core::prelude::*;
16use cubecl_std::tensor::layout::Coords2d;
17
18use super::{LoadingJob, LoadingValidation};
19
20#[derive(CubeType, Clone, Copy)]
21/// Each tile is guaranteed to be loaded entirely by the same plane.
22/// Each plane can load multiple tiles, provided the number of planes evenly divides the number of tiles.
23/// In this case, a plane loads contiguous tiles following the `TilingOrder`,
24/// until it would otherwise write to the opposite stage. At that point, it continues on the next
25/// row or column of the same stage, skipping over the memory region of the other stage.
26///
27/// Only supports RowMajorTilingOrder for Lhs and ColMajorTilingOrder for Rhs
28pub struct SyncPartialTilewiseLoading<T: TilingOrder> {
29    #[cube(comptime)]
30    tiling_order: PhantomData<T>,
31}
32
33impl<TO: TilingOrder> LoadMaxRoundPlaneCount for SyncPartialTilewiseLoading<TO> {
34    fn max_round_plane_count(
35        tiling_scheme: &TilingScheme,
36        ident: MatmulIdent,
37        _line_size: u8,
38        _plane_dim: u32,
39    ) -> u32 {
40        tiling_scheme.tiles_in_stage(ident)
41    }
42}
43
44impl<T: TilingOrder> LoadingValidation for SyncPartialTilewiseLoading<T> {
45    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
46        let line_size = config.global_line_size(ident);
47        let num_planes = config.num_loading_planes(ident);
48        let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
49
50        if !num_tiles.is_multiple_of(num_planes) {
51            return Err(FormattedConfigError::new(move || {
52                "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for tilewise loading.".to_string()
53            }));
54        }
55
56        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
57        let num_lines_per_tile =
58            comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
59        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
60        let num_planes = config.plane_dim();
61
62        if num_lines_per_plane % num_planes != 0 {
63            return Err(FormattedConfigError::new(move || {
64                "Number of planes {num_planes:?} must divide number of lines per plane {num_lines_per_plane:?} for tilewise loading.".to_string()
65            }));
66        }
67
68        match ident {
69            MatmulIdent::Lhs => {
70                if !matches!(T::to_enum(), TilingOrderEnum::RowMajor) {
71                    return Err(FormattedConfigError::new(move || {
72                        "Sync partial tilewise on Lhs is only supported with RowMajor tiling order"
73                            .to_string()
74                    }));
75                }
76            }
77            MatmulIdent::Rhs => {
78                if !matches!(T::to_enum(), TilingOrderEnum::ColMajor) {
79                    return Err(FormattedConfigError::new(move || {
80                        "Sync partial tilewise on Rhs is only supported with ColMajor tiling order"
81                            .to_string()
82                    }));
83                }
84            }
85            MatmulIdent::Out => unreachable!(),
86        }
87
88        ContiguousTilingLayout::<T>::check(config.global_memory_config(ident))?;
89
90        Ok(())
91    }
92}
93
94#[cube]
95impl<TO: TilingOrder> SyncPartialLoadingStrategy for SyncPartialTilewiseLoading<TO> {
96    type TilingLayout = ContiguousTilingLayout<TO>;
97    type Job<IP: MatrixPrecision> = SyncPartialTilewiseJob;
98
99    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
100        #[comptime] stage_index: u32,
101        #[comptime] ident: MatmulIdent,
102        #[comptime] line_size: u32,
103        #[comptime] config: G,
104    ) -> SyncPartialTilewiseJob {
105        let num_planes = config.num_loading_planes(ident);
106        let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
107        let plane_dim = config.plane_dim();
108
109        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
110        let num_lines_per_tile =
111            comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
112        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
113        let num_lines_per_unit = num_lines_per_plane / plane_dim;
114
115        let num_stages = config.num_stages(ident);
116        let stage_width = comptime!(match ident {
117            MatmulIdent::Lhs => config.tiling_scheme().tiles_in_stage_col(ident),
118            MatmulIdent::Rhs => config.tiling_scheme().tiles_in_stage_row(ident),
119            MatmulIdent::Out => unreachable!(),
120        });
121        let row_col_stride = num_stages * stage_width;
122        let stage_offset = stage_width * stage_index;
123
124        let starting_tile_within_stage = RoleRule::new(config.role_rule_config())
125            .load_index(ident, config.specialized_loading_sides())
126            * num_tiles_per_plane;
127        let row_col_index = starting_tile_within_stage / stage_width;
128        let inner_offset = starting_tile_within_stage % stage_width;
129        let num_tiles_to_skip = row_col_index * row_col_stride + inner_offset + stage_offset;
130
131        SyncPartialTilewiseJob {
132            num_tiles_to_skip,
133            row_col_stride,
134            stage_width,
135            num_lines_per_tile,
136            num_lines_per_unit,
137            plane_dim: config.plane_dim(),
138            line_size,
139            ident,
140        }
141    }
142}
143
144#[derive(CubeType, Clone, Copy)]
145pub struct SyncPartialTilewiseJob {
146    num_tiles_to_skip: u32,
147
148    #[cube(comptime)]
149    row_col_stride: u32,
150    #[cube(comptime)]
151    stage_width: u32,
152    #[cube(comptime)]
153    num_lines_per_tile: u32,
154    #[cube(comptime)]
155    num_lines_per_unit: u32,
156    #[cube(comptime)]
157    plane_dim: u32,
158    #[cube(comptime)]
159    line_size: u32,
160    #[cube(comptime)]
161    ident: MatmulIdent,
162}
163
164#[cube]
165impl<IP: MatrixPrecision, TO: TilingOrder> LoadingJob<IP, ContiguousTilingLayout<TO>>
166    for SyncPartialTilewiseJob
167{
168    fn execute_task<G: GlobalConfig>(
169        this: &mut Self,
170        #[comptime] task_id: u32,
171        global_iter: &GlobalIterator<Line<IP::Global>>,
172        stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
173        #[comptime] config: G,
174    ) {
175        let pos_across_tiles = task_id * this.plane_dim + UNIT_POS_X;
176        let nth_tile_for_this_plane = pos_across_tiles / this.num_lines_per_tile;
177        let line_index_within_tile = pos_across_tiles % this.num_lines_per_tile;
178
179        let row_col_index_local = nth_tile_for_this_plane / this.stage_width;
180        let inner_offset = nth_tile_for_this_plane % this.stage_width;
181        let num_tiles_to_skip_local = row_col_index_local * this.row_col_stride + inner_offset;
182        let nth_tile_global = this.num_tiles_to_skip + num_tiles_to_skip_local;
183
184        let (total_tile_count_row, total_tile_count_col) = match comptime!(this.ident) {
185            MatmulIdent::Lhs => (
186                comptime!(config.tiling_scheme().tiles_in_stage_m()),
187                comptime!(
188                    config.tiling_scheme().tiles_in_stage_k() * config.num_stages(MatmulIdent::Lhs)
189                ),
190            ),
191            MatmulIdent::Rhs => (
192                comptime!(
193                    config.tiling_scheme().tiles_in_stage_k() * config.num_stages(MatmulIdent::Rhs)
194                ),
195                comptime!(config.tiling_scheme().tiles_in_stage_n()),
196            ),
197            MatmulIdent::Out => comptime!(unreachable!()),
198        };
199
200        let tile = TO::to_row_col(
201            nth_tile_global,
202            total_tile_count_row,
203            total_tile_count_col,
204            comptime!(config.stage_memory_config(this.ident)),
205        );
206
207        let num_lines_to_skip_global = nth_tile_global * this.num_lines_per_tile;
208
209        SyncPartialTilewiseJob::load_and_store_line::<IP, TO, G>(
210            this,
211            tile,
212            line_index_within_tile,
213            num_lines_to_skip_global,
214            global_iter,
215            stage,
216            config,
217        );
218    }
219
220    fn task_count(this: &Self) -> comptime_type!(u32) {
221        comptime!(this.num_lines_per_unit)
222    }
223}
224
225#[cube]
226impl SyncPartialTilewiseJob {
227    #[allow(clippy::too_many_arguments)]
228    fn load_and_store_line<IP: MatrixPrecision, TO: TilingOrder, G: GlobalConfig>(
229        this: &Self,
230        tile: Coords2d,
231        line_index_within_tile: u32,
232        num_lines_to_skip_global: u32,
233        global_iter: &GlobalIterator<Line<IP::Global>>,
234        stage: &mut StridedStage<IP::Stage, ContiguousTilingLayout<TO>>,
235        #[comptime] config: G,
236    ) {
237        let layout = TiledLayout::new(comptime!(config.global_memory_config(this.ident)));
238        let view = global_iter.view().view(layout);
239
240        let line_read = view.read_checked((tile, line_index_within_tile * this.line_size));
241
242        let offset = line_index_within_tile + num_lines_to_skip_global;
243
244        stage.as_slice_mut(this.line_size)[offset] = Line::cast_from(line_read);
245    }
246}