Skip to main content

cubek_matmul/components/global/read/reader/
full_reader.rs

1use std::marker::PhantomData;
2
3use crate::{
4    components::global::multi_stage::JobExecutor,
5    components::global::multi_stage::LoadMaxRoundPlaneCount,
6    components::global::read::LoadingJob,
7    components::global::read::LoadingValidation,
8    components::global::read::StageBuffer,
9    components::global::read::SyncStrategy,
10    components::global::read::TaskCounter,
11    components::global::{multi_stage::JobIterator, read::FullLoaderStage},
12    components::stage::TilingLayout,
13    components::{global::memory::GlobalIterator, stage::LoadStageFamily},
14    {components::global::GlobalReaderConfig, launch::RuntimeConfig},
15};
16use cubecl::{
17    prelude::*,
18    std::tensor::{View, layout::Coords2d},
19};
20use cubek_std::tile::TileKind;
21
22pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
23
24#[cube]
25/// A strategy for synchronously loading a full stage memory.
26pub trait FullLoadingStrategy<RC: RuntimeConfig>:
27    'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
28{
29    /// The layout describing how data is tiled across the stage.
30    type TilingLayout: TilingLayout;
31    /// The synchronization strategy that should be used with this loading strategy
32    type SyncStrategy: SyncStrategy;
33    type Stage: LoadStageFamily<ReadOnly>;
34    type TileKind: TileKind;
35
36    /// The [LoadingJob] for this strategy.
37    type Job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>: LoadingJob<EG, NG, ES, NS, Self::TilingLayout, Self::SyncStrategy, Stage = Self::Stage>;
38
39    const SHOULD_CLEAR: bool = false;
40
41    /// Returns the job with preliminary calculations done.
42    fn new_job<EG: Numeric, NG: Size, ES: Numeric, NS: Size>(
43        config: RC,
44        #[comptime] config: GlobalReaderConfig,
45    ) -> Self::Job<EG, NG, ES, NS>;
46}
47
48#[derive(Clone, CubeType)]
49/// Loads the entire stage memory.
50///
51/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
52/// each Task represents a single data transfer for a specific unit
53pub struct FullStageGlobalReader<
54    EG: Numeric,
55    NG: Size,
56    ES: Numeric,
57    NS: Size,
58    RC: RuntimeConfig,
59    L: FullLoadingStrategy<RC>,
60> {
61    global_iter: GlobalIterator<Vector<EG, NG>>,
62    runtime_config: RC,
63    stage: FullLoaderStage<RC, L, ES, NS>,
64    loading_job: ComptimeOption<L::Job<EG, NG, ES, NS>>,
65    #[cube(comptime)]
66    _phantom: PhantomData<L>,
67}
68
69#[cube]
70impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
71    FullStageGlobalReader<EG, NG, ES, NS, RC, L>
72{
73    /// Create a new SyncFullStageGlobalReader
74    pub fn new(
75        view: View<Vector<EG, NG>, Coords2d>,
76        runtime_config: RC,
77        k_step: u32,
78        #[comptime] config: GlobalReaderConfig,
79    ) -> Self {
80        // Maybe make align a property on the strategy, but it's fine to over-align so this works
81        // for now. Swizzling will require more though.
82        let stage = L::Stage::create(128usize, config.smem_config);
83
84        let global_iter =
85            GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
86
87        let loading_job = match config.precompute_job {
88            true => ComptimeOption::new_Some(L::new_job::<EG, NG, ES, NS>(
89                runtime_config.clone(),
90                config,
91            )),
92            false => ComptimeOption::new_None(),
93        };
94
95        FullStageGlobalReader::<EG, NG, ES, NS, RC, L> {
96            global_iter,
97            runtime_config,
98            stage,
99            loading_job,
100            _phantom: PhantomData::<L>,
101        }
102    }
103
104    /// Give a reader to the loaded stage memory.
105    pub fn stage(&self) -> FullLoaderStage<RC, L, ES, NS> {
106        L::Stage::with_buffer_index(&self.stage, 0)
107    }
108
109    /// Frees the stage memory for reuse
110    pub fn free_stage(self) {
111        L::Stage::free(&self.stage);
112    }
113
114    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
115    pub fn advance_view(&mut self) {
116        self.global_iter.advance();
117    }
118
119    /// Accomplish the entire job of loading data into the stage memory
120    pub fn load_stage(
121        &mut self,
122        barrier: &mut SyncBarrier<L::SyncStrategy>,
123        #[comptime] config: GlobalReaderConfig,
124    ) {
125        let mut loading_job = self
126            .loading_job
127            .clone()
128            .unwrap_or_else(|| L::new_job::<EG, NG, ES, NS>(self.runtime_config.clone(), config));
129
130        let len = L::Job::task_count(&loading_job);
131
132        #[unroll]
133        for task_id in 0..len {
134            L::Job::<EG, NG, ES, NS>::execute_task(
135                &mut loading_job,
136                task_id,
137                &self.global_iter,
138                &mut self.stage,
139                barrier,
140                config,
141            );
142        }
143    }
144}
145
146#[cube]
147impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
148    JobExecutor<L::SyncStrategy> for FullStageGlobalReader<EG, NG, ES, NS, RC, L>
149{
150    type JobIterator = FullStageJobIterator<EG, NG, ES, NS, RC, L>;
151
152    fn create_job_iterator(
153        this: &Self,
154        #[comptime] _stage_buffer: StageBuffer,
155        #[comptime] config: GlobalReaderConfig,
156    ) -> Self::JobIterator {
157        let job = this
158            .loading_job
159            .clone()
160            .unwrap_or_else(|| L::new_job::<EG, NG, ES, NS>(this.runtime_config.clone(), config));
161
162        let num_tasks = L::Job::task_count(&job);
163
164        FullStageJobIterator::<EG, NG, ES, NS, RC, L> {
165            job,
166            num_tasks,
167            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
168        }
169    }
170
171    fn execute_task(
172        this: &mut Self,
173        job_iterator: &mut FullStageJobIterator<EG, NG, ES, NS, RC, L>,
174        barrier: &mut SyncBarrier<L::SyncStrategy>,
175        #[comptime] config: GlobalReaderConfig,
176    ) {
177        let task_id = job_iterator.current.read().counter.comptime();
178
179        L::Job::<EG, NG, ES, NS>::execute_task(
180            &mut job_iterator.job,
181            task_id,
182            &this.global_iter,
183            &mut this.stage,
184            barrier,
185            config,
186        );
187
188        job_iterator.current.store(TaskCounter {
189            counter: task_id + 1,
190        });
191    }
192
193    fn execute_all_remaining_tasks(
194        this: &mut Self,
195        job_iterator: &mut Self::JobIterator,
196        barrier: &mut SyncBarrier<L::SyncStrategy>,
197        #[comptime] config: GlobalReaderConfig,
198    ) {
199        let task_counter = job_iterator.current.read().counter;
200
201        #[unroll]
202        for task_id in task_counter..job_iterator.num_tasks {
203            L::Job::<EG, NG, ES, NS>::execute_task(
204                &mut job_iterator.job,
205                task_id,
206                &this.global_iter,
207                &mut this.stage,
208                barrier,
209                config,
210            );
211        }
212
213        job_iterator.current.store(TaskCounter {
214            counter: job_iterator.num_tasks,
215        });
216    }
217
218    fn execute_whole_job(
219        this: &mut Self,
220        barrier: &mut SyncBarrier<L::SyncStrategy>,
221        #[comptime] stage_buffer: StageBuffer,
222        #[comptime] config: GlobalReaderConfig,
223    ) {
224        Self::execute_all_remaining_tasks(
225            this,
226            &mut Self::create_job_iterator(this, stage_buffer, config),
227            barrier,
228            config,
229        );
230    }
231}
232
233#[derive(CubeType)]
234/// A comptime iterator over a job for sync full stage reader
235pub struct FullStageJobIterator<
236    EG: Numeric,
237    NG: Size,
238    ES: Numeric,
239    NS: Size,
240    RC: RuntimeConfig,
241    L: FullLoadingStrategy<RC>,
242> {
243    job: L::Job<EG, NG, ES, NS>,
244    #[cube(comptime)]
245    pub num_tasks: u32,
246    pub current: ComptimeCell<TaskCounter>,
247}
248
249#[cube]
250impl<EG: Numeric, NG: Size, ES: Numeric, NS: Size, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
251    JobIterator for FullStageJobIterator<EG, NG, ES, NS, RC, L>
252{
253    fn current(this: &Self) -> comptime_type!(u32) {
254        this.current.read().counter
255    }
256
257    fn num_tasks(this: &Self) -> comptime_type!(u32) {
258        this.num_tasks
259    }
260}