cubecl_matmul/components/global/read/strategy/
sync_full_strided.rs1use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
2use crate::components::global::read::{SyncFullLoadingStrategy, stage::FullStageLayout};
3use crate::components::global::{GlobalConfig, RoleRule};
4use crate::components::stage::{StridedStage, StridedTilingLayout};
5use crate::components::{InvalidConfigError, MatmulIdent};
6use crate::components::{MatrixPrecision, TilingScheme};
7use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
8use cubecl_core as cubecl;
9use cubecl_core::prelude::*;
10
11use super::{LoadingJob, LoadingValidation};
12
13#[derive(CubeType, Clone, Copy)]
14pub struct SyncFullStridedLoading {}
17
18impl LoadingValidation for SyncFullStridedLoading {
19 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
20 let line_size = config.global_line_size(ident);
21
22 let num_stage_lines = config.tiling_scheme().elements_in_stage(ident) / line_size;
23 let total_units = config.num_loading_planes(ident) * config.plane_dim();
24
25 if !num_stage_lines.is_multiple_of(total_units) {
26 return Err(Box::new(
27 "Too many data will be loaded, resulting in out of bounds.
28 Try setting line size and number of planes so that total unit count {:?} divides number of lines in stage.",
29 ));
30 }
31
32 StridedTilingLayout::check(config.global_memory_config(ident))?;
33
34 Ok(())
35 }
36}
37
38impl LoadMaxRoundPlaneCount for SyncFullStridedLoading {
39 fn max_round_plane_count(
40 tiling_scheme: &TilingScheme,
41 ident: MatmulIdent,
42 line_size: u8,
43 plane_dim: u32,
44 ) -> u32 {
45 let num_lines = tiling_scheme.elements_in_stage(ident) / line_size as u32;
46 num_lines.div_ceil(plane_dim)
47 }
48}
49
50#[cube]
51impl SyncFullLoadingStrategy for SyncFullStridedLoading {
52 type TilingLayout = StridedTilingLayout;
53 type Job<IP: MatrixPrecision> = SyncFullStridedJob;
54
55 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
56 #[comptime] ident: MatmulIdent,
57 #[comptime] line_size: u32,
58 #[comptime] config: G,
59 ) -> Self::Job<IP> {
60 let num_stage_lines = config.tiling_scheme().elements_in_stage(ident) / line_size;
61 let unit_count = config.num_loading_planes(ident) * config.plane_dim();
62 let num_tasks_per_unit = comptime!(num_stage_lines / unit_count);
63
64 let unit_position_base = RoleRule::new(config.role_rule_config())
65 .load_index(ident, config.specialized_loading_sides())
66 * config.plane_dim()
67 + UNIT_POS_X;
68
69 SyncFullStridedJob {
70 unit_position_base,
71 num_tasks_per_unit,
72 unit_count,
73 line_size,
74 ident,
75 }
76 }
77}
78
79#[derive(CubeType, Clone, Copy)]
80pub struct SyncFullStridedJob {
81 unit_position_base: u32,
82
83 #[cube(comptime)]
84 num_tasks_per_unit: u32,
85 #[cube(comptime)]
86 unit_count: u32,
87 #[cube(comptime)]
88 line_size: u32,
89 #[cube(comptime)]
90 ident: MatmulIdent,
91}
92
93#[cube]
94impl<IP: MatrixPrecision> LoadingJob<IP, StridedTilingLayout> for SyncFullStridedJob {
95 fn execute_task<G: GlobalConfig>(
96 this: &mut Self,
97 #[comptime] task_id: u32,
98 global_iter: &GlobalIterator<Line<IP::Global>>,
99 stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
100 #[comptime] config: G,
101 ) {
102 let unit_position = this.unit_position_base + task_id * this.unit_count;
103
104 let layout = FullStageLayout::new(comptime![config.global_memory_config(this.ident)]);
105 let view = global_iter.view().view(layout);
106
107 let line_read = view.read_checked(unit_position * this.line_size);
108
109 stage.as_slice_mut(this.line_size)[unit_position] = Line::cast_from(line_read);
110 }
111
112 fn task_count(this: &Self) -> comptime_type!(u32) {
113 this.num_tasks_per_unit
114 }
115}