cubecl_matmul/components/global/read/reader/
sync_full_reader.rs

1use std::marker::PhantomData;
2
3use crate::components::MatmulIdent;
4use crate::components::MatrixPrecision;
5use crate::components::global::GlobalConfig;
6use crate::components::global::memory::GlobalIterator;
7use crate::components::global::multi_stage::JobExecutor;
8use crate::components::global::multi_stage::JobIterator;
9use crate::components::global::multi_stage::LoadMaxRoundPlaneCount;
10use crate::components::global::read::LoadingJob;
11use crate::components::global::read::LoadingValidation;
12use crate::components::global::read::StageBuffer;
13use crate::components::global::read::TaskCounter;
14use crate::components::stage::StridedStage;
15use crate::components::stage::TilingLayout;
16use cubecl_core as cubecl;
17use cubecl_core::prelude::*;
18use cubecl_std::{
19    CubeOption, CubeOptionExpand,
20    tensor::{View, layout::Coords2d},
21};
22
23#[cube]
24/// A strategy for synchronously loading a full stage memory.
25pub trait SyncFullLoadingStrategy:
26    'static + Send + Sync + Clone + LoadingValidation + LoadMaxRoundPlaneCount
27{
28    /// The layout describing how data is tiled across the stage.
29    type TilingLayout: TilingLayout;
30
31    /// The [LoadingJob] for this strategy.
32    type Job<IP: MatrixPrecision>: LoadingJob<IP, Self::TilingLayout>;
33
34    /// Returns the job with preliminary calculations done.
35    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
36        #[comptime] ident: MatmulIdent,
37        #[comptime] line_size: u32,
38        #[comptime] config: G,
39    ) -> Self::Job<IP>;
40}
41
42#[derive(Clone, CubeType)]
43/// Loads the entire stage memory using synchronous data movement operations.
44///
45/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
46/// each Task represents a single data transfer for a specific unit
47pub struct SyncFullStageGlobalReader<
48    IP: MatrixPrecision,
49    G: GlobalConfig,
50    L: SyncFullLoadingStrategy,
51> {
52    global_iter: GlobalIterator<Line<IP::Global>>,
53    stage: StridedStage<IP::Stage, L::TilingLayout>,
54    loading_job: CubeOption<L::Job<IP>>,
55    #[cube(comptime)]
56    ident: MatmulIdent,
57    #[cube(comptime)]
58    _phantom: PhantomData<(G, L)>,
59}
60
61#[cube]
62impl<IP: MatrixPrecision, G: GlobalConfig, L: SyncFullLoadingStrategy>
63    SyncFullStageGlobalReader<IP, G, L>
64{
65    /// Create a new SyncFullStageGlobalReader
66    pub fn new(
67        tensor: View<Line<IP::Global>, Coords2d>,
68        k_step: u32,
69        #[comptime] ident: MatmulIdent,
70        #[comptime] config: G,
71    ) -> Self {
72        let stage = StridedStage::new(
73            comptime!(ident.into_stage()),
74            config.stage_memory_config(ident),
75        );
76        let global_iter = GlobalIterator::new(tensor, k_step, ident.view_direction(), false);
77
78        let loading_job = match config.precompute_job() {
79            true => CubeOption::new_Some(L::new_job::<IP, G>(ident, tensor.line_size(), config)),
80            false => CubeOption::new_None(),
81        };
82
83        SyncFullStageGlobalReader::<IP, G, L> {
84            global_iter,
85            stage,
86            loading_job,
87            ident,
88            _phantom: PhantomData::<(G, L)>,
89        }
90    }
91
92    /// Give a reader to the loaded stage memory.
93    pub fn stage(&self) -> StridedStage<IP::Stage, L::TilingLayout> {
94        self.stage
95    }
96
97    pub fn free_stage(self) {
98        unsafe { self.stage.free() };
99    }
100
101    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
102    pub fn advance_view(&mut self) {
103        self.global_iter.advance();
104    }
105
106    /// Accomplish the entire job of loading data into the stage memory
107    pub fn load_stage(&mut self, #[comptime] config: G) {
108        let mut loading_job = match self.loading_job {
109            CubeOption::Some(loading_job) => loading_job,
110            CubeOption::None => {
111                L::new_job::<IP, G>(self.ident, self.global_iter.line_size(), config)
112            }
113        };
114
115        let len = L::Job::task_count(&loading_job);
116
117        #[unroll]
118        for task_id in 0..len {
119            L::Job::<IP>::execute_task::<G>(
120                &mut loading_job,
121                task_id,
122                &self.global_iter,
123                &mut self.stage,
124                config,
125            );
126        }
127    }
128}
129
130#[cube]
131impl<IP: MatrixPrecision, G: GlobalConfig, L: SyncFullLoadingStrategy> JobExecutor<G>
132    for SyncFullStageGlobalReader<IP, G, L>
133{
134    type JobIterator = SyncFullStageJobIterator<IP, L>;
135
136    fn create_job_iterator(
137        this: &Self,
138        #[comptime] _stage_buffer: StageBuffer,
139        #[comptime] config: G,
140    ) -> Self::JobIterator {
141        let view = this.global_iter.view();
142        let job = match this.loading_job {
143            CubeOption::Some(loading_job) => loading_job,
144            CubeOption::None => L::new_job::<IP, G>(this.ident, view.line_size(), config),
145        };
146
147        let num_tasks = L::Job::task_count(&job);
148
149        SyncFullStageJobIterator::<IP, L> {
150            job,
151            num_tasks,
152            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
153        }
154    }
155
156    fn execute_task(
157        this: &mut Self,
158        job_iterator: &mut SyncFullStageJobIterator<IP, L>,
159        #[comptime] config: G,
160    ) {
161        let task_id = job_iterator.current.read().counter;
162
163        L::Job::<IP>::execute_task::<G>(
164            &mut job_iterator.job,
165            task_id,
166            &this.global_iter,
167            &mut this.stage,
168            config,
169        );
170
171        job_iterator.current.store(TaskCounter {
172            counter: comptime!(task_id + 1u32),
173        });
174    }
175
176    fn execute_all_remaining_tasks(
177        this: &mut Self,
178        job_iterator: &mut Self::JobIterator,
179        #[comptime] config: G,
180    ) {
181        let task_counter = job_iterator.current.read().counter;
182
183        let mut task_id = comptime![task_counter];
184
185        #[allow(clippy::explicit_counter_loop)]
186        #[unroll]
187        for _ in task_counter..job_iterator.num_tasks {
188            L::Job::<IP>::execute_task::<G>(
189                &mut job_iterator.job,
190                task_id,
191                &this.global_iter,
192                &mut this.stage,
193                config,
194            );
195            comptime![task_id += 1];
196        }
197
198        job_iterator.current.store(TaskCounter {
199            counter: comptime!(job_iterator.num_tasks),
200        });
201    }
202
203    fn execute_whole_job(
204        this: &mut Self,
205        #[comptime] stage_buffer: StageBuffer,
206        #[comptime] config: G,
207    ) {
208        Self::execute_all_remaining_tasks(
209            this,
210            &mut Self::create_job_iterator(this, stage_buffer, config),
211            config,
212        );
213    }
214}
215
216#[derive(CubeType)]
217/// A comptime iterator over a job for sync full stage reader
218pub struct SyncFullStageJobIterator<IP: MatrixPrecision, L: SyncFullLoadingStrategy> {
219    job: L::Job<IP>,
220    #[cube(comptime)]
221    pub num_tasks: u32,
222    pub current: ComptimeCell<TaskCounter>,
223}
224
225#[cube]
226impl<IP: MatrixPrecision, L: SyncFullLoadingStrategy> JobIterator
227    for SyncFullStageJobIterator<IP, L>
228{
229    fn current(this: &Self) -> comptime_type!(u32) {
230        this.current.read().counter
231    }
232
233    fn num_tasks(this: &Self) -> comptime_type!(u32) {
234        this.num_tasks
235    }
236}