cubecl_matmul/components/global/load/loader/
sync_partial_loader.rs

1use std::marker::PhantomData;
2
3use super::StageIdent;
4use super::TaskCounter;
5use crate::components::InputIdent;
6use crate::components::MatmulPrecision;
7use crate::components::global::GlobalConfig;
8use crate::components::global::Quantization;
9use crate::components::global::global_memory::TensorReader;
10use crate::components::global::load::LoadingJob;
11use crate::components::global::load::LoadingValidation;
12use crate::components::global::multi_stage::JobExecutor;
13use crate::components::global::multi_stage::JobIterator;
14use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
15use crate::components::stage::PartialStageToTileReader;
16use crate::components::stage::StageMemory;
17use crate::components::stage::TilingLayout;
18use cubecl_core as cubecl;
19use cubecl_core::prelude::*;
20use cubecl_std::tensor::r#virtual::VirtualTensor;
21use cubecl_std::{CubeOption, CubeOptionExpand};
22
23#[cube]
24/// A strategy for synchronously loading partial stage memory
25pub trait SyncPartialLoadingStrategy:
26    'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
27{
28    /// The layout describing how data is tiled across the stage.
29    type TilingLayout: TilingLayout;
30
31    /// The [LoadingJob] for this strategy.
32    type Job<MP: MatmulPrecision>: LoadingJob<MP, Self::TilingLayout>;
33
34    /// Returns the job with preliminary calculations done.
35    fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
36        #[comptime] stage_index: u32,
37        #[comptime] ident: InputIdent,
38        #[comptime] config: G,
39    ) -> Self::Job<MP>;
40}
41
42#[derive(Clone, CubeType)]
43/// Loads a stage from stage memory using synchronous data movement operations.
44///
45/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
46/// each Task represents a single data transfer for a specific unit
47pub struct SyncPartialLoader<MP: MatmulPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy> {
48    tensor_reader: TensorReader<MP::EI>,
49    stage_memory: StageMemory<MP::ES, L::TilingLayout>,
50    loading_job: CubeOption<(L::Job<MP>, L::Job<MP>)>,
51    quantization: CubeOption<Quantization<MP>>,
52    #[cube(comptime)]
53    input_ident: InputIdent,
54    #[cube(comptime)]
55    _config: PhantomData<G>,
56}
57
58#[cube]
59impl<MP: MatmulPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy>
60    SyncPartialLoader<MP, G, L>
61{
62    /// Create a new SyncPartialLoader
63    pub fn new(
64        tensor: VirtualTensor<MP::EI>,
65        x_offset: u32,
66        y_offset: u32,
67        batch_offset: u32,
68        quantization: CubeOption<Quantization<MP>>,
69        #[comptime] input_ident: InputIdent,
70        #[comptime] config: G,
71    ) -> Self {
72        let stage_memory =
73            StageMemory::new::<G::StageConfig>(2u32, input_ident.as_ident(), config.stage_config());
74        let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
75
76        let loading_job = match config.precompute_job() {
77            true => CubeOption::new_Some((
78                L::new_job::<MP, G>(0u32, input_ident, config),
79                L::new_job::<MP, G>(1u32, input_ident, config),
80            )),
81            false => CubeOption::new_None(),
82        };
83
84        SyncPartialLoader::<MP, G, L> {
85            tensor_reader,
86            stage_memory,
87            loading_job,
88            quantization,
89            input_ident,
90            _config: PhantomData::<G>,
91        }
92    }
93
94    /// Give a reader to the loaded stage memory.
95    pub fn reader(
96        this: &Self,
97        #[comptime] stage_ident: StageIdent,
98    ) -> PartialStageToTileReader<MP::ES, L::TilingLayout> {
99        PartialStageToTileReader::new(this.stage_memory, stage_ident, this.input_ident)
100    }
101
102    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
103    pub fn advance_view(this: &mut Self, k_offset: u32) {
104        this.tensor_reader.update_view(k_offset, this.input_ident);
105    }
106
107    /// Accomplish the entire job of filling the stage memory
108    pub fn fill_stage(this: &mut Self, #[comptime] stage_ident: StageIdent, #[comptime] config: G) {
109        let mut loading_job = match this.loading_job {
110            CubeOption::Some(job) => match stage_ident {
111                StageIdent::A => job.0,
112                StageIdent::B => job.1,
113            },
114            CubeOption::None => match stage_ident {
115                StageIdent::A => L::new_job::<MP, G>(0u32, this.input_ident, config),
116                StageIdent::B => L::new_job::<MP, G>(1u32, this.input_ident, config),
117            },
118        };
119
120        let len = L::Job::task_count(&loading_job);
121
122        let mut task_id = comptime![0u32];
123
124        #[allow(clippy::explicit_counter_loop)]
125        #[unroll]
126        for _ in 0..len {
127            L::Job::<MP>::execute_task::<G>(
128                &mut loading_job,
129                task_id,
130                &this.tensor_reader,
131                &mut this.stage_memory,
132                &this.quantization,
133                config,
134            );
135            comptime![task_id += 1];
136        }
137    }
138}
139
140#[cube]
141impl<MP: MatmulPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy> JobExecutor<G>
142    for SyncPartialLoader<MP, G, L>
143{
144    type JobIterator = SyncPartialLoaderJobIterator<MP, L>;
145
146    fn create_job_iterator(
147        this: &Self,
148        #[comptime] stage_ident: StageIdent,
149        #[comptime] config: G,
150    ) -> Self::JobIterator {
151        let job = match this.loading_job {
152            CubeOption::Some(job) => match stage_ident {
153                StageIdent::A => job.0,
154                StageIdent::B => job.1,
155            },
156            CubeOption::None => match stage_ident {
157                StageIdent::A => L::new_job::<MP, G>(0u32, this.input_ident, config),
158                StageIdent::B => L::new_job::<MP, G>(1u32, this.input_ident, config),
159            },
160        };
161
162        let num_tasks = L::Job::task_count(&job);
163
164        SyncPartialLoaderJobIterator::<MP, L> {
165            job,
166            num_tasks,
167            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
168        }
169    }
170
171    fn execute_task(
172        this: &mut Self,
173        job_iterator: &mut SyncPartialLoaderJobIterator<MP, L>,
174        #[comptime] config: G,
175    ) {
176        let task_id = job_iterator.current.read().counter;
177
178        L::Job::<MP>::execute_task::<G>(
179            &mut job_iterator.job,
180            task_id,
181            &this.tensor_reader,
182            &mut this.stage_memory,
183            &this.quantization,
184            config,
185        );
186
187        job_iterator.current.store(TaskCounter {
188            counter: comptime!(task_id + 1u32),
189        });
190    }
191
192    fn execute_all_remaining_tasks(
193        this: &mut Self,
194        job_iterator: &mut Self::JobIterator,
195        #[comptime] config: G,
196    ) {
197        let task_counter = job_iterator.current.read().counter;
198
199        let mut task_id = comptime![task_counter];
200
201        #[allow(clippy::explicit_counter_loop)]
202        #[unroll]
203        for _ in task_counter..job_iterator.num_tasks {
204            L::Job::<MP>::execute_task::<G>(
205                &mut job_iterator.job,
206                task_id,
207                &this.tensor_reader,
208                &mut this.stage_memory,
209                &this.quantization,
210                config,
211            );
212            comptime![task_id += 1];
213        }
214
215        job_iterator.current.store(TaskCounter {
216            counter: comptime!(job_iterator.num_tasks),
217        });
218    }
219
220    fn execute_whole_job(
221        this: &mut Self,
222        #[comptime] stage_ident: StageIdent,
223        #[comptime] config: G,
224    ) {
225        Self::execute_all_remaining_tasks(
226            this,
227            &mut Self::create_job_iterator(this, stage_ident, config),
228            config,
229        );
230    }
231}
232
233#[derive(CubeType)]
234/// Accomplish the entire job of filling the stage
235pub struct SyncPartialLoaderJobIterator<MP: MatmulPrecision, L: SyncPartialLoadingStrategy> {
236    job: L::Job<MP>,
237    #[cube(comptime)]
238    pub num_tasks: u32,
239    pub current: ComptimeCell<TaskCounter>,
240}
241
242#[cube]
243impl<MP: MatmulPrecision, L: SyncPartialLoadingStrategy> JobIterator
244    for SyncPartialLoaderJobIterator<MP, L>
245{
246    fn current(this: &Self) -> comptime_type!(u32) {
247        this.current.read().counter
248    }
249
250    fn num_tasks(this: &Self) -> comptime_type!(u32) {
251        this.num_tasks
252    }
253}