cubecl_matmul/components/global/read/reader/
sync_partial_reader.rs

1use std::marker::PhantomData;
2
3use super::StageBuffer;
4use super::TaskCounter;
5use crate::components::MatmulIdent;
6use crate::components::MatrixPrecision;
7use crate::components::global::GlobalConfig;
8use crate::components::global::memory::GlobalIterator;
9use crate::components::global::multi_stage::JobExecutor;
10use crate::components::global::multi_stage::JobIterator;
11use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
12use crate::components::global::read::LoadingJob;
13use crate::components::global::read::LoadingValidation;
14use crate::components::stage::StridedStage;
15use crate::components::stage::TilingLayout;
16use cubecl_core as cubecl;
17use cubecl_core::prelude::*;
18use cubecl_std::{
19    CubeOption, CubeOptionExpand,
20    tensor::{View, layout::Coords2d},
21};
22
23#[cube]
24/// A strategy for synchronously loading partial stage memory
25pub trait SyncPartialLoadingStrategy:
26    'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
27{
28    /// The layout describing how data is tiled across the stage.
29    type TilingLayout: TilingLayout;
30
31    /// The [LoadingJob] for this strategy.
32    type Job<IP: MatrixPrecision>: LoadingJob<IP, Self::TilingLayout>;
33
34    /// Returns the job with preliminary calculations done.
35    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
36        #[comptime] stage_index: u32,
37        #[comptime] ident: MatmulIdent,
38        #[comptime] line_size: u32,
39        #[comptime] config: G,
40    ) -> Self::Job<IP>;
41}
42
43#[derive(Clone, CubeType)]
44/// Loads a stage from stage memory using synchronous data movement operations.
45///
46/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
47/// each Task represents a single data transfer for a specific unit
48pub struct SyncPartialStageGlobalReader<
49    IP: MatrixPrecision,
50    G: GlobalConfig,
51    L: SyncPartialLoadingStrategy,
52> {
53    global_iter: GlobalIterator<Line<IP::Global>>,
54    stage_memory: StridedStage<IP::Stage, L::TilingLayout>,
55    loading_job: CubeOption<(L::Job<IP>, L::Job<IP>)>,
56    #[cube(comptime)]
57    ident: MatmulIdent,
58    #[cube(comptime)]
59    _config: PhantomData<G>,
60}
61
62#[cube]
63impl<IP: MatrixPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy>
64    SyncPartialStageGlobalReader<IP, G, L>
65{
66    /// Create a new SyncPartialStageGlobalReader
67    pub fn new(
68        tensor: View<Line<IP::Global>, Coords2d>,
69        k_step: u32,
70        #[comptime] ident: MatmulIdent,
71        #[comptime] config: G,
72    ) -> Self {
73        let stage_memory = StridedStage::new(
74            comptime!(ident.into_stage()),
75            config.stage_memory_config(ident),
76        );
77        let global_iter = GlobalIterator::new(tensor, k_step, ident.view_direction(), false);
78
79        let loading_job = match config.precompute_job() {
80            true => CubeOption::new_Some((
81                L::new_job::<IP, G>(0u32, ident, tensor.line_size(), config),
82                L::new_job::<IP, G>(1u32, ident, tensor.line_size(), config),
83            )),
84            false => CubeOption::new_None(),
85        };
86
87        SyncPartialStageGlobalReader::<IP, G, L> {
88            global_iter,
89            stage_memory,
90            loading_job,
91            ident,
92            _config: PhantomData::<G>,
93        }
94    }
95
96    /// Give a reader to the loaded stage memory.
97    pub fn stage(
98        &self,
99        #[comptime] stage_buffer: StageBuffer,
100    ) -> StridedStage<IP::Stage, L::TilingLayout> {
101        self.stage_memory.with_buffer_index(stage_buffer.to_index())
102    }
103
104    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
105    pub fn advance_view(&mut self) {
106        self.global_iter.advance();
107    }
108
109    /// Accomplish the entire job of loading data into the stage memory
110    pub fn load_stage(&mut self, #[comptime] stage_buffer: StageBuffer, #[comptime] config: G) {
111        let mut loading_job = match self.loading_job {
112            CubeOption::Some(job) => match stage_buffer {
113                StageBuffer::A => job.0,
114                StageBuffer::B => job.1,
115            },
116            CubeOption::None => match stage_buffer {
117                StageBuffer::A => {
118                    L::new_job::<IP, G>(0u32, self.ident, self.global_iter.line_size(), config)
119                }
120                StageBuffer::B => {
121                    L::new_job::<IP, G>(1u32, self.ident, self.global_iter.line_size(), config)
122                }
123            },
124        };
125
126        let len = L::Job::task_count(&loading_job);
127
128        #[unroll]
129        for task_id in 0..len {
130            L::Job::<IP>::execute_task::<G>(
131                &mut loading_job,
132                task_id,
133                &self.global_iter,
134                &mut self.stage_memory,
135                config,
136            );
137        }
138    }
139}
140
141#[cube]
142impl<IP: MatrixPrecision, G: GlobalConfig, L: SyncPartialLoadingStrategy> JobExecutor<G>
143    for SyncPartialStageGlobalReader<IP, G, L>
144{
145    type JobIterator = SyncPartialJobIterator<IP, L>;
146
147    fn create_job_iterator(
148        this: &Self,
149        #[comptime] stage_buffer: StageBuffer,
150        #[comptime] config: G,
151    ) -> Self::JobIterator {
152        let view = this.global_iter.view();
153        let job = match this.loading_job {
154            CubeOption::Some(job) => match stage_buffer {
155                StageBuffer::A => job.0,
156                StageBuffer::B => job.1,
157            },
158            CubeOption::None => match stage_buffer {
159                StageBuffer::A => L::new_job::<IP, G>(0u32, this.ident, view.line_size(), config),
160                StageBuffer::B => L::new_job::<IP, G>(1u32, this.ident, view.line_size(), config),
161            },
162        };
163
164        let num_tasks = L::Job::task_count(&job);
165
166        SyncPartialJobIterator::<IP, L> {
167            job,
168            num_tasks,
169            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
170        }
171    }
172
173    fn execute_task(
174        this: &mut Self,
175        job_iterator: &mut SyncPartialJobIterator<IP, L>,
176        #[comptime] config: G,
177    ) {
178        let task_id = job_iterator.current.read().counter;
179
180        L::Job::<IP>::execute_task::<G>(
181            &mut job_iterator.job,
182            task_id,
183            &this.global_iter,
184            &mut this.stage_memory,
185            config,
186        );
187
188        job_iterator.current.store(TaskCounter {
189            counter: comptime!(task_id + 1u32),
190        });
191    }
192
193    fn execute_all_remaining_tasks(
194        this: &mut Self,
195        job_iterator: &mut Self::JobIterator,
196        #[comptime] config: G,
197    ) {
198        let task_counter = job_iterator.current.read().counter;
199
200        let mut task_id = comptime![task_counter];
201
202        #[allow(clippy::explicit_counter_loop)]
203        #[unroll]
204        for _ in task_counter..job_iterator.num_tasks {
205            L::Job::<IP>::execute_task::<G>(
206                &mut job_iterator.job,
207                task_id,
208                &this.global_iter,
209                &mut this.stage_memory,
210                config,
211            );
212            comptime![task_id += 1];
213        }
214
215        job_iterator.current.store(TaskCounter {
216            counter: comptime!(job_iterator.num_tasks),
217        });
218    }
219
220    fn execute_whole_job(
221        this: &mut Self,
222        #[comptime] stage_buffer: StageBuffer,
223        #[comptime] config: G,
224    ) {
225        Self::execute_all_remaining_tasks(
226            this,
227            &mut Self::create_job_iterator(this, stage_buffer, config),
228            config,
229        );
230    }
231}
232
233#[derive(CubeType)]
234/// Accomplish the entire job of filling the stage
235pub struct SyncPartialJobIterator<IP: MatrixPrecision, L: SyncPartialLoadingStrategy> {
236    job: L::Job<IP>,
237    #[cube(comptime)]
238    pub num_tasks: u32,
239    pub current: ComptimeCell<TaskCounter>,
240}
241
242#[cube]
243impl<IP: MatrixPrecision, L: SyncPartialLoadingStrategy> JobIterator
244    for SyncPartialJobIterator<IP, L>
245{
246    fn current(this: &Self) -> comptime_type!(u32) {
247        this.current.read().counter
248    }
249
250    fn num_tasks(this: &Self) -> comptime_type!(u32) {
251        this.num_tasks
252    }
253}