cubecl_matmul/components/global/read/reader/
partial_reader.rs1use super::StageBuffer;
2use super::TaskCounter;
3use crate::components::global::GlobalReaderConfig;
4use crate::components::global::memory::GlobalIterator;
5use crate::components::global::multi_stage::JobExecutor;
6use crate::components::global::multi_stage::JobIterator;
7use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
8use crate::components::global::read::LoadingJob;
9use crate::components::global::read::LoadingValidation;
10use crate::components::global::read::SyncBarrier;
11use crate::components::global::read::SyncStrategy;
12use crate::components::stage::LoadStageFamily;
13use crate::components::stage::StageFamily;
14use crate::components::stage::TilingLayout;
15use cubecl_core as cubecl;
16use cubecl_core::prelude::*;
17use cubecl_std::{
18 CubeOption, CubeOptionExpand,
19 tensor::{View, layout::Coords2d},
20};
21
22pub type LoaderStage<L, IP> = <<L as PartialLoadingStrategy>::Stage as StageFamily>::Stage<
23 IP,
24 <L as PartialLoadingStrategy>::TilingLayout,
25>;
26
27#[cube]
28pub trait PartialLoadingStrategy:
30 'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
31{
32 type TilingLayout: TilingLayout;
34 type SyncStrategy: SyncStrategy;
35 type Stage: LoadStageFamily<ReadOnly>;
36
37 type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = Self::Stage>;
39
40 fn new_job<EG: Numeric, ES: Numeric>(
42 #[comptime] stage_index: u32,
43 #[comptime] line_size: u32,
44 #[comptime] config: GlobalReaderConfig,
45 ) -> Self::Job<EG, ES>;
46}
47
48#[derive(Clone, CubeType)]
49#[allow(clippy::type_complexity)]
50pub struct PartialStageGlobalReader<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> {
55 global_iter: GlobalIterator<Line<EG>>,
56 stage_memory: LoaderStage<L, ES>,
57 loading_job: CubeOption<(L::Job<EG, ES>, L::Job<EG, ES>)>,
58}
59
60#[cube]
61impl<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> PartialStageGlobalReader<EG, ES, L> {
62 pub fn new(
64 tensor: View<Line<EG>, Coords2d>,
65 k_step: u32,
66 #[comptime] config: GlobalReaderConfig,
67 ) -> Self {
68 let stage_memory = L::Stage::create(128u32, config.smem_config);
69 let global_iter =
70 GlobalIterator::new(tensor, k_step, config.gmem_config.view_direction, false);
71
72 let loading_job = match config.precompute_job {
73 true => CubeOption::new_Some((
74 L::new_job::<EG, ES>(0u32, tensor.line_size(), config),
75 L::new_job::<EG, ES>(1u32, tensor.line_size(), config),
76 )),
77 false => CubeOption::new_None(),
78 };
79
80 PartialStageGlobalReader::<EG, ES, L> {
81 global_iter,
82 stage_memory,
83 loading_job,
84 }
85 }
86
87 pub fn stage(&self, #[comptime] stage_buffer: StageBuffer) -> LoaderStage<L, ES> {
89 L::Stage::with_buffer_index(&self.stage_memory, stage_buffer.to_index())
90 }
91
92 pub fn free_stage(self) {
94 L::Stage::free(&self.stage_memory);
95 }
96
97 pub fn advance_view(&mut self) {
99 self.global_iter.advance();
100 }
101
102 pub fn load_stage(
104 &mut self,
105 barrier: &mut SyncBarrier<L::SyncStrategy>,
106 #[comptime] stage_buffer: StageBuffer,
107 #[comptime] config: GlobalReaderConfig,
108 ) {
109 let mut loading_job = match self.loading_job {
110 CubeOption::Some(job) => match stage_buffer {
111 StageBuffer::A => job.0,
112 StageBuffer::B => job.1,
113 },
114 CubeOption::None => match stage_buffer {
115 StageBuffer::A => L::new_job::<EG, ES>(0u32, self.global_iter.line_size(), config),
116 StageBuffer::B => L::new_job::<EG, ES>(1u32, self.global_iter.line_size(), config),
117 },
118 };
119
120 let len = L::Job::task_count(&loading_job);
121
122 #[unroll]
123 for task_id in 0..len {
124 L::Job::<EG, ES>::execute_task(
125 &mut loading_job,
126 task_id,
127 &self.global_iter,
128 &mut self.stage_memory,
129 barrier,
130 config,
131 );
132 }
133 }
134}
135
136#[cube]
137impl<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> JobExecutor<L::SyncStrategy>
138 for PartialStageGlobalReader<EG, ES, L>
139{
140 type JobIterator = PartialJobIterator<EG, ES, L>;
141
142 fn create_job_iterator(
143 this: &Self,
144 #[comptime] stage_buffer: StageBuffer,
145 #[comptime] config: GlobalReaderConfig,
146 ) -> Self::JobIterator {
147 let view = this.global_iter.view();
148 let job = match this.loading_job {
149 CubeOption::Some(job) => match stage_buffer {
150 StageBuffer::A => job.0,
151 StageBuffer::B => job.1,
152 },
153 CubeOption::None => match stage_buffer {
154 StageBuffer::A => L::new_job::<EG, ES>(0u32, view.line_size(), config),
155 StageBuffer::B => L::new_job::<EG, ES>(1u32, view.line_size(), config),
156 },
157 };
158
159 let num_tasks = L::Job::task_count(&job);
160
161 PartialJobIterator::<EG, ES, L> {
162 job,
163 num_tasks,
164 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
165 }
166 }
167
168 fn execute_task(
169 this: &mut Self,
170 job_iterator: &mut PartialJobIterator<EG, ES, L>,
171 barrier: &mut SyncBarrier<L::SyncStrategy>,
172 #[comptime] config: GlobalReaderConfig,
173 ) {
174 let task_id = job_iterator.current.read().counter;
175
176 L::Job::<EG, ES>::execute_task(
177 &mut job_iterator.job,
178 task_id,
179 &this.global_iter,
180 &mut this.stage_memory,
181 barrier,
182 config,
183 );
184
185 job_iterator.current.store(TaskCounter {
186 counter: comptime!(task_id + 1u32),
187 });
188 }
189
190 fn execute_all_remaining_tasks(
191 this: &mut Self,
192 job_iterator: &mut Self::JobIterator,
193 barrier: &mut SyncBarrier<L::SyncStrategy>,
194 #[comptime] config: GlobalReaderConfig,
195 ) {
196 let task_counter = job_iterator.current.read().counter;
197
198 #[unroll]
199 for task_id in task_counter..job_iterator.num_tasks {
200 L::Job::<EG, ES>::execute_task(
201 &mut job_iterator.job,
202 task_id,
203 &this.global_iter,
204 &mut this.stage_memory,
205 barrier,
206 config,
207 );
208 }
209
210 job_iterator.current.store(TaskCounter {
211 counter: comptime!(job_iterator.num_tasks),
212 });
213 }
214
215 fn execute_whole_job(
216 this: &mut Self,
217 barrier: &mut SyncBarrier<L::SyncStrategy>,
218 #[comptime] stage_buffer: StageBuffer,
219 #[comptime] config: GlobalReaderConfig,
220 ) {
221 Self::execute_all_remaining_tasks(
222 this,
223 &mut Self::create_job_iterator(this, stage_buffer, config),
224 barrier,
225 config,
226 );
227 }
228}
229
230#[derive(CubeType)]
231pub struct PartialJobIterator<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> {
233 job: L::Job<EG, ES>,
234 #[cube(comptime)]
235 pub num_tasks: u32,
236 pub current: ComptimeCell<TaskCounter>,
237}
238
239#[cube]
240impl<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> JobIterator
241 for PartialJobIterator<EG, ES, L>
242{
243 fn current(this: &Self) -> comptime_type!(u32) {
244 this.current.read().counter
245 }
246
247 fn num_tasks(this: &Self) -> comptime_type!(u32) {
248 this.num_tasks
249 }
250}