cubecl_matmul/components/global/read/strategy/
sync_full_strided.rs1use crate::components::InvalidConfigError;
2use crate::components::MatmulElems;
3use crate::components::global::read::validate_swizzle_atom_size;
4use crate::components::global::read::{FullLoadingStrategy, stage::FullStageLayout};
5use crate::components::global::{GlobalReaderConfig, RoleRule};
6use crate::components::global::{multi_stage::LoadMaxRoundPlaneCount, read::sync::Synchronous};
7use crate::components::stage::StridedStageFamily;
8use crate::components::stage::{StridedStageMemory, StridedTilingLayout};
9use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
10use cubecl_core as cubecl;
11use cubecl_core::prelude::*;
12use cubecl_std::type_size;
13
14use super::{LoadingJob, LoadingValidation};
15
16#[derive(CubeType, Clone, Copy)]
17pub struct SyncFullStridedLoading {}
20
21impl LoadingValidation for SyncFullStridedLoading {
22 fn check<R: Runtime>(
23 _client: &ComputeClient<R::Server>,
24 config: &GlobalReaderConfig,
25 dtypes: &MatmulElems,
26 ) -> Result<(), InvalidConfigError> {
27 let line_size = config.gmem_config.line_size;
28
29 let num_stage_lines = config.smem_config.elements_per_stage() / line_size;
30 let total_units = config.loading_units_count();
31
32 if !num_stage_lines.is_multiple_of(total_units) {
33 return Err(Box::new(
34 "Too many data will be loaded, resulting in out of bounds.
35 Try setting line size and number of planes so that total unit count {:?} divides number of lines in stage.",
36 ));
37 }
38
39 validate_swizzle_atom_size(config.smem_config, config.stage_ident, dtypes)?;
40 StridedTilingLayout::check(config.smem_config)?;
41
42 Ok(())
43 }
44}
45
46impl LoadMaxRoundPlaneCount for SyncFullStridedLoading {
47 fn max_round_plane_count(
48 elements_per_tile: u32,
49 tiles_per_stage: u32,
50 line_size: u8,
51 plane_dim: u32,
52 ) -> u32 {
53 let elements_per_stage = elements_per_tile * tiles_per_stage;
54 let num_lines = elements_per_stage / line_size as u32;
55 num_lines.div_ceil(plane_dim)
56 }
57}
58
59#[cube]
60impl FullLoadingStrategy for SyncFullStridedLoading {
61 type TilingLayout = StridedTilingLayout;
62 type SyncStrategy = Synchronous;
63 type Job<EG: Numeric, ES: Numeric> = SyncFullStridedJob;
64
65 fn new_job<EG: Numeric, ES: Numeric>(
66 #[comptime] line_size: u32,
67 #[comptime] config: GlobalReaderConfig,
68 ) -> Self::Job<EG, ES> {
69 let num_stage_lines = config.smem_config.elements_per_stage() / line_size;
70 let unit_count = config.loading_planes_count() * config.plane_dim;
71 let num_tasks_per_unit = comptime!(num_stage_lines / unit_count);
72
73 let unit_position_base = RoleRule::new(config.plane_role_config.rule)
74 .load_index(config.specialization_tensor_config)
75 * config.plane_dim
76 + UNIT_POS_X;
77
78 SyncFullStridedJob {
79 unit_position_base,
80 num_tasks_per_unit,
81 unit_count,
82 line_size,
83 }
84 }
85}
86
87#[derive(CubeType, Clone, Copy)]
88pub struct SyncFullStridedJob {
89 unit_position_base: u32,
90
91 #[cube(comptime)]
92 num_tasks_per_unit: u32,
93 #[cube(comptime)]
94 unit_count: u32,
95 #[cube(comptime)]
96 line_size: u32,
97}
98
99#[cube]
100impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, StridedTilingLayout, Synchronous>
101 for SyncFullStridedJob
102{
103 type Stage = StridedStageFamily;
104
105 fn execute_task(
106 this: &mut Self,
107 #[comptime] task_id: u32,
108 global_iter: &GlobalIterator<Line<EG>>,
109 stage: &mut StridedStageMemory<ES, StridedTilingLayout>,
110 _barrier: &mut (),
111 #[comptime] config: GlobalReaderConfig,
112 ) {
113 let unit_position = this.unit_position_base + task_id * this.unit_count;
114
115 let layout = FullStageLayout::new(comptime![config.smem_config]);
116 let view = global_iter.view().view(layout);
117
118 let line_read = view.read_checked(unit_position * this.line_size);
119 let type_size = type_size::<ES>(this.line_size);
120 let stage_offs = stage.swizzle.apply(unit_position, type_size);
121
122 stage.as_slice_mut(this.line_size)[stage_offs] = Line::cast_from(line_read);
123 }
124
125 fn task_count(this: &Self) -> comptime_type!(u32) {
126 this.num_tasks_per_unit
127 }
128}