cubecl_convolution/components/global/read/reader/
full_reader.rs1use std::marker::PhantomData;
2
3use cubecl_core as cubecl;
4use cubecl_core::prelude::*;
5use cubecl_matmul::components::{
6 global::{
7 GlobalReaderConfig,
8 memory::GlobalIterator,
9 multi_stage::{JobExecutor, JobIterator, LoadMaxRoundPlaneCount},
10 read::{LoadingJob, LoadingValidation, StageBuffer, SyncStrategy, TaskCounter},
11 },
12 stage::{StridedStageFamily, StridedStageMemory, TilingLayout},
13};
14use cubecl_std::{
15 CubeOption, CubeOptionExpand,
16 tensor::{View, layout::Coords2d},
17};
18
19use crate::components::global::args::RuntimeArgs;
20
21pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
22
23#[cube]
24pub trait FullLoadingStrategy:
26 'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
27{
28 type TilingLayout: TilingLayout;
30 type SyncStrategy: SyncStrategy;
32
33 type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = StridedStageFamily>;
35
36 fn new_job<EG: Numeric, ES: Numeric>(
38 runtime_args: RuntimeArgs,
39 #[comptime] line_size: u32,
40 #[comptime] config: GlobalReaderConfig,
41 ) -> Self::Job<EG, ES>;
42}
43
44#[derive(Clone, CubeType)]
45pub struct FullStageGlobalReader<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
50 global_iter: GlobalIterator<Line<EG>>,
51 runtime_args: RuntimeArgs,
52 stage: StridedStageMemory<ES, L::TilingLayout>,
53 loading_job: CubeOption<L::Job<EG, ES>>,
54 #[cube(comptime)]
55 _phantom: PhantomData<L>,
56}
57
58#[cube]
59impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> FullStageGlobalReader<EG, ES, L> {
60 pub fn new(
62 view: View<Line<EG>, Coords2d>,
63 runtime_args: RuntimeArgs,
64 k_step: u32,
65 #[comptime] config: GlobalReaderConfig,
66 ) -> Self {
67 let stage = StridedStageMemory::new_aligned(128u32, config.smem_config);
70
71 let global_iter =
72 GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
73
74 let loading_job = match config.precompute_job {
75 true => CubeOption::new_Some(L::new_job::<EG, ES>(
76 runtime_args.clone(),
77 view.line_size(),
78 config,
79 )),
80 false => CubeOption::new_None(),
81 };
82
83 FullStageGlobalReader::<EG, ES, L> {
84 global_iter,
85 runtime_args,
86 stage,
87 loading_job,
88 _phantom: PhantomData::<L>,
89 }
90 }
91
92 pub fn stage(&self) -> StridedStageMemory<ES, L::TilingLayout> {
94 self.stage
95 }
96
97 pub fn clear_stage(&mut self, #[comptime] config: GlobalReaderConfig) {
98 self.stage.clear_all(config);
99 }
100
101 pub fn free_stage(self) {
102 unsafe { self.stage.free() };
103 }
104
105 pub fn advance_view(&mut self) {
107 self.global_iter.advance();
108 }
109
110 pub fn load_stage(
112 &mut self,
113 barrier: &mut SyncBarrier<L::SyncStrategy>,
114 #[comptime] config: GlobalReaderConfig,
115 ) {
116 let mut loading_job = match self.loading_job.clone() {
117 CubeOption::Some(loading_job) => loading_job,
118 CubeOption::None => L::new_job::<EG, ES>(
119 self.runtime_args.clone(),
120 self.global_iter.line_size(),
121 config,
122 ),
123 };
124
125 let len = L::Job::task_count(&loading_job);
126
127 #[unroll]
128 for task_id in 0..len {
129 L::Job::<EG, ES>::execute_task(
130 &mut loading_job,
131 task_id,
132 &self.global_iter,
133 &mut self.stage,
134 barrier,
135 config,
136 );
137 }
138 }
139}
140
141#[cube]
142impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobExecutor<L::SyncStrategy>
143 for FullStageGlobalReader<EG, ES, L>
144{
145 type JobIterator = FullStageJobIterator<EG, ES, L>;
146
147 fn create_job_iterator(
148 this: &Self,
149 #[comptime] _stage_buffer: StageBuffer,
150 #[comptime] config: GlobalReaderConfig,
151 ) -> Self::JobIterator {
152 let view = this.global_iter.view();
153 let job = match this.loading_job.clone() {
154 CubeOption::Some(loading_job) => loading_job,
155 CubeOption::None => {
156 L::new_job::<EG, ES>(this.runtime_args.clone(), view.line_size(), config)
157 }
158 };
159
160 let num_tasks = L::Job::task_count(&job);
161
162 FullStageJobIterator::<EG, ES, L> {
163 job,
164 num_tasks,
165 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
166 }
167 }
168
169 fn execute_task(
170 this: &mut Self,
171 job_iterator: &mut FullStageJobIterator<EG, ES, L>,
172 barrier: &mut SyncBarrier<L::SyncStrategy>,
173 #[comptime] config: GlobalReaderConfig,
174 ) {
175 let task_id = job_iterator.current.read().counter;
176
177 L::Job::<EG, ES>::execute_task(
178 &mut job_iterator.job,
179 task_id,
180 &this.global_iter,
181 &mut this.stage,
182 barrier,
183 config,
184 );
185
186 job_iterator.current.store(TaskCounter {
187 counter: comptime!(task_id + 1u32),
188 });
189 }
190
191 fn execute_all_remaining_tasks(
192 this: &mut Self,
193 job_iterator: &mut Self::JobIterator,
194 barrier: &mut SyncBarrier<L::SyncStrategy>,
195 #[comptime] config: GlobalReaderConfig,
196 ) {
197 let task_counter = job_iterator.current.read().counter;
198
199 #[unroll]
200 for task_id in task_counter..job_iterator.num_tasks {
201 L::Job::<EG, ES>::execute_task(
202 &mut job_iterator.job,
203 task_id,
204 &this.global_iter,
205 &mut this.stage,
206 barrier,
207 config,
208 );
209 }
210
211 job_iterator.current.store(TaskCounter {
212 counter: comptime!(job_iterator.num_tasks),
213 });
214 }
215
216 fn execute_whole_job(
217 this: &mut Self,
218 barrier: &mut SyncBarrier<L::SyncStrategy>,
219 #[comptime] stage_buffer: StageBuffer,
220 #[comptime] config: GlobalReaderConfig,
221 ) {
222 Self::execute_all_remaining_tasks(
223 this,
224 &mut Self::create_job_iterator(this, stage_buffer, config),
225 barrier,
226 config,
227 );
228 }
229}
230
231#[derive(CubeType)]
232pub struct FullStageJobIterator<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
234 job: L::Job<EG, ES>,
235 #[cube(comptime)]
236 pub num_tasks: u32,
237 pub current: ComptimeCell<TaskCounter>,
238}
239
240#[cube]
241impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobIterator
242 for FullStageJobIterator<EG, ES, L>
243{
244 fn current(this: &Self) -> comptime_type!(u32) {
245 this.current.read().counter
246 }
247
248 fn num_tasks(this: &Self) -> comptime_type!(u32) {
249 this.num_tasks
250 }
251}