cubecl_matmul/components/global/read/strategy/
sync_full_ordered.rs1use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
2use crate::components::global::read::SyncFullLoadingStrategy;
3use crate::components::stage::OrderedTilingOrder;
4use crate::components::{
5 FormattedConfigError, InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme,
6};
7use crate::components::{global::GlobalConfig, stage::ContiguousTilingLayout};
8use crate::components::{global::RoleRule, stage::TilingValidation};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::*;
11
12use super::{LoadingValidation, sync_full_tilewise};
13
14#[derive(CubeType, Clone, Copy)]
15pub struct SyncFullOrderedLoading {}
24
25impl LoadingValidation for SyncFullOrderedLoading {
26 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
27 if ident != MatmulIdent::Lhs {
28 return Err(FormattedConfigError::new(move || {
29 "Ordered loading only available on Lhs".to_string()
30 }));
31 }
32
33 let line_size = config.global_line_size(ident);
34 let num_planes = config.num_loading_planes(ident);
35 let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
36
37 if !num_tiles.is_multiple_of(num_planes) {
38 return Err(FormattedConfigError::new(move || {
39 format!(
40 "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for ordered loading.",
41 )
42 }));
43 }
44
45 let num_tiles_per_plane = comptime!(num_tiles / num_planes);
46 let num_lines_per_tile =
47 comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
48 let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
49 let num_planes = config.num_loading_planes(ident);
50 let plane_dim = config.plane_dim();
51 let rows_per_plane = config.tiling_scheme().tiles_in_stage_row(ident) / num_planes;
52
53 if num_lines_per_plane % plane_dim != 0 {
54 return Err(FormattedConfigError::new(move || {
55 format!(
56 "Plane dimension {plane_dim:?} must divide number of lines per plane {num_lines_per_plane:?} for ordered loading.",
57 )
58 }));
59 }
60
61 let tile_count_col = config.tiling_scheme().tiles_in_stage_col(ident);
62 if num_tiles_per_plane != rows_per_plane * tile_count_col {
63 return Err(FormattedConfigError::new(move || {
64 format!(
65 "Number of tiles per plane {num_tiles_per_plane:?} must equal rows_per_plane {rows_per_plane:?} times cols {tile_count_col:?} for ordered loading.",
66 )
67 }));
68 }
69
70 ContiguousTilingLayout::<OrderedTilingOrder>::check(config.global_memory_config(ident))?;
71
72 Ok(())
73 }
74}
75
76impl LoadMaxRoundPlaneCount for SyncFullOrderedLoading {
77 fn max_round_plane_count(
78 tiling_scheme: &TilingScheme,
79 ident: MatmulIdent,
80 _line_size: u8,
81 _plane_dim: u32,
82 ) -> u32 {
83 tiling_scheme.tiles_in_stage(ident)
84 }
85}
86
87#[cube]
88impl SyncFullLoadingStrategy for SyncFullOrderedLoading {
89 type TilingLayout = ContiguousTilingLayout<OrderedTilingOrder>;
90 type Job<IP: MatrixPrecision> = sync_full_tilewise::SyncFullTilewiseJob;
91
92 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
93 #[comptime] ident: MatmulIdent,
94 #[comptime] line_size: u32,
95 #[comptime] config: G,
96 ) -> Self::Job<IP> {
97 let num_planes = config.num_loading_planes(ident);
98 let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
99 let plane_dim = config.plane_dim();
100
101 let num_tiles_per_plane = comptime!(num_tiles / num_planes);
102 let num_lines_per_tile =
103 comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
104 let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
105 let num_lines_per_unit = num_lines_per_plane / plane_dim;
106
107 let num_tiles_to_skip = RoleRule::new(config.role_rule_config())
108 .load_index(ident, config.specialized_loading_sides())
109 * num_tiles_per_plane;
110 let num_lines_to_skip = num_tiles_to_skip * num_lines_per_tile;
111
112 sync_full_tilewise::SyncFullTilewiseJob {
114 num_tiles_to_skip,
115 num_lines_to_skip,
116 num_lines_per_tile,
117 num_lines_per_unit,
118 plane_dim,
119 line_size,
120 ident,
121 }
122 }
123}