cubecl_matmul/components/global/read/reader/
async_full_reader.rs

1use std::marker::PhantomData;
2
3use crate::components::global::memory::GlobalIterator;
4use crate::components::global::read::{AsyncLoadingJob, LoadingValidation};
5use crate::components::global::{CopyMechanism, GlobalConfig};
6use crate::components::stage::TilingLayout;
7use crate::components::stage::{self, StridedStage};
8use crate::components::{MatmulIdent, MatrixPrecision};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::barrier::BarrierLevel;
11use cubecl_core::prelude::*;
12use cubecl_std::{
13    CubeOption, CubeOptionExpand,
14    tensor::{View, layout::Coords2d},
15};
16
17#[cube]
18/// A strategy for fully and asynchronously loading a stage.
19pub trait AsyncFullLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
20    /// The layout describing how data is tiled across the stage.
21    type TilingLayout: TilingLayout;
22
23    /// The [LoadingJob] for this strategy.
24    type Job<IP: MatrixPrecision>: AsyncLoadingJob<IP, Self::TilingLayout>;
25
26    /// Returns the job with preliminary calculations done.
27    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
28        #[comptime] ident: MatmulIdent,
29        #[comptime] config: G,
30    ) -> Self::Job<IP>;
31
32    /// The barrier level at which the copy mechanism works
33    fn barrier_level() -> BarrierLevel;
34}
35
36#[derive(CubeType)]
37/// Loads the entire stage memory using asynchronous data movement operations.
38///
39/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
40/// each Task represents a single data transfer for a specific unit
41pub struct AsyncFullStageGlobalReader<
42    IP: MatrixPrecision,
43    CM: CopyMechanism,
44    S: stage::StageConfig,
45    L: AsyncFullLoadingStrategy,
46    G: GlobalConfig,
47> {
48    tensor_reader: GlobalIterator<Line<IP::Global>>,
49    stage_memory: StridedStage<IP::Stage, L::TilingLayout>,
50    loading_job: CubeOption<L::Job<IP>>,
51    #[cube(comptime)]
52    ident: MatmulIdent,
53    #[cube(comptime)]
54    _phantom: PhantomData<(S, L, CM, G)>,
55}
56
57#[cube]
58impl<
59    IP: MatrixPrecision,
60    CM: CopyMechanism,
61    S: stage::StageConfig,
62    L: AsyncFullLoadingStrategy,
63    G: GlobalConfig,
64> AsyncFullStageGlobalReader<IP, CM, S, L, G>
65{
66    /// Create a new AsyncFullStageGlobalReader
67    pub fn new(
68        view: View<Line<IP::Global>, Coords2d>,
69        k_step: u32,
70        #[comptime] ident: MatmulIdent,
71        #[comptime] config: G,
72    ) -> Self {
73        let mut stage_memory = StridedStage::new(
74            comptime!(ident.into_stage()),
75            config.stage_memory_config(ident),
76        );
77        let (shape_row, shape_col) = view.shape();
78        let tensor_reader = GlobalIterator::new(view, k_step, ident.view_direction(), true);
79
80        let loading_job = match config.precompute_job() {
81            true => CubeOption::new_Some(L::new_job::<IP, G>(ident, config)),
82            false => CubeOption::new_None(),
83        };
84
85        // Slices are clamped to the shape, so if the slice size is smaller than the stage size
86        // we are partially out of bounds.
87        match ident {
88            MatmulIdent::Lhs =>
89            {
90                #[allow(clippy::collapsible_if)]
91                if config.check_row_bounds(ident) {
92                    if shape_row < config.tiling_scheme().elements_in_stage_m() {
93                        stage_memory.clear_all::<G>(ident, config);
94                    }
95                }
96            }
97            MatmulIdent::Rhs =>
98            {
99                #[allow(clippy::collapsible_if)]
100                if config.check_col_bounds(ident) {
101                    if shape_col < config.tiling_scheme().elements_in_stage_n() {
102                        stage_memory.clear_all::<G>(ident, config);
103                    }
104                }
105            }
106            MatmulIdent::Out => comptime!(unreachable!()),
107        }
108
109        AsyncFullStageGlobalReader::<IP, CM, S, L, G> {
110            tensor_reader,
111            stage_memory,
112            loading_job,
113            ident,
114            _phantom: PhantomData,
115        }
116    }
117
118    /// Accomplish the entire job of loading data into the stage memory
119    pub fn load_stage(&mut self, mechanism: &CM, #[comptime] config: G) {
120        let mut loading_job = match self.loading_job {
121            CubeOption::Some(loading_job) => loading_job,
122            CubeOption::None => L::new_job::<IP, G>(self.ident, config),
123        };
124
125        let len = L::Job::task_count(&loading_job);
126        for task_id in 0..len {
127            L::Job::<IP>::execute_task::<CM, G>(
128                &mut loading_job,
129                task_id,
130                &self.tensor_reader,
131                &mut self.stage_memory,
132                mechanism,
133                config,
134            );
135        }
136    }
137
138    /// Zero out the stage memory
139    pub fn clear_stage(&mut self, #[comptime] config: G) {
140        self.stage_memory.clear_all::<G>(self.ident, config)
141    }
142
143    /// Give a reader to the loaded stage memory.
144    pub fn stage(&self) -> StridedStage<IP::Stage, L::TilingLayout> {
145        self.stage_memory
146    }
147
148    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
149    pub fn advance_view(&mut self) {
150        self.tensor_reader.advance();
151    }
152}