cubecl_matmul/components/global/read/strategy/
async_full_cyclic.rs1use std::marker::PhantomData;
2
3use crate::components::{
4 InvalidConfigError, MatmulElems, MatrixLayout,
5 global::{
6 GlobalReaderConfig, RoleRule,
7 memory::{GlobalIterator, load_window_in_tile},
8 multi_stage::LoadMaxRoundPlaneCount,
9 read::{
10 FullLoadingStrategy, LoadingJob, async_barrier::AsyncBarrier, validate_async_barrier,
11 validate_noswizzle,
12 },
13 },
14 stage::{
15 ContiguousTilingLayout, StridedStageFamily, StridedStageMemory, TilingOrder,
16 TilingValidation,
17 },
18};
19use cubecl_core::prelude::{barrier::Barrier, *};
20use cubecl_core::{self as cubecl};
21
22use super::LoadingValidation;
23
24#[derive(CubeType, Clone, Copy)]
25pub struct AsyncFullCyclicLoading<T: TilingOrder> {
28 #[cube(comptime)]
29 _phantom: PhantomData<T>,
30}
31
32impl<T: TilingOrder> LoadingValidation for AsyncFullCyclicLoading<T> {
33 fn check<R: Runtime>(
34 client: &ComputeClient<R::Server>,
35 config: &GlobalReaderConfig,
36 _dtypes: &MatmulElems,
37 ) -> Result<(), InvalidConfigError> {
38 let total_units = config.loading_planes_count() * config.plane_dim;
39 let num_slices =
40 config.smem_config.elements_per_tile_along_row * config.smem_config.tiles_per_stage();
41
42 if num_slices >= total_units && !num_slices.is_multiple_of(total_units) {
43 return Err(Box::new(format!(
44 "Number of units ({total_units:?}) must divide number of slices ({num_slices:?}). Would require units doing different numbers of slices"
45 )));
46 }
47
48 ContiguousTilingLayout::<T>::check(config.smem_config)?;
49 validate_async_barrier::<R>(client)?;
50 validate_noswizzle(config.smem_config)?;
51
52 Ok(())
53 }
54}
55
56impl<TO: TilingOrder> LoadMaxRoundPlaneCount for AsyncFullCyclicLoading<TO> {
57 fn max_round_plane_count(
58 _elements_per_tile: u32,
59 _tiles_per_stage: u32,
60 _line_size: u8,
61 _plane_dim: u32,
62 ) -> u32 {
63 4
66 }
67}
68
69#[cube]
70impl<TO: TilingOrder> FullLoadingStrategy for AsyncFullCyclicLoading<TO> {
71 type TilingLayout = ContiguousTilingLayout<TO>;
72 type SyncStrategy = AsyncBarrier;
73 type Job<EG: Numeric, ES: Numeric> = AsyncFullCyclicJob;
74
75 const SHOULD_CLEAR: bool = true;
76
77 fn new_job<EG: Numeric, ES: Numeric>(
78 #[comptime] line_size: u32,
79 #[comptime] config: GlobalReaderConfig,
80 ) -> AsyncFullCyclicJob {
81 let total_units = config.loading_units_count();
82
83 let (num_slices_per_tile, slice_length_in_lines) = match config.gmem_config.matrix_layout {
84 MatrixLayout::RowMajor => (
85 config.smem_config.elements_per_tile_along_row,
86 config.smem_config.elements_per_tile_along_col / line_size,
87 ),
88 MatrixLayout::ColMajor => (
89 config.smem_config.elements_per_tile_along_col,
90 config.smem_config.elements_per_tile_along_row / line_size,
91 ),
92 };
93
94 let num_slices = comptime!(num_slices_per_tile * config.smem_config.tiles_per_stage());
95 let num_tasks_per_unit = num_slices.div_ceil(total_units);
96
97 let unit_id = RoleRule::new(config.plane_role_config.rule)
98 .load_index(config.specialization_tensor_config)
99 * config.plane_dim
100 + UNIT_POS_X;
101
102 AsyncFullCyclicJob {
103 unit_id,
104 num_tasks_per_unit,
105 total_units,
106 num_slices,
107 num_slices_per_tile,
108 slice_length_in_lines,
109 line_size,
110 }
111 }
112}
113
114#[derive(CubeType, Clone, Copy)]
115pub struct AsyncFullCyclicJob {
116 unit_id: u32,
117
118 #[cube(comptime)]
119 num_tasks_per_unit: u32,
120 #[cube(comptime)]
121 total_units: u32,
122 #[cube(comptime)]
123 num_slices: u32,
124 #[cube(comptime)]
125 num_slices_per_tile: u32,
126 #[cube(comptime)]
127 slice_length_in_lines: u32,
128 #[cube(comptime)]
129 line_size: u32,
130}
131
132#[cube]
133impl<EG: Numeric, ES: Numeric, TO: TilingOrder>
134 LoadingJob<EG, ES, ContiguousTilingLayout<TO>, AsyncBarrier> for AsyncFullCyclicJob
135{
136 type Stage = StridedStageFamily;
137
138 fn execute_task(
139 this: &mut Self,
140 #[comptime] task_id: u32,
141 global_iter: &GlobalIterator<Line<EG>>,
142 stage: &mut StridedStageMemory<ES, ContiguousTilingLayout<TO>>,
143 barrier: &mut Barrier,
144 #[comptime] config: GlobalReaderConfig,
145 ) {
146 let slice_index = this.unit_id + this.total_units * task_id;
147
148 let nth_tile = slice_index / this.num_slices_per_tile;
149 let (tile_x, tile_y) =
150 ContiguousTilingLayout::<TO>::to_x_y(nth_tile, comptime!(config.smem_config));
151 let nth_slice = slice_index % this.num_slices_per_tile;
152
153 if slice_index < this.num_slices {
155 let window = load_window_in_tile(
156 &global_iter.view(),
157 (tile_x, tile_y),
158 nth_slice,
159 config.smem_config,
160 config.gmem_config,
161 );
162
163 let slice_destination_offset =
165 (nth_tile * this.num_slices_per_tile + nth_slice) * this.slice_length_in_lines;
166
167 let mut destination = stage.as_slice_mut(this.line_size).slice_mut(
169 slice_destination_offset,
170 slice_destination_offset + this.slice_length_in_lines,
171 );
172
173 barrier.memcpy_async(&window.try_cast_unchecked(), &mut destination);
174 }
175 }
176
177 fn task_count(this: &Self) -> comptime_type!(u32) {
178 this.num_tasks_per_unit
179 }
180}