cubecl_matmul/components/global/read/reader/
sync_full_reader.rs1use std::marker::PhantomData;
2
3use crate::components::MatmulIdent;
4use crate::components::MatrixPrecision;
5use crate::components::global::GlobalConfig;
6use crate::components::global::memory::GlobalIterator;
7use crate::components::global::multi_stage::JobExecutor;
8use crate::components::global::multi_stage::JobIterator;
9use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
10use crate::components::global::read::LoadingJob;
11use crate::components::global::read::LoadingValidation;
12use crate::components::global::read::StageBuffer;
13use crate::components::global::read::TaskCounter;
14use crate::components::stage::StridedStage;
15use crate::components::stage::TilingLayout;
16use cubecl_core as cubecl;
17use cubecl_core::prelude::*;
18use cubecl_std::{
19 CubeOption, CubeOptionExpand,
20 tensor::{View, layout::Coords2d},
21};
22
23#[cube]
24pub trait SyncFullLoadingStrategy:
26 'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
27{
28 type TilingLayout: TilingLayout;
30
31 type Job<IP: MatrixPrecision>: LoadingJob<IP, Self::TilingLayout>;
33
34 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
36 #[comptime] ident: MatmulIdent,
37 #[comptime] line_size: u32,
38 #[comptime] config: G,
39 ) -> Self::Job<IP>;
40}
41
42#[derive(Clone, CubeType)]
43pub struct SyncFullStageGlobalReader<
48 IP: MatrixPrecision,
49 G: GlobalConfig,
50 L: SyncFullLoadingStrategy,
51> {
52 global_iter: GlobalIterator<Line<IP::Global>>,
53 stage: StridedStage<IP::Stage, L::TilingLayout>,
54 loading_job: CubeOption<L::Job<IP>>,
55 #[cube(comptime)]
56 ident: MatmulIdent,
57 #[cube(comptime)]
58 _phantom: PhantomData<(G, L)>,
59}
60
61#[cube]
62impl<IP: MatrixPrecision, G: GlobalConfig, L: SyncFullLoadingStrategy>
63 SyncFullStageGlobalReader<IP, G, L>
64{
65 pub fn new(
67 tensor: View<Line<IP::Global>, Coords2d>,
68 k_step: u32,
69 #[comptime] ident: MatmulIdent,
70 #[comptime] config: G,
71 ) -> Self {
72 let stage = StridedStage::new(
73 comptime!(ident.into_stage()),
74 config.stage_memory_config(ident),
75 );
76 let global_iter = GlobalIterator::new(tensor, k_step, ident.view_direction(), false);
77
78 let loading_job = match config.precompute_job() {
79 true => CubeOption::new_Some(L::new_job::<IP, G>(ident, tensor.line_size(), config)),
80 false => CubeOption::new_None(),
81 };
82
83 SyncFullStageGlobalReader::<IP, G, L> {
84 global_iter,
85 stage,
86 loading_job,
87 ident,
88 _phantom: PhantomData::<(G, L)>,
89 }
90 }
91
92 pub fn stage(&self) -> StridedStage<IP::Stage, L::TilingLayout> {
94 self.stage
95 }
96
97 pub fn free_stage(self) {
98 unsafe { self.stage.free() };
99 }
100
101 pub fn advance_view(&mut self) {
103 self.global_iter.advance();
104 }
105
106 pub fn load_stage(&mut self, #[comptime] config: G) {
108 let mut loading_job = match self.loading_job {
109 CubeOption::Some(loading_job) => loading_job,
110 CubeOption::None => {
111 L::new_job::<IP, G>(self.ident, self.global_iter.line_size(), config)
112 }
113 };
114
115 let len = L::Job::task_count(&loading_job);
116
117 #[unroll]
118 for task_id in 0..len {
119 L::Job::<IP>::execute_task::<G>(
120 &mut loading_job,
121 task_id,
122 &self.global_iter,
123 &mut self.stage,
124 config,
125 );
126 }
127 }
128}
129
130#[cube]
131impl<IP: MatrixPrecision, G: GlobalConfig, L: SyncFullLoadingStrategy> JobExecutor<G>
132 for SyncFullStageGlobalReader<IP, G, L>
133{
134 type JobIterator = SyncFullStageJobIterator<IP, L>;
135
136 fn create_job_iterator(
137 this: &Self,
138 #[comptime] _stage_buffer: StageBuffer,
139 #[comptime] config: G,
140 ) -> Self::JobIterator {
141 let view = this.global_iter.view();
142 let job = match this.loading_job {
143 CubeOption::Some(loading_job) => loading_job,
144 CubeOption::None => L::new_job::<IP, G>(this.ident, view.line_size(), config),
145 };
146
147 let num_tasks = L::Job::task_count(&job);
148
149 SyncFullStageJobIterator::<IP, L> {
150 job,
151 num_tasks,
152 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
153 }
154 }
155
156 fn execute_task(
157 this: &mut Self,
158 job_iterator: &mut SyncFullStageJobIterator<IP, L>,
159 #[comptime] config: G,
160 ) {
161 let task_id = job_iterator.current.read().counter;
162
163 L::Job::<IP>::execute_task::<G>(
164 &mut job_iterator.job,
165 task_id,
166 &this.global_iter,
167 &mut this.stage,
168 config,
169 );
170
171 job_iterator.current.store(TaskCounter {
172 counter: comptime!(task_id + 1u32),
173 });
174 }
175
176 fn execute_all_remaining_tasks(
177 this: &mut Self,
178 job_iterator: &mut Self::JobIterator,
179 #[comptime] config: G,
180 ) {
181 let task_counter = job_iterator.current.read().counter;
182
183 let mut task_id = comptime![task_counter];
184
185 #[allow(clippy::explicit_counter_loop)]
186 #[unroll]
187 for _ in task_counter..job_iterator.num_tasks {
188 L::Job::<IP>::execute_task::<G>(
189 &mut job_iterator.job,
190 task_id,
191 &this.global_iter,
192 &mut this.stage,
193 config,
194 );
195 comptime![task_id += 1];
196 }
197
198 job_iterator.current.store(TaskCounter {
199 counter: comptime!(job_iterator.num_tasks),
200 });
201 }
202
203 fn execute_whole_job(
204 this: &mut Self,
205 #[comptime] stage_buffer: StageBuffer,
206 #[comptime] config: G,
207 ) {
208 Self::execute_all_remaining_tasks(
209 this,
210 &mut Self::create_job_iterator(this, stage_buffer, config),
211 config,
212 );
213 }
214}
215
216#[derive(CubeType)]
217pub struct SyncFullStageJobIterator<IP: MatrixPrecision, L: SyncFullLoadingStrategy> {
219 job: L::Job<IP>,
220 #[cube(comptime)]
221 pub num_tasks: u32,
222 pub current: ComptimeCell<TaskCounter>,
223}
224
225#[cube]
226impl<IP: MatrixPrecision, L: SyncFullLoadingStrategy> JobIterator
227 for SyncFullStageJobIterator<IP, L>
228{
229 fn current(this: &Self) -> comptime_type!(u32) {
230 this.current.read().counter
231 }
232
233 fn num_tasks(this: &Self) -> comptime_type!(u32) {
234 this.num_tasks
235 }
236}