cubecl_matmul/components/global/read/reader/
partial_reader.rs

1use super::StageBuffer;
2use super::TaskCounter;
3use crate::components::global::GlobalReaderConfig;
4use crate::components::global::memory::GlobalIterator;
5use crate::components::global::multi_stage::JobExecutor;
6use crate::components::global::multi_stage::JobIterator;
7use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
8use crate::components::global::read::LoadingJob;
9use crate::components::global::read::LoadingValidation;
10use crate::components::global::read::SyncBarrier;
11use crate::components::global::read::SyncStrategy;
12use crate::components::stage::LoadStageFamily;
13use crate::components::stage::StageFamily;
14use crate::components::stage::TilingLayout;
15use cubecl_core as cubecl;
16use cubecl_core::prelude::*;
17use cubecl_std::{
18    CubeOption, CubeOptionExpand,
19    tensor::{View, layout::Coords2d},
20};
21
22pub type LoaderStage<L, IP> = <<L as PartialLoadingStrategy>::Stage as StageFamily>::Stage<
23    IP,
24    <L as PartialLoadingStrategy>::TilingLayout,
25>;
26
27#[cube]
28/// A strategy for loading partial stage memory
29pub trait PartialLoadingStrategy:
30    'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
31{
32    /// The layout describing how data is tiled across the stage.
33    type TilingLayout: TilingLayout;
34    type SyncStrategy: SyncStrategy;
35    type Stage: LoadStageFamily<ReadOnly>;
36
37    /// The [LoadingJob] for this strategy.
38    type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = Self::Stage>;
39
40    /// Returns the job with preliminary calculations done.
41    fn new_job<EG: Numeric, ES: Numeric>(
42        #[comptime] stage_index: u32,
43        #[comptime] line_size: u32,
44        #[comptime] config: GlobalReaderConfig,
45    ) -> Self::Job<EG, ES>;
46}
47
48#[derive(Clone, CubeType)]
49#[allow(clippy::type_complexity)]
50/// Loads a stage from stage memory using synchronous data movement operations.
51///
52/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
53/// each Task represents a single data transfer for a specific unit
54pub struct PartialStageGlobalReader<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> {
55    global_iter: GlobalIterator<Line<EG>>,
56    stage_memory: LoaderStage<L, ES>,
57    loading_job: CubeOption<(L::Job<EG, ES>, L::Job<EG, ES>)>,
58}
59
60#[cube]
61impl<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> PartialStageGlobalReader<EG, ES, L> {
62    /// Create a new SyncPartialStageGlobalReader
63    pub fn new(
64        tensor: View<Line<EG>, Coords2d>,
65        k_step: u32,
66        #[comptime] config: GlobalReaderConfig,
67    ) -> Self {
68        let stage_memory = L::Stage::create(128u32, config.smem_config);
69        let global_iter =
70            GlobalIterator::new(tensor, k_step, config.gmem_config.view_direction, false);
71
72        let loading_job = match config.precompute_job {
73            true => CubeOption::new_Some((
74                L::new_job::<EG, ES>(0u32, tensor.line_size(), config),
75                L::new_job::<EG, ES>(1u32, tensor.line_size(), config),
76            )),
77            false => CubeOption::new_None(),
78        };
79
80        PartialStageGlobalReader::<EG, ES, L> {
81            global_iter,
82            stage_memory,
83            loading_job,
84        }
85    }
86
87    /// Give a reader to the loaded stage memory.
88    pub fn stage(&self, #[comptime] stage_buffer: StageBuffer) -> LoaderStage<L, ES> {
89        L::Stage::with_buffer_index(&self.stage_memory, stage_buffer.to_index())
90    }
91
92    /// Frees the stage memory for reuse
93    pub fn free_stage(self) {
94        L::Stage::free(&self.stage_memory);
95    }
96
97    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
98    pub fn advance_view(&mut self) {
99        self.global_iter.advance();
100    }
101
102    /// Accomplish the entire job of loading data into the stage memory
103    pub fn load_stage(
104        &mut self,
105        barrier: &mut SyncBarrier<L::SyncStrategy>,
106        #[comptime] stage_buffer: StageBuffer,
107        #[comptime] config: GlobalReaderConfig,
108    ) {
109        let mut loading_job = match self.loading_job {
110            CubeOption::Some(job) => match stage_buffer {
111                StageBuffer::A => job.0,
112                StageBuffer::B => job.1,
113            },
114            CubeOption::None => match stage_buffer {
115                StageBuffer::A => L::new_job::<EG, ES>(0u32, self.global_iter.line_size(), config),
116                StageBuffer::B => L::new_job::<EG, ES>(1u32, self.global_iter.line_size(), config),
117            },
118        };
119
120        let len = L::Job::task_count(&loading_job);
121
122        #[unroll]
123        for task_id in 0..len {
124            L::Job::<EG, ES>::execute_task(
125                &mut loading_job,
126                task_id,
127                &self.global_iter,
128                &mut self.stage_memory,
129                barrier,
130                config,
131            );
132        }
133    }
134}
135
136#[cube]
137impl<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> JobExecutor<L::SyncStrategy>
138    for PartialStageGlobalReader<EG, ES, L>
139{
140    type JobIterator = PartialJobIterator<EG, ES, L>;
141
142    fn create_job_iterator(
143        this: &Self,
144        #[comptime] stage_buffer: StageBuffer,
145        #[comptime] config: GlobalReaderConfig,
146    ) -> Self::JobIterator {
147        let view = this.global_iter.view();
148        let job = match this.loading_job {
149            CubeOption::Some(job) => match stage_buffer {
150                StageBuffer::A => job.0,
151                StageBuffer::B => job.1,
152            },
153            CubeOption::None => match stage_buffer {
154                StageBuffer::A => L::new_job::<EG, ES>(0u32, view.line_size(), config),
155                StageBuffer::B => L::new_job::<EG, ES>(1u32, view.line_size(), config),
156            },
157        };
158
159        let num_tasks = L::Job::task_count(&job);
160
161        PartialJobIterator::<EG, ES, L> {
162            job,
163            num_tasks,
164            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
165        }
166    }
167
168    fn execute_task(
169        this: &mut Self,
170        job_iterator: &mut PartialJobIterator<EG, ES, L>,
171        barrier: &mut SyncBarrier<L::SyncStrategy>,
172        #[comptime] config: GlobalReaderConfig,
173    ) {
174        let task_id = job_iterator.current.read().counter;
175
176        L::Job::<EG, ES>::execute_task(
177            &mut job_iterator.job,
178            task_id,
179            &this.global_iter,
180            &mut this.stage_memory,
181            barrier,
182            config,
183        );
184
185        job_iterator.current.store(TaskCounter {
186            counter: comptime!(task_id + 1u32),
187        });
188    }
189
190    fn execute_all_remaining_tasks(
191        this: &mut Self,
192        job_iterator: &mut Self::JobIterator,
193        barrier: &mut SyncBarrier<L::SyncStrategy>,
194        #[comptime] config: GlobalReaderConfig,
195    ) {
196        let task_counter = job_iterator.current.read().counter;
197
198        #[unroll]
199        for task_id in task_counter..job_iterator.num_tasks {
200            L::Job::<EG, ES>::execute_task(
201                &mut job_iterator.job,
202                task_id,
203                &this.global_iter,
204                &mut this.stage_memory,
205                barrier,
206                config,
207            );
208        }
209
210        job_iterator.current.store(TaskCounter {
211            counter: comptime!(job_iterator.num_tasks),
212        });
213    }
214
215    fn execute_whole_job(
216        this: &mut Self,
217        barrier: &mut SyncBarrier<L::SyncStrategy>,
218        #[comptime] stage_buffer: StageBuffer,
219        #[comptime] config: GlobalReaderConfig,
220    ) {
221        Self::execute_all_remaining_tasks(
222            this,
223            &mut Self::create_job_iterator(this, stage_buffer, config),
224            barrier,
225            config,
226        );
227    }
228}
229
230#[derive(CubeType)]
231/// Accomplish the entire job of filling the stage
232pub struct PartialJobIterator<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> {
233    job: L::Job<EG, ES>,
234    #[cube(comptime)]
235    pub num_tasks: u32,
236    pub current: ComptimeCell<TaskCounter>,
237}
238
239#[cube]
240impl<EG: Numeric, ES: Numeric, L: PartialLoadingStrategy> JobIterator
241    for PartialJobIterator<EG, ES, L>
242{
243    fn current(this: &Self) -> comptime_type!(u32) {
244        this.current.read().counter
245    }
246
247    fn num_tasks(this: &Self) -> comptime_type!(u32) {
248        this.num_tasks
249    }
250}