cubecl_matmul/components/global/read/reader/
sync_partial_reader.rs1use std::marker::PhantomData;
2
3use super::StageBuffer;
4use super::TaskCounter;
5use crate::components::MatmulIdent;
6use crate::components::MatrixPrecision;
7use crate::components::global::GlobalConfig;
8use crate::components::global::memory::GlobalIterator;
9use crate::components::global::multi_stage::JobExecutor;
10use crate::components::global::multi_stage::JobIterator;
11use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
12use crate::components::global::read::LoadingJob;
13use crate::components::global::read::LoadingValidation;
14use crate::components::stage::StridedStage;
15use crate::components::stage::TilingLayout;
16use cubecl_core as cubecl;
17use cubecl_core::prelude::*;
18use cubecl_std::{
19 CubeOption, CubeOptionExpand,
20 tensor::{View, layout::Coords2d},
21};
22
23#[cube]
24pub trait SyncPartialLoadingStrategy:
26 'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
27{
28 type TilingLayout: TilingLayout;
30
31 type Job<IP: MatrixPrecision>: LoadingJob<IP, Self::TilingLayout>;
33
34 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
36 #[comptime] stage_index: u32,
37 #[comptime] ident: MatmulIdent,
38 #[comptime] line_size: u32,
39 #[comptime] config: G,
40 ) -> Self::Job<IP>;
41}
42
43#[derive(Clone, CubeType)]
44pub struct SyncPartialStageGlobalReader<
49 IP: MatrixPrecision,
50 G: GlobalConfig,
51 L: SyncPartialLoadingStrategy,
52> {
53 global_iter: GlobalIterator<Line<IP::Global>>,
54 stage_memory: StridedStage<IP::Stage, L::TilingLayout>,
55 loading_job: CubeOption<(L::Job<IP>, L::Job<IP>)>,
56 #[cube(comptime)]
57 ident: MatmulIdent,
58 #[cube(comptime)]
59 _config: PhantomData<G>,
60}
61
62#[cube]
63impl<IP: MatrixPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy>
64 SyncPartialStageGlobalReader<IP, G, L>
65{
66 pub fn new(
68 tensor: View<Line<IP::Global>, Coords2d>,
69 k_step: u32,
70 #[comptime] ident: MatmulIdent,
71 #[comptime] config: G,
72 ) -> Self {
73 let stage_memory = StridedStage::new(
74 comptime!(ident.into_stage()),
75 config.stage_memory_config(ident),
76 );
77 let global_iter = GlobalIterator::new(tensor, k_step, ident.view_direction(), false);
78
79 let loading_job = match config.precompute_job() {
80 true => CubeOption::new_Some((
81 L::new_job::<IP, G>(0u32, ident, tensor.line_size(), config),
82 L::new_job::<IP, G>(1u32, ident, tensor.line_size(), config),
83 )),
84 false => CubeOption::new_None(),
85 };
86
87 SyncPartialStageGlobalReader::<IP, G, L> {
88 global_iter,
89 stage_memory,
90 loading_job,
91 ident,
92 _config: PhantomData::<G>,
93 }
94 }
95
96 pub fn stage(
98 &self,
99 #[comptime] stage_buffer: StageBuffer,
100 ) -> StridedStage<IP::Stage, L::TilingLayout> {
101 self.stage_memory.with_buffer_index(stage_buffer.to_index())
102 }
103
104 pub fn advance_view(&mut self) {
106 self.global_iter.advance();
107 }
108
109 pub fn load_stage(&mut self, #[comptime] stage_buffer: StageBuffer, #[comptime] config: G) {
111 let mut loading_job = match self.loading_job {
112 CubeOption::Some(job) => match stage_buffer {
113 StageBuffer::A => job.0,
114 StageBuffer::B => job.1,
115 },
116 CubeOption::None => match stage_buffer {
117 StageBuffer::A => {
118 L::new_job::<IP, G>(0u32, self.ident, self.global_iter.line_size(), config)
119 }
120 StageBuffer::B => {
121 L::new_job::<IP, G>(1u32, self.ident, self.global_iter.line_size(), config)
122 }
123 },
124 };
125
126 let len = L::Job::task_count(&loading_job);
127
128 #[unroll]
129 for task_id in 0..len {
130 L::Job::<IP>::execute_task::<G>(
131 &mut loading_job,
132 task_id,
133 &self.global_iter,
134 &mut self.stage_memory,
135 config,
136 );
137 }
138 }
139}
140
141#[cube]
142impl<IP: MatrixPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy> JobExecutor<G>
143 for SyncPartialStageGlobalReader<IP, G, L>
144{
145 type JobIterator = SyncPartialJobIterator<IP, L>;
146
147 fn create_job_iterator(
148 this: &Self,
149 #[comptime] stage_buffer: StageBuffer,
150 #[comptime] config: G,
151 ) -> Self::JobIterator {
152 let view = this.global_iter.view();
153 let job = match this.loading_job {
154 CubeOption::Some(job) => match stage_buffer {
155 StageBuffer::A => job.0,
156 StageBuffer::B => job.1,
157 },
158 CubeOption::None => match stage_buffer {
159 StageBuffer::A => L::new_job::<IP, G>(0u32, this.ident, view.line_size(), config),
160 StageBuffer::B => L::new_job::<IP, G>(1u32, this.ident, view.line_size(), config),
161 },
162 };
163
164 let num_tasks = L::Job::task_count(&job);
165
166 SyncPartialJobIterator::<IP, L> {
167 job,
168 num_tasks,
169 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
170 }
171 }
172
173 fn execute_task(
174 this: &mut Self,
175 job_iterator: &mut SyncPartialJobIterator<IP, L>,
176 #[comptime] config: G,
177 ) {
178 let task_id = job_iterator.current.read().counter;
179
180 L::Job::<IP>::execute_task::<G>(
181 &mut job_iterator.job,
182 task_id,
183 &this.global_iter,
184 &mut this.stage_memory,
185 config,
186 );
187
188 job_iterator.current.store(TaskCounter {
189 counter: comptime!(task_id + 1u32),
190 });
191 }
192
193 fn execute_all_remaining_tasks(
194 this: &mut Self,
195 job_iterator: &mut Self::JobIterator,
196 #[comptime] config: G,
197 ) {
198 let task_counter = job_iterator.current.read().counter;
199
200 let mut task_id = comptime![task_counter];
201
202 #[allow(clippy::explicit_counter_loop)]
203 #[unroll]
204 for _ in task_counter..job_iterator.num_tasks {
205 L::Job::<IP>::execute_task::<G>(
206 &mut job_iterator.job,
207 task_id,
208 &this.global_iter,
209 &mut this.stage_memory,
210 config,
211 );
212 comptime![task_id += 1];
213 }
214
215 job_iterator.current.store(TaskCounter {
216 counter: comptime!(job_iterator.num_tasks),
217 });
218 }
219
220 fn execute_whole_job(
221 this: &mut Self,
222 #[comptime] stage_buffer: StageBuffer,
223 #[comptime] config: G,
224 ) {
225 Self::execute_all_remaining_tasks(
226 this,
227 &mut Self::create_job_iterator(this, stage_buffer, config),
228 config,
229 );
230 }
231}
232
233#[derive(CubeType)]
234pub struct SyncPartialJobIterator<IP: MatrixPrecision, L: SyncPartialLoadingStrategy> {
236 job: L::Job<IP>,
237 #[cube(comptime)]
238 pub num_tasks: u32,
239 pub current: ComptimeCell<TaskCounter>,
240}
241
242#[cube]
243impl<IP: MatrixPrecision, L: SyncPartialLoadingStrategy> JobIterator
244 for SyncPartialJobIterator<IP, L>
245{
246 fn current(this: &Self) -> comptime_type!(u32) {
247 this.current.read().counter
248 }
249
250 fn num_tasks(this: &Self) -> comptime_type!(u32) {
251 this.num_tasks
252 }
253}