Skip to main content

cubek_matmul/components/global/read/reader/
full_reader.rs

1use std::marker::PhantomData;
2
3use crate::components::global::multi_stage::JobExecutor;
4use crate::components::global::read::LoadingJob;
5use crate::components::global::read::LoadingValidation;
6use crate::components::global::read::StageBuffer;
7use crate::components::global::read::SyncStrategy;
8use crate::components::global::read::TaskCounter;
9use crate::components::global::{multi_stage::JobIterator, read::FullLoaderStage};
10use crate::components::stage::TilingLayout;
11use crate::components::{global::memory::GlobalIterator, stage::LoadStageFamily};
12use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, tile::io::TileKind};
13use crate::{components::global::GlobalReaderConfig, launch::RuntimeConfig};
14use cubecl::prelude::*;
15use cubecl::std::{
16    CubeOption, CubeOptionExpand,
17    tensor::{View, layout::Coords2d},
18};
19
20pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
21
22#[cube]
23/// A strategy for synchronously loading a full stage memory.
24pub trait FullLoadingStrategy<RC: RuntimeConfig>:
25    'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
26{
27    /// The layout describing how data is tiled across the stage.
28    type TilingLayout: TilingLayout;
29    /// The synchronization strategy that should be used with this loading strategy
30    type SyncStrategy: SyncStrategy;
31    type Stage: LoadStageFamily<ReadOnly, TileKind = Self::TileKind>;
32    type TileKind: TileKind;
33
34    /// The [LoadingJob] for this strategy.
35    type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = Self::Stage>;
36
37    const SHOULD_CLEAR: bool = false;
38
39    /// Returns the job with preliminary calculations done.
40    fn new_job<EG: Numeric, ES: Numeric>(
41        config: RC,
42        #[comptime] line_size: LineSize,
43        #[comptime] config: GlobalReaderConfig,
44    ) -> Self::Job<EG, ES>;
45}
46
47#[derive(Clone, CubeType)]
48/// Loads the entire stage memory.
49///
50/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
51/// each Task represents a single data transfer for a specific unit
52pub struct FullStageGlobalReader<
53    EG: Numeric,
54    ES: Numeric,
55    RC: RuntimeConfig,
56    L: FullLoadingStrategy<RC>,
57> {
58    global_iter: GlobalIterator<Line<EG>>,
59    runtime_config: RC,
60    stage: FullLoaderStage<RC, L, ES>,
61    loading_job: CubeOption<L::Job<EG, ES>>,
62    #[cube(comptime)]
63    _phantom: PhantomData<L>,
64}
65
66#[cube]
67impl<EG: Numeric, ES: Numeric, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
68    FullStageGlobalReader<EG, ES, RC, L>
69{
70    /// Create a new SyncFullStageGlobalReader
71    pub fn new(
72        view: View<Line<EG>, Coords2d>,
73        runtime_config: RC,
74        k_step: u32,
75        #[comptime] config: GlobalReaderConfig,
76    ) -> Self {
77        // Maybe make align a property on the strategy, but it's fine to over-align so this works
78        // for now. Swizzling will require more though.
79        let stage = L::Stage::create(128usize, config.smem_config);
80
81        let global_iter =
82            GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
83
84        let loading_job = match config.precompute_job {
85            true => CubeOption::new_Some(L::new_job::<EG, ES>(
86                runtime_config.clone(),
87                view.line_size(),
88                config,
89            )),
90            false => CubeOption::new_None(),
91        };
92
93        FullStageGlobalReader::<EG, ES, RC, L> {
94            global_iter,
95            runtime_config,
96            stage,
97            loading_job,
98            _phantom: PhantomData::<L>,
99        }
100    }
101
102    /// Give a reader to the loaded stage memory.
103    pub fn stage(&self) -> FullLoaderStage<RC, L, ES> {
104        L::Stage::with_buffer_index(&self.stage, 0)
105    }
106
107    /// Frees the stage memory for reuse
108    pub fn free_stage(self) {
109        L::Stage::free(&self.stage);
110    }
111
112    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
113    pub fn advance_view(&mut self) {
114        self.global_iter.advance();
115    }
116
117    /// Accomplish the entire job of loading data into the stage memory
118    pub fn load_stage(
119        &mut self,
120        barrier: &mut SyncBarrier<L::SyncStrategy>,
121        #[comptime] config: GlobalReaderConfig,
122    ) {
123        let mut loading_job = match self.loading_job.clone() {
124            CubeOption::Some(loading_job) => loading_job,
125            CubeOption::None => L::new_job::<EG, ES>(
126                self.runtime_config.clone(),
127                self.global_iter.line_size(),
128                config,
129            ),
130        };
131
132        let len = L::Job::task_count(&loading_job);
133
134        #[unroll]
135        for task_id in 0..len {
136            L::Job::<EG, ES>::execute_task(
137                &mut loading_job,
138                task_id,
139                &self.global_iter,
140                &mut self.stage,
141                barrier,
142                config,
143            );
144        }
145    }
146}
147
148#[cube]
149impl<EG: Numeric, ES: Numeric, RC: RuntimeConfig, L: FullLoadingStrategy<RC>>
150    JobExecutor<L::SyncStrategy> for FullStageGlobalReader<EG, ES, RC, L>
151{
152    type JobIterator = FullStageJobIterator<EG, ES, RC, L>;
153
154    fn create_job_iterator(
155        this: &Self,
156        #[comptime] _stage_buffer: StageBuffer,
157        #[comptime] config: GlobalReaderConfig,
158    ) -> Self::JobIterator {
159        let view = this.global_iter.view();
160        let job = match this.loading_job.clone() {
161            CubeOption::Some(loading_job) => loading_job,
162            CubeOption::None => {
163                L::new_job::<EG, ES>(this.runtime_config.clone(), view.line_size(), config)
164            }
165        };
166
167        let num_tasks = L::Job::task_count(&job);
168
169        FullStageJobIterator::<EG, ES, RC, L> {
170            job,
171            num_tasks,
172            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
173        }
174    }
175
176    fn execute_task(
177        this: &mut Self,
178        job_iterator: &mut FullStageJobIterator<EG, ES, RC, L>,
179        barrier: &mut SyncBarrier<L::SyncStrategy>,
180        #[comptime] config: GlobalReaderConfig,
181    ) {
182        let task_id = job_iterator.current.read().counter.comptime();
183
184        L::Job::<EG, ES>::execute_task(
185            &mut job_iterator.job,
186            task_id,
187            &this.global_iter,
188            &mut this.stage,
189            barrier,
190            config,
191        );
192
193        job_iterator.current.store(TaskCounter {
194            counter: task_id + 1,
195        });
196    }
197
198    fn execute_all_remaining_tasks(
199        this: &mut Self,
200        job_iterator: &mut Self::JobIterator,
201        barrier: &mut SyncBarrier<L::SyncStrategy>,
202        #[comptime] config: GlobalReaderConfig,
203    ) {
204        let task_counter = job_iterator.current.read().counter;
205
206        #[unroll]
207        for task_id in task_counter..job_iterator.num_tasks {
208            L::Job::<EG, ES>::execute_task(
209                &mut job_iterator.job,
210                task_id,
211                &this.global_iter,
212                &mut this.stage,
213                barrier,
214                config,
215            );
216        }
217
218        job_iterator.current.store(TaskCounter {
219            counter: job_iterator.num_tasks,
220        });
221    }
222
223    fn execute_whole_job(
224        this: &mut Self,
225        barrier: &mut SyncBarrier<L::SyncStrategy>,
226        #[comptime] stage_buffer: StageBuffer,
227        #[comptime] config: GlobalReaderConfig,
228    ) {
229        Self::execute_all_remaining_tasks(
230            this,
231            &mut Self::create_job_iterator(this, stage_buffer, config),
232            barrier,
233            config,
234        );
235    }
236}
237
238#[derive(CubeType)]
239/// A comptime iterator over a job for sync full stage reader
240pub struct FullStageJobIterator<
241    EG: Numeric,
242    ES: Numeric,
243    RC: RuntimeConfig,
244    L: FullLoadingStrategy<RC>,
245> {
246    job: L::Job<EG, ES>,
247    #[cube(comptime)]
248    pub num_tasks: u32,
249    pub current: ComptimeCell<TaskCounter>,
250}
251
252#[cube]
253impl<EG: Numeric, ES: Numeric, RC: RuntimeConfig, L: FullLoadingStrategy<RC>> JobIterator
254    for FullStageJobIterator<EG, ES, RC, L>
255{
256    fn current(this: &Self) -> comptime_type!(u32) {
257        this.current.read().counter
258    }
259
260    fn num_tasks(this: &Self) -> comptime_type!(u32) {
261        this.num_tasks
262    }
263}