cubecl_matmul/components/global/read/reader/
full_reader.rs1use std::marker::PhantomData;
2
3use crate::components::StageIdent;
4use crate::components::global::GlobalReaderConfig;
5use crate::components::global::memory::GlobalIterator;
6use crate::components::global::multi_stage::JobExecutor;
7use crate::components::global::multi_stage::JobIterator;
8use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
9use crate::components::global::read::LoadingJob;
10use crate::components::global::read::LoadingValidation;
11use crate::components::global::read::StageBuffer;
12use crate::components::global::read::SyncStrategy;
13use crate::components::global::read::TaskCounter;
14use crate::components::stage::StridedStageFamily;
15use crate::components::stage::StridedStageMemory;
16use crate::components::stage::TilingLayout;
17use cubecl_core as cubecl;
18use cubecl_core::prelude::*;
19use cubecl_std::{
20 CubeOption, CubeOptionExpand,
21 tensor::{View, layout::Coords2d},
22};
23
24pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
25
26#[cube]
27pub trait FullLoadingStrategy:
29 'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
30{
31 type TilingLayout: TilingLayout;
33 type SyncStrategy: SyncStrategy;
35
36 type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = StridedStageFamily>;
38
39 const SHOULD_CLEAR: bool = false;
40
41 fn new_job<EG: Numeric, ES: Numeric>(
43 #[comptime] line_size: u32,
44 #[comptime] config: GlobalReaderConfig,
45 ) -> Self::Job<EG, ES>;
46}
47
48#[derive(Clone, CubeType)]
49pub struct FullStageGlobalReader<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
54 global_iter: GlobalIterator<Line<EG>>,
55 stage: StridedStageMemory<ES, L::TilingLayout>,
56 loading_job: CubeOption<L::Job<EG, ES>>,
57 #[cube(comptime)]
58 _phantom: PhantomData<L>,
59}
60
61#[cube]
62impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> FullStageGlobalReader<EG, ES, L> {
63 pub fn new(
65 view: View<Line<EG>, Coords2d>,
66 k_step: u32,
67 #[comptime] config: GlobalReaderConfig,
68 ) -> Self {
69 let mut stage = StridedStageMemory::new_aligned(128u32, config.smem_config);
72
73 let (shape_row, shape_col) = view.shape();
74 let global_iter =
75 GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
76
77 let loading_job = match config.precompute_job {
78 true => CubeOption::new_Some(L::new_job::<EG, ES>(view.line_size(), config)),
79 false => CubeOption::new_None(),
80 };
81
82 if L::SHOULD_CLEAR {
83 match config.stage_ident {
86 StageIdent::Lhs =>
87 {
88 #[allow(clippy::collapsible_if)]
89 if config.gmem_config.check_row_bounds {
90 if shape_row < config.smem_config.elements_per_stage_along_row() {
91 stage.clear_all(config);
92 }
93 }
94 }
95 StageIdent::Rhs =>
96 {
97 #[allow(clippy::collapsible_if)]
98 if config.gmem_config.check_col_bounds {
99 if shape_col < config.smem_config.elements_per_stage_along_col() {
100 stage.clear_all(config);
101 }
102 }
103 }
104 _ => comptime!(unreachable!()),
105 }
106 }
107
108 FullStageGlobalReader::<EG, ES, L> {
109 global_iter,
110 stage,
111 loading_job,
112 _phantom: PhantomData::<L>,
113 }
114 }
115
116 pub fn stage(&self) -> StridedStageMemory<ES, L::TilingLayout> {
118 self.stage
119 }
120
121 pub fn clear_stage(&mut self, #[comptime] config: GlobalReaderConfig) {
122 self.stage.clear_all(config);
123 }
124
125 pub fn free_stage(self) {
126 unsafe { self.stage.free() };
127 }
128
129 pub fn advance_view(&mut self) {
131 self.global_iter.advance();
132 }
133
134 pub fn load_stage(
136 &mut self,
137 barrier: &mut SyncBarrier<L::SyncStrategy>,
138 #[comptime] config: GlobalReaderConfig,
139 ) {
140 let mut loading_job = match self.loading_job {
141 CubeOption::Some(loading_job) => loading_job,
142 CubeOption::None => L::new_job::<EG, ES>(self.global_iter.line_size(), config),
143 };
144
145 let len = L::Job::task_count(&loading_job);
146
147 #[unroll]
148 for task_id in 0..len {
149 L::Job::<EG, ES>::execute_task(
150 &mut loading_job,
151 task_id,
152 &self.global_iter,
153 &mut self.stage,
154 barrier,
155 config,
156 );
157 }
158 }
159}
160
161#[cube]
162impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobExecutor<L::SyncStrategy>
163 for FullStageGlobalReader<EG, ES, L>
164{
165 type JobIterator = FullStageJobIterator<EG, ES, L>;
166
167 fn create_job_iterator(
168 this: &Self,
169 #[comptime] _stage_buffer: StageBuffer,
170 #[comptime] config: GlobalReaderConfig,
171 ) -> Self::JobIterator {
172 let view = this.global_iter.view();
173 let job = match this.loading_job {
174 CubeOption::Some(loading_job) => loading_job,
175 CubeOption::None => L::new_job::<EG, ES>(view.line_size(), config),
176 };
177
178 let num_tasks = L::Job::task_count(&job);
179
180 FullStageJobIterator::<EG, ES, L> {
181 job,
182 num_tasks,
183 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
184 }
185 }
186
187 fn execute_task(
188 this: &mut Self,
189 job_iterator: &mut FullStageJobIterator<EG, ES, L>,
190 barrier: &mut SyncBarrier<L::SyncStrategy>,
191 #[comptime] config: GlobalReaderConfig,
192 ) {
193 let task_id = job_iterator.current.read().counter;
194
195 L::Job::<EG, ES>::execute_task(
196 &mut job_iterator.job,
197 task_id,
198 &this.global_iter,
199 &mut this.stage,
200 barrier,
201 config,
202 );
203
204 job_iterator.current.store(TaskCounter {
205 counter: comptime!(task_id + 1u32),
206 });
207 }
208
209 fn execute_all_remaining_tasks(
210 this: &mut Self,
211 job_iterator: &mut Self::JobIterator,
212 barrier: &mut SyncBarrier<L::SyncStrategy>,
213 #[comptime] config: GlobalReaderConfig,
214 ) {
215 let task_counter = job_iterator.current.read().counter;
216
217 #[unroll]
218 for task_id in task_counter..job_iterator.num_tasks {
219 L::Job::<EG, ES>::execute_task(
220 &mut job_iterator.job,
221 task_id,
222 &this.global_iter,
223 &mut this.stage,
224 barrier,
225 config,
226 );
227 }
228
229 job_iterator.current.store(TaskCounter {
230 counter: comptime!(job_iterator.num_tasks),
231 });
232 }
233
234 fn execute_whole_job(
235 this: &mut Self,
236 barrier: &mut SyncBarrier<L::SyncStrategy>,
237 #[comptime] stage_buffer: StageBuffer,
238 #[comptime] config: GlobalReaderConfig,
239 ) {
240 Self::execute_all_remaining_tasks(
241 this,
242 &mut Self::create_job_iterator(this, stage_buffer, config),
243 barrier,
244 config,
245 );
246 }
247}
248
249#[derive(CubeType)]
250pub struct FullStageJobIterator<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
252 job: L::Job<EG, ES>,
253 #[cube(comptime)]
254 pub num_tasks: u32,
255 pub current: ComptimeCell<TaskCounter>,
256}
257
258#[cube]
259impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobIterator
260 for FullStageJobIterator<EG, ES, L>
261{
262 fn current(this: &Self) -> comptime_type!(u32) {
263 this.current.read().counter
264 }
265
266 fn num_tasks(this: &Self) -> comptime_type!(u32) {
267 this.num_tasks
268 }
269}