cubecl_matmul/components/global/load/loader/
sync_partial_loader.rs1use std::marker::PhantomData;
2
3use super::StageIdent;
4use super::TaskCounter;
5use crate::components::InputIdent;
6use crate::components::MatmulPrecision;
7use crate::components::global::GlobalConfig;
8use crate::components::global::Quantization;
9use crate::components::global::global_memory::TensorReader;
10use crate::components::global::load::LoadingJob;
11use crate::components::global::load::LoadingValidation;
12use crate::components::global::multi_stage::JobExecutor;
13use crate::components::global::multi_stage::JobIterator;
14use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
15use crate::components::stage::PartialStageToTileReader;
16use crate::components::stage::StageMemory;
17use crate::components::stage::TilingLayout;
18use cubecl_core as cubecl;
19use cubecl_core::prelude::*;
20use cubecl_std::tensor::r#virtual::VirtualTensor;
21use cubecl_std::{CubeOption, CubeOptionExpand};
22
23#[cube]
24pub trait SyncPartialLoadingStrategy:
26 'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
27{
28 type TilingLayout: TilingLayout;
30
31 type Job<MP: MatmulPrecision>: LoadingJob<MP, Self::TilingLayout>;
33
34 fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
36 #[comptime] stage_index: u32,
37 #[comptime] ident: InputIdent,
38 #[comptime] config: G,
39 ) -> Self::Job<MP>;
40}
41
42#[derive(Clone, CubeType)]
43pub struct SyncPartialLoader<MP: MatmulPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy> {
48 tensor_reader: TensorReader<MP::EI>,
49 stage_memory: StageMemory<MP::ES, L::TilingLayout>,
50 loading_job: CubeOption<(L::Job<MP>, L::Job<MP>)>,
51 quantization: CubeOption<Quantization<MP>>,
52 #[cube(comptime)]
53 input_ident: InputIdent,
54 #[cube(comptime)]
55 _config: PhantomData<G>,
56}
57
58#[cube]
59impl<MP: MatmulPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy>
60 SyncPartialLoader<MP, G, L>
61{
62 pub fn new(
64 tensor: VirtualTensor<MP::EI>,
65 x_offset: u32,
66 y_offset: u32,
67 batch_offset: u32,
68 quantization: CubeOption<Quantization<MP>>,
69 #[comptime] input_ident: InputIdent,
70 #[comptime] config: G,
71 ) -> Self {
72 let stage_memory =
73 StageMemory::new::<G::StageConfig>(2u32, input_ident.as_ident(), config.stage_config());
74 let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
75
76 let loading_job = match config.precompute_job() {
77 true => CubeOption::new_Some((
78 L::new_job::<MP, G>(0u32, input_ident, config),
79 L::new_job::<MP, G>(1u32, input_ident, config),
80 )),
81 false => CubeOption::new_None(),
82 };
83
84 SyncPartialLoader::<MP, G, L> {
85 tensor_reader,
86 stage_memory,
87 loading_job,
88 quantization,
89 input_ident,
90 _config: PhantomData::<G>,
91 }
92 }
93
94 pub fn reader(
96 this: &Self,
97 #[comptime] stage_ident: StageIdent,
98 ) -> PartialStageToTileReader<MP::ES, L::TilingLayout> {
99 PartialStageToTileReader::new(this.stage_memory, stage_ident, this.input_ident)
100 }
101
102 pub fn advance_view(this: &mut Self, k_offset: u32) {
104 this.tensor_reader.update_view(k_offset, this.input_ident);
105 }
106
107 pub fn fill_stage(this: &mut Self, #[comptime] stage_ident: StageIdent, #[comptime] config: G) {
109 let mut loading_job = match this.loading_job {
110 CubeOption::Some(job) => match stage_ident {
111 StageIdent::A => job.0,
112 StageIdent::B => job.1,
113 },
114 CubeOption::None => match stage_ident {
115 StageIdent::A => L::new_job::<MP, G>(0u32, this.input_ident, config),
116 StageIdent::B => L::new_job::<MP, G>(1u32, this.input_ident, config),
117 },
118 };
119
120 let len = L::Job::task_count(&loading_job);
121
122 let mut task_id = comptime![0u32];
123
124 #[allow(clippy::explicit_counter_loop)]
125 #[unroll]
126 for _ in 0..len {
127 L::Job::<MP>::execute_task::<G>(
128 &mut loading_job,
129 task_id,
130 &this.tensor_reader,
131 &mut this.stage_memory,
132 &this.quantization,
133 config,
134 );
135 comptime![task_id += 1];
136 }
137 }
138}
139
140#[cube]
141impl<MP: MatmulPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy> JobExecutor<G>
142 for SyncPartialLoader<MP, G, L>
143{
144 type JobIterator = SyncPartialLoaderJobIterator<MP, L>;
145
146 fn create_job_iterator(
147 this: &Self,
148 #[comptime] stage_ident: StageIdent,
149 #[comptime] config: G,
150 ) -> Self::JobIterator {
151 let job = match this.loading_job {
152 CubeOption::Some(job) => match stage_ident {
153 StageIdent::A => job.0,
154 StageIdent::B => job.1,
155 },
156 CubeOption::None => match stage_ident {
157 StageIdent::A => L::new_job::<MP, G>(0u32, this.input_ident, config),
158 StageIdent::B => L::new_job::<MP, G>(1u32, this.input_ident, config),
159 },
160 };
161
162 let num_tasks = L::Job::task_count(&job);
163
164 SyncPartialLoaderJobIterator::<MP, L> {
165 job,
166 num_tasks,
167 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
168 }
169 }
170
171 fn execute_task(
172 this: &mut Self,
173 job_iterator: &mut SyncPartialLoaderJobIterator<MP, L>,
174 #[comptime] config: G,
175 ) {
176 let task_id = job_iterator.current.read().counter;
177
178 L::Job::<MP>::execute_task::<G>(
179 &mut job_iterator.job,
180 task_id,
181 &this.tensor_reader,
182 &mut this.stage_memory,
183 &this.quantization,
184 config,
185 );
186
187 job_iterator.current.store(TaskCounter {
188 counter: comptime!(task_id + 1u32),
189 });
190 }
191
192 fn execute_all_remaining_tasks(
193 this: &mut Self,
194 job_iterator: &mut Self::JobIterator,
195 #[comptime] config: G,
196 ) {
197 let task_counter = job_iterator.current.read().counter;
198
199 let mut task_id = comptime![task_counter];
200
201 #[allow(clippy::explicit_counter_loop)]
202 #[unroll]
203 for _ in task_counter..job_iterator.num_tasks {
204 L::Job::<MP>::execute_task::<G>(
205 &mut job_iterator.job,
206 task_id,
207 &this.tensor_reader,
208 &mut this.stage_memory,
209 &this.quantization,
210 config,
211 );
212 comptime![task_id += 1];
213 }
214
215 job_iterator.current.store(TaskCounter {
216 counter: comptime!(job_iterator.num_tasks),
217 });
218 }
219
220 fn execute_whole_job(
221 this: &mut Self,
222 #[comptime] stage_ident: StageIdent,
223 #[comptime] config: G,
224 ) {
225 Self::execute_all_remaining_tasks(
226 this,
227 &mut Self::create_job_iterator(this, stage_ident, config),
228 config,
229 );
230 }
231}
232
233#[derive(CubeType)]
234pub struct SyncPartialLoaderJobIterator<MP: MatmulPrecision, L: SyncPartialLoadingStrategy> {
236 job: L::Job<MP>,
237 #[cube(comptime)]
238 pub num_tasks: u32,
239 pub current: ComptimeCell<TaskCounter>,
240}
241
242#[cube]
243impl<MP: MatmulPrecision, L: SyncPartialLoadingStrategy> JobIterator
244 for SyncPartialLoaderJobIterator<MP, L>
245{
246 fn current(this: &Self) -> comptime_type!(u32) {
247 this.current.read().counter
248 }
249
250 fn num_tasks(this: &Self) -> comptime_type!(u32) {
251 this.num_tasks
252 }
253}