cubek_convolution/components/global/read/reader/
full_reader.rs1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl::std::{
5 CubeOption, CubeOptionExpand,
6 tensor::{View, layout::Coords2d},
7};
8use cubek_matmul::components::{
9 global::{
10 GlobalReaderConfig,
11 memory::GlobalIterator,
12 multi_stage::{JobExecutor, JobIterator, LoadMaxRoundPlaneCount},
13 read::{LoadingJob, LoadingValidation, StageBuffer, SyncStrategy, TaskCounter},
14 },
15 stage::{StridedStageFamily, StridedStageMemory, TilingLayout},
16};
17
18use crate::components::global::args::RuntimeArgs;
19
20pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
21
22#[cube]
23pub trait FullLoadingStrategy:
25 'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
26{
27 type TilingLayout: TilingLayout;
29 type SyncStrategy: SyncStrategy;
31
32 type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = StridedStageFamily>;
34
35 fn new_job<EG: Numeric, ES: Numeric>(
37 runtime_args: RuntimeArgs,
38 #[comptime] line_size: LineSize,
39 #[comptime] config: GlobalReaderConfig,
40 ) -> Self::Job<EG, ES>;
41}
42
43#[derive(Clone, CubeType)]
44pub struct FullStageGlobalReader<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
49 global_iter: GlobalIterator<Line<EG>>,
50 runtime_args: RuntimeArgs,
51 stage: StridedStageMemory<ES, L::TilingLayout>,
52 loading_job: CubeOption<L::Job<EG, ES>>,
53 #[cube(comptime)]
54 _phantom: PhantomData<L>,
55}
56
57#[cube]
58impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> FullStageGlobalReader<EG, ES, L> {
59 pub fn new(
61 view: View<Line<EG>, Coords2d>,
62 runtime_args: RuntimeArgs,
63 k_step: u32,
64 #[comptime] config: GlobalReaderConfig,
65 ) -> Self {
66 let stage = StridedStageMemory::new_aligned(128usize, config.smem_config);
69
70 let global_iter =
71 GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
72
73 let loading_job = match config.precompute_job {
74 true => CubeOption::new_Some(L::new_job::<EG, ES>(
75 runtime_args.clone(),
76 view.line_size(),
77 config,
78 )),
79 false => CubeOption::new_None(),
80 };
81
82 FullStageGlobalReader::<EG, ES, L> {
83 global_iter,
84 runtime_args,
85 stage,
86 loading_job,
87 _phantom: PhantomData::<L>,
88 }
89 }
90
91 pub fn stage(&self) -> StridedStageMemory<ES, L::TilingLayout> {
93 self.stage
94 }
95
96 pub fn clear_stage(&mut self, #[comptime] config: GlobalReaderConfig) {
97 self.stage.clear_all(config);
98 }
99
100 pub fn free_stage(self) {
101 unsafe { self.stage.free() };
102 }
103
104 pub fn advance_view(&mut self) {
106 self.global_iter.advance();
107 }
108
109 pub fn load_stage(
111 &mut self,
112 barrier: &mut SyncBarrier<L::SyncStrategy>,
113 #[comptime] config: GlobalReaderConfig,
114 ) {
115 let mut loading_job = match self.loading_job.clone() {
116 CubeOption::Some(loading_job) => loading_job,
117 CubeOption::None => L::new_job::<EG, ES>(
118 self.runtime_args.clone(),
119 self.global_iter.line_size(),
120 config,
121 ),
122 };
123
124 let len = L::Job::task_count(&loading_job);
125
126 #[unroll]
127 for task_id in 0..len {
128 L::Job::<EG, ES>::execute_task(
129 &mut loading_job,
130 task_id,
131 &self.global_iter,
132 &mut self.stage,
133 barrier,
134 config,
135 );
136 }
137 }
138}
139
140#[cube]
141impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobExecutor<L::SyncStrategy>
142 for FullStageGlobalReader<EG, ES, L>
143{
144 type JobIterator = FullStageJobIterator<EG, ES, L>;
145
146 fn create_job_iterator(
147 this: &Self,
148 #[comptime] _stage_buffer: StageBuffer,
149 #[comptime] config: GlobalReaderConfig,
150 ) -> Self::JobIterator {
151 let view = this.global_iter.view();
152 let job = match this.loading_job.clone() {
153 CubeOption::Some(loading_job) => loading_job,
154 CubeOption::None => {
155 L::new_job::<EG, ES>(this.runtime_args.clone(), view.line_size(), config)
156 }
157 };
158
159 let num_tasks = L::Job::task_count(&job);
160
161 FullStageJobIterator::<EG, ES, L> {
162 job,
163 num_tasks,
164 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
165 }
166 }
167
168 fn execute_task(
169 this: &mut Self,
170 job_iterator: &mut FullStageJobIterator<EG, ES, L>,
171 barrier: &mut SyncBarrier<L::SyncStrategy>,
172 #[comptime] config: GlobalReaderConfig,
173 ) {
174 let task_id = job_iterator.current.read().counter.comptime();
175
176 L::Job::<EG, ES>::execute_task(
177 &mut job_iterator.job,
178 task_id,
179 &this.global_iter,
180 &mut this.stage,
181 barrier,
182 config,
183 );
184
185 job_iterator.current.store(TaskCounter {
186 counter: task_id + 1,
187 });
188 }
189
190 fn execute_all_remaining_tasks(
191 this: &mut Self,
192 job_iterator: &mut Self::JobIterator,
193 barrier: &mut SyncBarrier<L::SyncStrategy>,
194 #[comptime] config: GlobalReaderConfig,
195 ) {
196 let task_counter = job_iterator.current.read().counter;
197
198 #[unroll]
199 for task_id in task_counter..job_iterator.num_tasks {
200 L::Job::<EG, ES>::execute_task(
201 &mut job_iterator.job,
202 task_id,
203 &this.global_iter,
204 &mut this.stage,
205 barrier,
206 config,
207 );
208 }
209
210 job_iterator.current.store(TaskCounter {
211 counter: job_iterator.num_tasks,
212 });
213 }
214
215 fn execute_whole_job(
216 this: &mut Self,
217 barrier: &mut SyncBarrier<L::SyncStrategy>,
218 #[comptime] stage_buffer: StageBuffer,
219 #[comptime] config: GlobalReaderConfig,
220 ) {
221 Self::execute_all_remaining_tasks(
222 this,
223 &mut Self::create_job_iterator(this, stage_buffer, config),
224 barrier,
225 config,
226 );
227 }
228}
229
230#[derive(CubeType)]
231pub struct FullStageJobIterator<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
233 job: L::Job<EG, ES>,
234 #[cube(comptime)]
235 pub num_tasks: u32,
236 pub current: ComptimeCell<TaskCounter>,
237}
238
239#[cube]
240impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobIterator
241 for FullStageJobIterator<EG, ES, L>
242{
243 fn current(this: &Self) -> comptime_type!(u32) {
244 this.current.read().counter
245 }
246
247 fn num_tasks(this: &Self) -> comptime_type!(u32) {
248 this.num_tasks
249 }
250}