cubek_convolution/components/global/read/reader/
full_reader.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl::std::{
5    CubeOption, CubeOptionExpand,
6    tensor::{View, layout::Coords2d},
7};
8use cubek_matmul::components::{
9    global::{
10        GlobalReaderConfig,
11        memory::GlobalIterator,
12        multi_stage::{JobExecutor, JobIterator, LoadMaxRoundPlaneCount},
13        read::{LoadingJob, LoadingValidation, StageBuffer, SyncStrategy, TaskCounter},
14    },
15    stage::{StridedStageFamily, StridedStageMemory, TilingLayout},
16};
17
18use crate::components::global::args::RuntimeArgs;
19
20pub type SyncBarrier<S> = <S as SyncStrategy>::Barrier;
21
22#[cube]
23/// A strategy for synchronously loading a full stage memory.
24pub trait FullLoadingStrategy:
25    'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
26{
27    /// The layout describing how data is tiled across the stage.
28    type TilingLayout: TilingLayout;
29    /// The synchronization strategy that should be used with this loading strategy
30    type SyncStrategy: SyncStrategy;
31
32    /// The [LoadingJob] for this strategy.
33    type Job<EG: Numeric, ES: Numeric>: LoadingJob<EG, ES, Self::TilingLayout, Self::SyncStrategy, Stage = StridedStageFamily>;
34
35    /// Returns the job with preliminary calculations done.
36    fn new_job<EG: Numeric, ES: Numeric>(
37        runtime_args: RuntimeArgs,
38        #[comptime] line_size: LineSize,
39        #[comptime] config: GlobalReaderConfig,
40    ) -> Self::Job<EG, ES>;
41}
42
43#[derive(Clone, CubeType)]
44/// Loads the entire stage memory.
45///
46/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
47/// each Task represents a single data transfer for a specific unit
48pub struct FullStageGlobalReader<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
49    global_iter: GlobalIterator<Line<EG>>,
50    runtime_args: RuntimeArgs,
51    stage: StridedStageMemory<ES, L::TilingLayout>,
52    loading_job: CubeOption<L::Job<EG, ES>>,
53    #[cube(comptime)]
54    _phantom: PhantomData<L>,
55}
56
57#[cube]
58impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> FullStageGlobalReader<EG, ES, L> {
59    /// Create a new SyncFullStageGlobalReader
60    pub fn new(
61        view: View<Line<EG>, Coords2d>,
62        runtime_args: RuntimeArgs,
63        k_step: u32,
64        #[comptime] config: GlobalReaderConfig,
65    ) -> Self {
66        // Maybe make align a property on the strategy, but it's fine to over-align so this works
67        // for now. Swizzling will require more though.
68        let stage = StridedStageMemory::new_aligned(128usize, config.smem_config);
69
70        let global_iter =
71            GlobalIterator::new(view, k_step, config.gmem_config.view_direction, false);
72
73        let loading_job = match config.precompute_job {
74            true => CubeOption::new_Some(L::new_job::<EG, ES>(
75                runtime_args.clone(),
76                view.line_size(),
77                config,
78            )),
79            false => CubeOption::new_None(),
80        };
81
82        FullStageGlobalReader::<EG, ES, L> {
83            global_iter,
84            runtime_args,
85            stage,
86            loading_job,
87            _phantom: PhantomData::<L>,
88        }
89    }
90
91    /// Give a reader to the loaded stage memory.
92    pub fn stage(&self) -> StridedStageMemory<ES, L::TilingLayout> {
93        self.stage
94    }
95
96    pub fn clear_stage(&mut self, #[comptime] config: GlobalReaderConfig) {
97        self.stage.clear_all(config);
98    }
99
100    pub fn free_stage(self) {
101        unsafe { self.stage.free() };
102    }
103
104    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
105    pub fn advance_view(&mut self) {
106        self.global_iter.advance();
107    }
108
109    /// Accomplish the entire job of loading data into the stage memory
110    pub fn load_stage(
111        &mut self,
112        barrier: &mut SyncBarrier<L::SyncStrategy>,
113        #[comptime] config: GlobalReaderConfig,
114    ) {
115        let mut loading_job = match self.loading_job.clone() {
116            CubeOption::Some(loading_job) => loading_job,
117            CubeOption::None => L::new_job::<EG, ES>(
118                self.runtime_args.clone(),
119                self.global_iter.line_size(),
120                config,
121            ),
122        };
123
124        let len = L::Job::task_count(&loading_job);
125
126        #[unroll]
127        for task_id in 0..len {
128            L::Job::<EG, ES>::execute_task(
129                &mut loading_job,
130                task_id,
131                &self.global_iter,
132                &mut self.stage,
133                barrier,
134                config,
135            );
136        }
137    }
138}
139
140#[cube]
141impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobExecutor<L::SyncStrategy>
142    for FullStageGlobalReader<EG, ES, L>
143{
144    type JobIterator = FullStageJobIterator<EG, ES, L>;
145
146    fn create_job_iterator(
147        this: &Self,
148        #[comptime] _stage_buffer: StageBuffer,
149        #[comptime] config: GlobalReaderConfig,
150    ) -> Self::JobIterator {
151        let view = this.global_iter.view();
152        let job = match this.loading_job.clone() {
153            CubeOption::Some(loading_job) => loading_job,
154            CubeOption::None => {
155                L::new_job::<EG, ES>(this.runtime_args.clone(), view.line_size(), config)
156            }
157        };
158
159        let num_tasks = L::Job::task_count(&job);
160
161        FullStageJobIterator::<EG, ES, L> {
162            job,
163            num_tasks,
164            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
165        }
166    }
167
168    fn execute_task(
169        this: &mut Self,
170        job_iterator: &mut FullStageJobIterator<EG, ES, L>,
171        barrier: &mut SyncBarrier<L::SyncStrategy>,
172        #[comptime] config: GlobalReaderConfig,
173    ) {
174        let task_id = job_iterator.current.read().counter.comptime();
175
176        L::Job::<EG, ES>::execute_task(
177            &mut job_iterator.job,
178            task_id,
179            &this.global_iter,
180            &mut this.stage,
181            barrier,
182            config,
183        );
184
185        job_iterator.current.store(TaskCounter {
186            counter: task_id + 1,
187        });
188    }
189
190    fn execute_all_remaining_tasks(
191        this: &mut Self,
192        job_iterator: &mut Self::JobIterator,
193        barrier: &mut SyncBarrier<L::SyncStrategy>,
194        #[comptime] config: GlobalReaderConfig,
195    ) {
196        let task_counter = job_iterator.current.read().counter;
197
198        #[unroll]
199        for task_id in task_counter..job_iterator.num_tasks {
200            L::Job::<EG, ES>::execute_task(
201                &mut job_iterator.job,
202                task_id,
203                &this.global_iter,
204                &mut this.stage,
205                barrier,
206                config,
207            );
208        }
209
210        job_iterator.current.store(TaskCounter {
211            counter: job_iterator.num_tasks,
212        });
213    }
214
215    fn execute_whole_job(
216        this: &mut Self,
217        barrier: &mut SyncBarrier<L::SyncStrategy>,
218        #[comptime] stage_buffer: StageBuffer,
219        #[comptime] config: GlobalReaderConfig,
220    ) {
221        Self::execute_all_remaining_tasks(
222            this,
223            &mut Self::create_job_iterator(this, stage_buffer, config),
224            barrier,
225            config,
226        );
227    }
228}
229
230#[derive(CubeType)]
231/// A comptime iterator over a job for sync full stage reader
232pub struct FullStageJobIterator<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> {
233    job: L::Job<EG, ES>,
234    #[cube(comptime)]
235    pub num_tasks: u32,
236    pub current: ComptimeCell<TaskCounter>,
237}
238
239#[cube]
240impl<EG: Numeric, ES: Numeric, L: FullLoadingStrategy> JobIterator
241    for FullStageJobIterator<EG, ES, L>
242{
243    fn current(this: &Self) -> comptime_type!(u32) {
244        this.current.read().counter
245    }
246
247    fn num_tasks(this: &Self) -> comptime_type!(u32) {
248        this.num_tasks
249    }
250}