cubecl_matmul/components/global/read/strategy/
sync_full_ordered.rs

1use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
2use crate::components::global::read::SyncFullLoadingStrategy;
3use crate::components::stage::OrderedTilingOrder;
4use crate::components::{
5    FormattedConfigError, InvalidConfigError, MatmulIdent, MatrixPrecision, TilingScheme,
6};
7use crate::components::{global::GlobalConfig, stage::ContiguousTilingLayout};
8use crate::components::{global::RoleRule, stage::TilingValidation};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::*;
11
12use super::{LoadingValidation, sync_full_tilewise};
13
14#[derive(CubeType, Clone, Copy)]
15/// Similar to `sync_full_tilewise`, but includes additional validation checks.
16///
17/// This function operates only on the LHS (left-hand side).
18///
19/// - In the single-row case, behavior is similar to `tilewise` with row-major tiling order.
20///   However, it will explicitly fail if any plane does not load its entire row.
21/// - In the multi-row case, it too will fail if a plane does not load all its rows.
22///   Within each plane, the local tiling order is column-major.
23pub struct SyncFullOrderedLoading {}
24
25impl LoadingValidation for SyncFullOrderedLoading {
26    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
27        if ident != MatmulIdent::Lhs {
28            return Err(FormattedConfigError::new(move || {
29                "Ordered loading only available on Lhs".to_string()
30            }));
31        }
32
33        let line_size = config.global_line_size(ident);
34        let num_planes = config.num_loading_planes(ident);
35        let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
36
37        if !num_tiles.is_multiple_of(num_planes) {
38            return Err(FormattedConfigError::new(move || {
39                format!(
40                    "Number of planes {num_planes:?} must divide number of tiles {num_tiles:?} for ordered loading.",
41                )
42            }));
43        }
44
45        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
46        let num_lines_per_tile =
47            comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
48        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
49        let num_planes = config.num_loading_planes(ident);
50        let plane_dim = config.plane_dim();
51        let rows_per_plane = config.tiling_scheme().tiles_in_stage_row(ident) / num_planes;
52
53        if num_lines_per_plane % plane_dim != 0 {
54            return Err(FormattedConfigError::new(move || {
55                format!(
56                    "Plane dimension {plane_dim:?} must divide number of lines per plane {num_lines_per_plane:?} for ordered loading.",
57                )
58            }));
59        }
60
61        let tile_count_col = config.tiling_scheme().tiles_in_stage_col(ident);
62        if num_tiles_per_plane != rows_per_plane * tile_count_col {
63            return Err(FormattedConfigError::new(move || {
64                format!(
65                    "Number of tiles per plane {num_tiles_per_plane:?} must equal rows_per_plane {rows_per_plane:?} times cols {tile_count_col:?} for ordered loading.",
66                )
67            }));
68        }
69
70        ContiguousTilingLayout::<OrderedTilingOrder>::check(config.global_memory_config(ident))?;
71
72        Ok(())
73    }
74}
75
76impl LoadMaxRoundPlaneCount for SyncFullOrderedLoading {
77    fn max_round_plane_count(
78        tiling_scheme: &TilingScheme,
79        ident: MatmulIdent,
80        _line_size: u8,
81        _plane_dim: u32,
82    ) -> u32 {
83        tiling_scheme.tiles_in_stage(ident)
84    }
85}
86
87#[cube]
88impl SyncFullLoadingStrategy for SyncFullOrderedLoading {
89    type TilingLayout = ContiguousTilingLayout<OrderedTilingOrder>;
90    type Job<IP: MatrixPrecision> = sync_full_tilewise::SyncFullTilewiseJob;
91
92    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
93        #[comptime] ident: MatmulIdent,
94        #[comptime] line_size: u32,
95        #[comptime] config: G,
96    ) -> Self::Job<IP> {
97        let num_planes = config.num_loading_planes(ident);
98        let num_tiles = config.tiling_scheme().tiles_in_stage(ident);
99        let plane_dim = config.plane_dim();
100
101        let num_tiles_per_plane = comptime!(num_tiles / num_planes);
102        let num_lines_per_tile =
103            comptime!(config.tiling_scheme().elements_in_tile(ident) / line_size);
104        let num_lines_per_plane = num_lines_per_tile * num_tiles_per_plane;
105        let num_lines_per_unit = num_lines_per_plane / plane_dim;
106
107        let num_tiles_to_skip = RoleRule::new(config.role_rule_config())
108            .load_index(ident, config.specialized_loading_sides())
109            * num_tiles_per_plane;
110        let num_lines_to_skip = num_tiles_to_skip * num_lines_per_tile;
111
112        // Ordered is just a tilewise reader using the ordered tiling order
113        sync_full_tilewise::SyncFullTilewiseJob {
114            num_tiles_to_skip,
115            num_lines_to_skip,
116            num_lines_per_tile,
117            num_lines_per_unit,
118            plane_dim,
119            line_size,
120            ident,
121        }
122    }
123}