cubecl_matmul/components/global/read/strategy/
sync_full_strided.rs

1use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
2use crate::components::global::read::{SyncFullLoadingStrategy, stage::FullStageLayout};
3use crate::components::global::{GlobalConfig, RoleRule};
4use crate::components::stage::{StridedStage, StridedTilingLayout};
5use crate::components::{InvalidConfigError, MatmulIdent};
6use crate::components::{MatrixPrecision, TilingScheme};
7use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
8use cubecl_core as cubecl;
9use cubecl_core::prelude::*;
10
11use super::{LoadingJob, LoadingValidation};
12
13#[derive(CubeType, Clone, Copy)]
14/// Loads the content of all the stage using all planes,
15/// keeping the original layout, making each tile strided
16pub struct SyncFullStridedLoading {}
17
18impl LoadingValidation for SyncFullStridedLoading {
19    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
20        let line_size = config.global_line_size(ident);
21
22        let num_stage_lines = config.tiling_scheme().elements_in_stage(ident) / line_size;
23        let total_units = config.num_loading_planes(ident) * config.plane_dim();
24
25        if !num_stage_lines.is_multiple_of(total_units) {
26            return Err(Box::new(
27                "Too many data will be loaded, resulting in out of bounds.
28        Try setting line size and number of planes so that total unit count {:?} divides number of lines in stage.",
29            ));
30        }
31
32        StridedTilingLayout::check(config.global_memory_config(ident))?;
33
34        Ok(())
35    }
36}
37
38impl LoadMaxRoundPlaneCount for SyncFullStridedLoading {
39    fn max_round_plane_count(
40        tiling_scheme: &TilingScheme,
41        ident: MatmulIdent,
42        line_size: u8,
43        plane_dim: u32,
44    ) -> u32 {
45        let num_lines = tiling_scheme.elements_in_stage(ident) / line_size as u32;
46        num_lines.div_ceil(plane_dim)
47    }
48}
49
50#[cube]
51impl SyncFullLoadingStrategy for SyncFullStridedLoading {
52    type TilingLayout = StridedTilingLayout;
53    type Job<IP: MatrixPrecision> = SyncFullStridedJob;
54
55    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
56        #[comptime] ident: MatmulIdent,
57        #[comptime] line_size: u32,
58        #[comptime] config: G,
59    ) -> Self::Job<IP> {
60        let num_stage_lines = config.tiling_scheme().elements_in_stage(ident) / line_size;
61        let unit_count = config.num_loading_planes(ident) * config.plane_dim();
62        let num_tasks_per_unit = comptime!(num_stage_lines / unit_count);
63
64        let unit_position_base = RoleRule::new(config.role_rule_config())
65            .load_index(ident, config.specialized_loading_sides())
66            * config.plane_dim()
67            + UNIT_POS_X;
68
69        SyncFullStridedJob {
70            unit_position_base,
71            num_tasks_per_unit,
72            unit_count,
73            line_size,
74            ident,
75        }
76    }
77}
78
79#[derive(CubeType, Clone, Copy)]
80pub struct SyncFullStridedJob {
81    unit_position_base: u32,
82
83    #[cube(comptime)]
84    num_tasks_per_unit: u32,
85    #[cube(comptime)]
86    unit_count: u32,
87    #[cube(comptime)]
88    line_size: u32,
89    #[cube(comptime)]
90    ident: MatmulIdent,
91}
92
93#[cube]
94impl<IP: MatrixPrecision> LoadingJob<IP, StridedTilingLayout> for SyncFullStridedJob {
95    fn execute_task<G: GlobalConfig>(
96        this: &mut Self,
97        #[comptime] task_id: u32,
98        global_iter: &GlobalIterator<Line<IP::Global>>,
99        stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
100        #[comptime] config: G,
101    ) {
102        let unit_position = this.unit_position_base + task_id * this.unit_count;
103
104        let layout = FullStageLayout::new(comptime![config.global_memory_config(this.ident)]);
105        let view = global_iter.view().view(layout);
106
107        let line_read = view.read_checked(unit_position * this.line_size);
108
109        stage.as_slice_mut(this.line_size)[unit_position] = Line::cast_from(line_read);
110    }
111
112    fn task_count(this: &Self) -> comptime_type!(u32) {
113        this.num_tasks_per_unit
114    }
115}