cubecl_convolution/components/global/read/reader/
full_reader.rs

1use std::marker::PhantomData;
2
3use cubecl_core as cubecl;
4use cubecl_core::prelude::*;
5use cubecl_matmul::components::{
6    global::{
7        GlobalReaderConfig,
8        memory::GlobalIterator,
9        multi_stage::{JobExecutor, JobIterator, LoadMaxRoundPlaneCount},
10        read::{LoadingJob, LoadingValidation, StageBuffer, SyncStrategy, TaskCounter},
11    },
12    stage::{StridedStageFamily, StridedStageMemory, TilingLayout},
13};
14use cubecl_std::{
15    CubeOption, CubeOptionExpand,
16    tensor::{View, layout::Coords2d},
17};
18
19use crate::components::global::args::RuntimeArgs;
20
21pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
22
23#[cube]
24/// A strategy for synchronously loading a full stage memory.
25pub trait FullLoadingStrategy:
26    'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
27{
28    /// The layout describing how data is tiled across the stage.
29    type TilingLayout: TilingLayout;
30    /// The synchronization strategy that should be used with this loading strategy
31    type SyncStrategy: SyncStrategy;
32
33    /// The [LoadingJob] for this strategy.
34    type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = StridedStageFamily>;
35
36    /// Returns the job with preliminary calculations done.
37    fn new_job<EG: Numeric, ES: Numeric>(
38        runtime_args: RuntimeArgs,
39        #[comptime] line_size: u32,
40        #[comptime] config: GlobalReaderConfig,
41    ) -> Self::Job<EG, ES>;
42}
43
44#[derive(Clone, CubeType)]
45/// Loads the entire stage memory.
46///
47/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
48/// each Task represents a single data transfer for a specific unit
49pub struct FullStageGlobalReader<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
50    global_iter: GlobalIterator<Line<EG>>,
51    runtime_args: RuntimeArgs,
52    stage: StridedStageMemory<ES, L::TilingLayout>,
53    loading_job: CubeOption<L::Job<EG, ES>>,
54    #[cube(comptime)]
55    _phantom: PhantomData<L>,
56}
57
58#[cube]
59impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> FullStageGlobalReader<EG, ES, L> {
60    /// Create a new SyncFullStageGlobalReader
61    pub fn new(
62        view: View<Line<EG>, Coords2d>,
63        runtime_args: RuntimeArgs,
64        k_step: u32,
65        #[comptime] config: GlobalReaderConfig,
66    ) -> Self {
67        // Maybe make align a property on the strategy, but it's fine to over-align so this works
68        // for now. Swizzling will require more though.
69        let stage = StridedStageMemory::new_aligned(128u32, config.smem_config);
70
71        let global_iter =
72            GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
73
74        let loading_job = match config.precompute_job {
75            true => CubeOption::new_Some(L::new_job::<EG, ES>(
76                runtime_args.clone(),
77                view.line_size(),
78                config,
79            )),
80            false => CubeOption::new_None(),
81        };
82
83        FullStageGlobalReader::<EG, ES, L> {
84            global_iter,
85            runtime_args,
86            stage,
87            loading_job,
88            _phantom: PhantomData::<L>,
89        }
90    }
91
92    /// Give a reader to the loaded stage memory.
93    pub fn stage(&self) -> StridedStageMemory<ES, L::TilingLayout> {
94        self.stage
95    }
96
97    pub fn clear_stage(&mut self, #[comptime] config: GlobalReaderConfig) {
98        self.stage.clear_all(config);
99    }
100
101    pub fn free_stage(self) {
102        unsafe { self.stage.free() };
103    }
104
105    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
106    pub fn advance_view(&mut self) {
107        self.global_iter.advance();
108    }
109
110    /// Accomplish the entire job of loading data into the stage memory
111    pub fn load_stage(
112        &mut self,
113        barrier: &mut SyncBarrier<L::SyncStrategy>,
114        #[comptime] config: GlobalReaderConfig,
115    ) {
116        let mut loading_job = match self.loading_job.clone() {
117            CubeOption::Some(loading_job) => loading_job,
118            CubeOption::None => L::new_job::<EG, ES>(
119                self.runtime_args.clone(),
120                self.global_iter.line_size(),
121                config,
122            ),
123        };
124
125        let len = L::Job::task_count(&loading_job);
126
127        #[unroll]
128        for task_id in 0..len {
129            L::Job::<EG, ES>::execute_task(
130                &mut loading_job,
131                task_id,
132                &self.global_iter,
133                &mut self.stage,
134                barrier,
135                config,
136            );
137        }
138    }
139}
140
141#[cube]
142impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobExecutor<L::SyncStrategy>
143    for FullStageGlobalReader<EG, ES, L>
144{
145    type JobIterator = FullStageJobIterator<EG, ES, L>;
146
147    fn create_job_iterator(
148        this: &Self,
149        #[comptime] _stage_buffer: StageBuffer,
150        #[comptime] config: GlobalReaderConfig,
151    ) -> Self::JobIterator {
152        let view = this.global_iter.view();
153        let job = match this.loading_job.clone() {
154            CubeOption::Some(loading_job) => loading_job,
155            CubeOption::None => {
156                L::new_job::<EG, ES>(this.runtime_args.clone(), view.line_size(), config)
157            }
158        };
159
160        let num_tasks = L::Job::task_count(&job);
161
162        FullStageJobIterator::<EG, ES, L> {
163            job,
164            num_tasks,
165            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
166        }
167    }
168
169    fn execute_task(
170        this: &mut Self,
171        job_iterator: &mut FullStageJobIterator<EG, ES, L>,
172        barrier: &mut SyncBarrier<L::SyncStrategy>,
173        #[comptime] config: GlobalReaderConfig,
174    ) {
175        let task_id = job_iterator.current.read().counter;
176
177        L::Job::<EG, ES>::execute_task(
178            &mut job_iterator.job,
179            task_id,
180            &this.global_iter,
181            &mut this.stage,
182            barrier,
183            config,
184        );
185
186        job_iterator.current.store(TaskCounter {
187            counter: comptime!(task_id + 1u32),
188        });
189    }
190
191    fn execute_all_remaining_tasks(
192        this: &mut Self,
193        job_iterator: &mut Self::JobIterator,
194        barrier: &mut SyncBarrier<L::SyncStrategy>,
195        #[comptime] config: GlobalReaderConfig,
196    ) {
197        let task_counter = job_iterator.current.read().counter;
198
199        #[unroll]
200        for task_id in task_counter..job_iterator.num_tasks {
201            L::Job::<EG, ES>::execute_task(
202                &mut job_iterator.job,
203                task_id,
204                &this.global_iter,
205                &mut this.stage,
206                barrier,
207                config,
208            );
209        }
210
211        job_iterator.current.store(TaskCounter {
212            counter: comptime!(job_iterator.num_tasks),
213        });
214    }
215
216    fn execute_whole_job(
217        this: &mut Self,
218        barrier: &mut SyncBarrier<L::SyncStrategy>,
219        #[comptime] stage_buffer: StageBuffer,
220        #[comptime] config: GlobalReaderConfig,
221    ) {
222        Self::execute_all_remaining_tasks(
223            this,
224            &mut Self::create_job_iterator(this, stage_buffer, config),
225            barrier,
226            config,
227        );
228    }
229}
230
231#[derive(CubeType)]
232/// A comptime iterator over a job for sync full stage reader
233pub struct FullStageJobIterator<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
234    job: L::Job<EG, ES>,
235    #[cube(comptime)]
236    pub num_tasks: u32,
237    pub current: ComptimeCell<TaskCounter>,
238}
239
240#[cube]
241impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobIterator
242    for FullStageJobIterator<EG, ES, L>
243{
244    fn current(this: &Self) -> comptime_type!(u32) {
245        this.current.read().counter
246    }
247
248    fn num_tasks(this: &Self) -> comptime_type!(u32) {
249        this.num_tasks
250    }
251}