1use std::marker::PhantomData;
2
3use crate::{
4 components::global::multi_stage::JobExecutor,
5 components::global::multi_stage::LoadMaxRoundPlaneCount,
6 components::global::read::LoadingJob,
7 components::global::read::LoadingValidation,
8 components::global::read::StageBuffer,
9 components::global::read::SyncStrategy,
10 components::global::read::TaskCounter,
11 components::global::{multi_stage::JobIterator, read::FullLoaderStage},
12 components::stage::TilingLayout,
13 components::{global::memory::GlobalIterator, stage::LoadStageFamily},
14 {components::global::GlobalReaderConfig, launch::RuntimeConfig},
15};
16use cubecl::{
17 prelude::*,
18 std::tensor::{View, layout::Coords2d},
19};
20use cubek_std::tile::TileKind;
21
22pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
23
24#[cube]
25pub trait FullLoadingStrategy<RC: RuntimeConfig>:
27 'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
28{
29 type TilingLayout: TilingLayout;
31 type SyncStrategy: SyncStrategy;
33 type Stage: LoadStageFamily<ReadOnly, TileKind = Self::TileKind>;
34 type TileKind: TileKind;
35
36 type Job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>: LoadingJob<EG, NG, ES, NS, Self::TilingLayout, Self::SyncStrategy, Stage = Self::Stage>;
38
39 const SHOULD_CLEAR: bool = false;
40
41 fn new_job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>(
43 config: RC,
44 #[comptime] config: GlobalReaderConfig,
45 ) -> Self::Job<EG, NG, ES, NS>;
46}
47
48#[derive(Clone, CubeType)]
49pub struct FullStageGlobalReader<
54 EG: Numeric,
55 NG: Size,
56 ES: Numeric,
57 NS: Size,
58 RC: RuntimeConfig,
59 L: FullLoadingStrategy<RC>,
60> {
61 global_iter: GlobalIterator<Vector<EG, NG>>,
62 runtime_config: RC,
63 stage: FullLoaderStage<RC, L, ES, NS>,
64 loading_job: ComptimeOption<L::Job<EG, NG, ES, NS>>,
65 #[cube(comptime)]
66 _phantom: PhantomData<L>,
67}
68
69#[cube]
70impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
71 FullStageGlobalReader<EG, NG, ES, NS, RC, L>
72{
73 pub fn new(
75 view: View<Vector<EG, NG>, Coords2d>,
76 runtime_config: RC,
77 k_step: u32,
78 #[comptime] config: GlobalReaderConfig,
79 ) -> Self {
80 let stage = L::Stage::create(128usize, config.smem_config);
83
84 let global_iter =
85 GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
86
87 let loading_job = match config.precompute_job {
88 true => ComptimeOption::new_Some(L::new_job::<EG, NG, ES, NS>(
89 runtime_config.clone(),
90 config,
91 )),
92 false => ComptimeOption::new_None(),
93 };
94
95 FullStageGlobalReader::<EG, NG, ES, NS, RC, L> {
96 global_iter,
97 runtime_config,
98 stage,
99 loading_job,
100 _phantom: PhantomData::<L>,
101 }
102 }
103
104 pub fn stage(&self) -> FullLoaderStage<RC, L, ES, NS> {
106 L::Stage::with_buffer_index(&self.stage, 0)
107 }
108
109 pub fn free_stage(self) {
111 L::Stage::free(&self.stage);
112 }
113
114 pub fn advance_view(&mut self) {
116 self.global_iter.advance();
117 }
118
119 pub fn load_stage(
121 &mut self,
122 barrier: &mut SyncBarrier<L::SyncStrategy>,
123 #[comptime] config: GlobalReaderConfig,
124 ) {
125 let mut loading_job = self
126 .loading_job
127 .clone()
128 .unwrap_or_else(|| L::new_job::<EG, NG, ES, NS>(self.runtime_config.clone(), config));
129
130 let len = L::Job::task_count(&loading_job);
131
132 #[unroll]
133 for task_id in 0..len {
134 L::Job::<EG, NG, ES, NS>::execute_task(
135 &mut loading_job,
136 task_id,
137 &self.global_iter,
138 &mut self.stage,
139 barrier,
140 config,
141 );
142 }
143 }
144}
145
146#[cube]
147impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
148 JobExecutor<L::SyncStrategy> for FullStageGlobalReader<EG, NG, ES, NS, RC, L>
149{
150 type JobIterator = FullStageJobIterator<EG, NG, ES, NS, RC, L>;
151
152 fn create_job_iterator(
153 this: &Self,
154 #[comptime] _stage_buffer: StageBuffer,
155 #[comptime] config: GlobalReaderConfig,
156 ) -> Self::JobIterator {
157 let job = this
158 .loading_job
159 .clone()
160 .unwrap_or_else(|| L::new_job::<EG, NG, ES, NS>(this.runtime_config.clone(), config));
161
162 let num_tasks = L::Job::task_count(&job);
163
164 FullStageJobIterator::<EG, NG, ES, NS, RC, L> {
165 job,
166 num_tasks,
167 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
168 }
169 }
170
171 fn execute_task(
172 this: &mut Self,
173 job_iterator: &mut FullStageJobIterator<EG, NG, ES, NS, RC, L>,
174 barrier: &mut SyncBarrier<L::SyncStrategy>,
175 #[comptime] config: GlobalReaderConfig,
176 ) {
177 let task_id = job_iterator.current.read().counter.comptime();
178
179 L::Job::<EG, NG, ES, NS>::execute_task(
180 &mut job_iterator.job,
181 task_id,
182 &this.global_iter,
183 &mut this.stage,
184 barrier,
185 config,
186 );
187
188 job_iterator.current.store(TaskCounter {
189 counter: task_id + 1,
190 });
191 }
192
193 fn execute_all_remaining_tasks(
194 this: &mut Self,
195 job_iterator: &mut Self::JobIterator,
196 barrier: &mut SyncBarrier<L::SyncStrategy>,
197 #[comptime] config: GlobalReaderConfig,
198 ) {
199 let task_counter = job_iterator.current.read().counter;
200
201 #[unroll]
202 for task_id in task_counter..job_iterator.num_tasks {
203 L::Job::<EG, NG, ES, NS>::execute_task(
204 &mut job_iterator.job,
205 task_id,
206 &this.global_iter,
207 &mut this.stage,
208 barrier,
209 config,
210 );
211 }
212
213 job_iterator.current.store(TaskCounter {
214 counter: job_iterator.num_tasks,
215 });
216 }
217
218 fn execute_whole_job(
219 this: &mut Self,
220 barrier: &mut SyncBarrier<L::SyncStrategy>,
221 #[comptime] stage_buffer: StageBuffer,
222 #[comptime] config: GlobalReaderConfig,
223 ) {
224 Self::execute_all_remaining_tasks(
225 this,
226 &mut Self::create_job_iterator(this, stage_buffer, config),
227 barrier,
228 config,
229 );
230 }
231}
232
233#[derive(CubeType)]
234pub struct FullStageJobIterator<
236 EG: Numeric,
237 NG: Size,
238 ES: Numeric,
239 NS: Size,
240 RC: RuntimeConfig,
241 L: FullLoadingStrategy<RC>,
242> {
243 job: L::Job<EG, NG, ES, NS>,
244 #[cube(comptime)]
245 pub num_tasks: u32,
246 pub current: ComptimeCell<TaskCounter>,
247}
248
249#[cube]
250impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
251 JobIterator for FullStageJobIterator<EG, NG, ES, NS, RC, L>
252{
253 fn current(this: &Self) -> comptime_type!(u32) {
254 this.current.read().counter
255 }
256
257 fn num_tasks(this: &Self) -> comptime_type!(u32) {
258 this.num_tasks
259 }
260}