cubecl_matmul/components/global/read/reader/
full_reader.rs

1use std::marker::PhantomData;
2
3use crate::components::StageIdent;
4use crate::components::global::GlobalReaderConfig;
5use crate::components::global::memory::GlobalIterator;
6use crate::components::global::multi_stage::JobExecutor;
7use crate::components::global::multi_stage::JobIterator;
8use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
9use crate::components::global::read::LoadingJob;
10use crate::components::global::read::LoadingValidation;
11use crate::components::global::read::StageBuffer;
12use crate::components::global::read::SyncStrategy;
13use crate::components::global::read::TaskCounter;
14use crate::components::stage::StridedStageFamily;
15use crate::components::stage::StridedStageMemory;
16use crate::components::stage::TilingLayout;
17use cubecl_core as cubecl;
18use cubecl_core::prelude::*;
19use cubecl_std::{
20    CubeOption, CubeOptionExpand,
21    tensor::{View, layout::Coords2d},
22};
23
24pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
25
26#[cube]
27/// A strategy for synchronously loading a full stage memory.
28pub trait FullLoadingStrategy:
29    'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
30{
31    /// The layout describing how data is tiled across the stage.
32    type TilingLayout: TilingLayout;
33    /// The synchronization strategy that should be used with this loading strategy
34    type SyncStrategy: SyncStrategy;
35
36    /// The [LoadingJob] for this strategy.
37    type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = StridedStageFamily>;
38
39    const SHOULD_CLEAR: bool = false;
40
41    /// Returns the job with preliminary calculations done.
42    fn new_job<EG: Numeric, ES: Numeric>(
43        #[comptime] line_size: u32,
44        #[comptime] config: GlobalReaderConfig,
45    ) -> Self::Job<EG, ES>;
46}
47
48#[derive(Clone, CubeType)]
49/// Loads the entire stage memory.
50///
51/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
52/// each Task represents a single data transfer for a specific unit
53pub struct FullStageGlobalReader<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
54    global_iter: GlobalIterator<Line<EG>>,
55    stage: StridedStageMemory<ES, L::TilingLayout>,
56    loading_job: CubeOption<L::Job<EG, ES>>,
57    #[cube(comptime)]
58    _phantom: PhantomData<L>,
59}
60
61#[cube]
62impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> FullStageGlobalReader<EG, ES, L> {
63    /// Create a new SyncFullStageGlobalReader
64    pub fn new(
65        view: View<Line<EG>, Coords2d>,
66        k_step: u32,
67        #[comptime] config: GlobalReaderConfig,
68    ) -> Self {
69        // Maybe make align a property on the strategy, but it's fine to over-align so this works
70        // for now. Swizzling will require more though.
71        let mut stage = StridedStageMemory::new_aligned(128u32, config.smem_config);
72
73        let (shape_row, shape_col) = view.shape();
74        let global_iter =
75            GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
76
77        let loading_job = match config.precompute_job {
78            true => CubeOption::new_Some(L::new_job::<EG, ES>(view.line_size(), config)),
79            false => CubeOption::new_None(),
80        };
81
82        if L::SHOULD_CLEAR {
83            // Slices are clamped to the shape, so if the slice size is smaller than the stage size
84            // we are partially out of bounds.
85            match config.stage_ident {
86                StageIdent::Lhs =>
87                {
88                    #[allow(clippy::collapsible_if)]
89                    if config.gmem_config.check_row_bounds {
90                        if shape_row < config.smem_config.elements_per_stage_along_row() {
91                            stage.clear_all(config);
92                        }
93                    }
94                }
95                StageIdent::Rhs =>
96                {
97                    #[allow(clippy::collapsible_if)]
98                    if config.gmem_config.check_col_bounds {
99                        if shape_col < config.smem_config.elements_per_stage_along_col() {
100                            stage.clear_all(config);
101                        }
102                    }
103                }
104                _ => comptime!(unreachable!()),
105            }
106        }
107
108        FullStageGlobalReader::<EG, ES, L> {
109            global_iter,
110            stage,
111            loading_job,
112            _phantom: PhantomData::<L>,
113        }
114    }
115
116    /// Give a reader to the loaded stage memory.
117    pub fn stage(&self) -> StridedStageMemory<ES, L::TilingLayout> {
118        self.stage
119    }
120
121    pub fn clear_stage(&mut self, #[comptime] config: GlobalReaderConfig) {
122        self.stage.clear_all(config);
123    }
124
125    pub fn free_stage(self) {
126        unsafe { self.stage.free() };
127    }
128
129    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
130    pub fn advance_view(&mut self) {
131        self.global_iter.advance();
132    }
133
134    /// Accomplish the entire job of loading data into the stage memory
135    pub fn load_stage(
136        &mut self,
137        barrier: &mut SyncBarrier<L::SyncStrategy>,
138        #[comptime] config: GlobalReaderConfig,
139    ) {
140        let mut loading_job = match self.loading_job {
141            CubeOption::Some(loading_job) => loading_job,
142            CubeOption::None => L::new_job::<EG, ES>(self.global_iter.line_size(), config),
143        };
144
145        let len = L::Job::task_count(&loading_job);
146
147        #[unroll]
148        for task_id in 0..len {
149            L::Job::<EG, ES>::execute_task(
150                &mut loading_job,
151                task_id,
152                &self.global_iter,
153                &mut self.stage,
154                barrier,
155                config,
156            );
157        }
158    }
159}
160
161#[cube]
162impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobExecutor<L::SyncStrategy>
163    for FullStageGlobalReader<EG, ES, L>
164{
165    type JobIterator = FullStageJobIterator<EG, ES, L>;
166
167    fn create_job_iterator(
168        this: &Self,
169        #[comptime] _stage_buffer: StageBuffer,
170        #[comptime] config: GlobalReaderConfig,
171    ) -> Self::JobIterator {
172        let view = this.global_iter.view();
173        let job = match this.loading_job {
174            CubeOption::Some(loading_job) => loading_job,
175            CubeOption::None => L::new_job::<EG, ES>(view.line_size(), config),
176        };
177
178        let num_tasks = L::Job::task_count(&job);
179
180        FullStageJobIterator::<EG, ES, L> {
181            job,
182            num_tasks,
183            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
184        }
185    }
186
187    fn execute_task(
188        this: &mut Self,
189        job_iterator: &mut FullStageJobIterator<EG, ES, L>,
190        barrier: &mut SyncBarrier<L::SyncStrategy>,
191        #[comptime] config: GlobalReaderConfig,
192    ) {
193        let task_id = job_iterator.current.read().counter;
194
195        L::Job::<EG, ES>::execute_task(
196            &mut job_iterator.job,
197            task_id,
198            &this.global_iter,
199            &mut this.stage,
200            barrier,
201            config,
202        );
203
204        job_iterator.current.store(TaskCounter {
205            counter: comptime!(task_id + 1u32),
206        });
207    }
208
209    fn execute_all_remaining_tasks(
210        this: &mut Self,
211        job_iterator: &mut Self::JobIterator,
212        barrier: &mut SyncBarrier<L::SyncStrategy>,
213        #[comptime] config: GlobalReaderConfig,
214    ) {
215        let task_counter = job_iterator.current.read().counter;
216
217        #[unroll]
218        for task_id in task_counter..job_iterator.num_tasks {
219            L::Job::<EG, ES>::execute_task(
220                &mut job_iterator.job,
221                task_id,
222                &this.global_iter,
223                &mut this.stage,
224                barrier,
225                config,
226            );
227        }
228
229        job_iterator.current.store(TaskCounter {
230            counter: comptime!(job_iterator.num_tasks),
231        });
232    }
233
234    fn execute_whole_job(
235        this: &mut Self,
236        barrier: &mut SyncBarrier<L::SyncStrategy>,
237        #[comptime] stage_buffer: StageBuffer,
238        #[comptime] config: GlobalReaderConfig,
239    ) {
240        Self::execute_all_remaining_tasks(
241            this,
242            &mut Self::create_job_iterator(this, stage_buffer, config),
243            barrier,
244            config,
245        );
246    }
247}
248
249#[derive(CubeType)]
250/// A comptime iterator over a job for sync full stage reader
251pub struct FullStageJobIterator<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
252    job: L::Job<EG, ES>,
253    #[cube(comptime)]
254    pub num_tasks: u32,
255    pub current: ComptimeCell<TaskCounter>,
256}
257
258#[cube]
259impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobIterator
260    for FullStageJobIterator<EG, ES, L>
261{
262    fn current(this: &Self) -> comptime_type!(u32) {
263        this.current.read().counter
264    }
265
266    fn num_tasks(this: &Self) -> comptime_type!(u32) {
267        this.num_tasks
268    }
269}