cubek_matmul/components/global/read/reader/
full_reader.rs1use std::marker::PhantomData;
2
3use crate::components::global::multi_stage::JobExecutor;
4use crate::components::global::read::LoadingJob;
5use crate::components::global::read::LoadingValidation;
6use crate::components::global::read::StageBuffer;
7use crate::components::global::read::SyncStrategy;
8use crate::components::global::read::TaskCounter;
9use crate::components::global::{multi_stage::JobIterator, read::FullLoaderStage};
10use crate::components::stage::TilingLayout;
11use crate::components::{global::memory::GlobalIterator, stage::LoadStageFamily};
12use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, tile::io::TileKind};
13use crate::{components::global::GlobalReaderConfig, launch::RuntimeConfig};
14use cubecl::prelude::*;
15use cubecl::std::{
16 CubeOption, CubeOptionExpand,
17 tensor::{View, layout::Coords2d},
18};
19
20pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
21
22#[cube]
23pub trait FullLoadingStrategy<RC: RuntimeConfig>:
25 'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
26{
27 type TilingLayout: TilingLayout;
29 type SyncStrategy: SyncStrategy;
31 type Stage: LoadStageFamily<ReadOnly, TileKind = Self::TileKind>;
32 type TileKind: TileKind;
33
34 type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = Self::Stage>;
36
37 const SHOULD_CLEAR: bool = false;
38
39 fn new_job<EG: Numeric, ES: Numeric>(
41 config: RC,
42 #[comptime] line_size: LineSize,
43 #[comptime] config: GlobalReaderConfig,
44 ) -> Self::Job<EG, ES>;
45}
46
47#[derive(Clone, CubeType)]
48pub struct FullStageGlobalReader<
53 EG: Numeric,
54 ES: Numeric,
55 RC: RuntimeConfig,
56 L: FullLoadingStrategy<RC>,
57> {
58 global_iter: GlobalIterator<Line<EG>>,
59 runtime_config: RC,
60 stage: FullLoaderStage<RC, L, ES>,
61 loading_job: CubeOption<L::Job<EG, ES>>,
62 #[cube(comptime)]
63 _phantom: PhantomData<L>,
64}
65
66#[cube]
67impl<EG: Numeric, ES: Numeric, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
68 FullStageGlobalReader<EG, ES, RC, L>
69{
70 pub fn new(
72 view: View<Line<EG>, Coords2d>,
73 runtime_config: RC,
74 k_step: u32,
75 #[comptime] config: GlobalReaderConfig,
76 ) -> Self {
77 let stage = L::Stage::create(128usize, config.smem_config);
80
81 let global_iter =
82 GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
83
84 let loading_job = match config.precompute_job {
85 true => CubeOption::new_Some(L::new_job::<EG, ES>(
86 runtime_config.clone(),
87 view.line_size(),
88 config,
89 )),
90 false => CubeOption::new_None(),
91 };
92
93 FullStageGlobalReader::<EG, ES, RC, L> {
94 global_iter,
95 runtime_config,
96 stage,
97 loading_job,
98 _phantom: PhantomData::<L>,
99 }
100 }
101
102 pub fn stage(&self) -> FullLoaderStage<RC, L, ES> {
104 L::Stage::with_buffer_index(&self.stage, 0)
105 }
106
107 pub fn free_stage(self) {
109 L::Stage::free(&self.stage);
110 }
111
112 pub fn advance_view(&mut self) {
114 self.global_iter.advance();
115 }
116
117 pub fn load_stage(
119 &mut self,
120 barrier: &mut SyncBarrier<L::SyncStrategy>,
121 #[comptime] config: GlobalReaderConfig,
122 ) {
123 let mut loading_job = match self.loading_job.clone() {
124 CubeOption::Some(loading_job) => loading_job,
125 CubeOption::None => L::new_job::<EG, ES>(
126 self.runtime_config.clone(),
127 self.global_iter.line_size(),
128 config,
129 ),
130 };
131
132 let len = L::Job::task_count(&loading_job);
133
134 #[unroll]
135 for task_id in 0..len {
136 L::Job::<EG, ES>::execute_task(
137 &mut loading_job,
138 task_id,
139 &self.global_iter,
140 &mut self.stage,
141 barrier,
142 config,
143 );
144 }
145 }
146}
147
148#[cube]
149impl<EG: Numeric, ES: Numeric, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
150 JobExecutor<L::SyncStrategy> for FullStageGlobalReader<EG, ES, RC, L>
151{
152 type JobIterator = FullStageJobIterator<EG, ES, RC, L>;
153
154 fn create_job_iterator(
155 this: &Self,
156 #[comptime] _stage_buffer: StageBuffer,
157 #[comptime] config: GlobalReaderConfig,
158 ) -> Self::JobIterator {
159 let view = this.global_iter.view();
160 let job = match this.loading_job.clone() {
161 CubeOption::Some(loading_job) => loading_job,
162 CubeOption::None => {
163 L::new_job::<EG, ES>(this.runtime_config.clone(), view.line_size(), config)
164 }
165 };
166
167 let num_tasks = L::Job::task_count(&job);
168
169 FullStageJobIterator::<EG, ES, RC, L> {
170 job,
171 num_tasks,
172 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
173 }
174 }
175
176 fn execute_task(
177 this: &mut Self,
178 job_iterator: &mut FullStageJobIterator<EG, ES, RC, L>,
179 barrier: &mut SyncBarrier<L::SyncStrategy>,
180 #[comptime] config: GlobalReaderConfig,
181 ) {
182 let task_id = job_iterator.current.read().counter.comptime();
183
184 L::Job::<EG, ES>::execute_task(
185 &mut job_iterator.job,
186 task_id,
187 &this.global_iter,
188 &mut this.stage,
189 barrier,
190 config,
191 );
192
193 job_iterator.current.store(TaskCounter {
194 counter: task_id + 1,
195 });
196 }
197
198 fn execute_all_remaining_tasks(
199 this: &mut Self,
200 job_iterator: &mut Self::JobIterator,
201 barrier: &mut SyncBarrier<L::SyncStrategy>,
202 #[comptime] config: GlobalReaderConfig,
203 ) {
204 let task_counter = job_iterator.current.read().counter;
205
206 #[unroll]
207 for task_id in task_counter..job_iterator.num_tasks {
208 L::Job::<EG, ES>::execute_task(
209 &mut job_iterator.job,
210 task_id,
211 &this.global_iter,
212 &mut this.stage,
213 barrier,
214 config,
215 );
216 }
217
218 job_iterator.current.store(TaskCounter {
219 counter: job_iterator.num_tasks,
220 });
221 }
222
223 fn execute_whole_job(
224 this: &mut Self,
225 barrier: &mut SyncBarrier<L::SyncStrategy>,
226 #[comptime] stage_buffer: StageBuffer,
227 #[comptime] config: GlobalReaderConfig,
228 ) {
229 Self::execute_all_remaining_tasks(
230 this,
231 &mut Self::create_job_iterator(this, stage_buffer, config),
232 barrier,
233 config,
234 );
235 }
236}
237
238#[derive(CubeType)]
239pub struct FullStageJobIterator<
241 EG: Numeric,
242 ES: Numeric,
243 RC: RuntimeConfig,
244 L: FullLoadingStrategy<RC>,
245> {
246 job: L::Job<EG, ES>,
247 #[cube(comptime)]
248 pub num_tasks: u32,
249 pub current: ComptimeCell<TaskCounter>,
250}
251
252#[cube]
253impl<EG: Numeric, ES: Numeric, RC: RuntimeConfig, L: FullLoadingStrategy<RC>> JobIterator
254 for FullStageJobIterator<EG, ES, RC, L>
255{
256 fn current(this: &Self) -> comptime_type!(u32) {
257 this.current.read().counter
258 }
259
260 fn num_tasks(this: &Self) -> comptime_type!(u32) {
261 this.num_tasks
262 }
263}