cubecl_matmul/components/global/load/loader/
async_full_loader.rs

1use std::marker::PhantomData;
2
3use crate::components::global::Quantization;
4use crate::components::global::global_memory::TensorReader;
5use crate::components::global::load::{AsyncLoadingJob, LoadingValidation};
6use crate::components::global::{CopyMechanism, GlobalConfig};
7use crate::components::stage::FullStageToTileReader;
8use crate::components::stage::TilingLayout;
9use crate::components::stage::{self, StageMemory};
10use crate::components::{InputIdent, MatmulPrecision};
11use cubecl_core as cubecl;
12use cubecl_core::prelude::barrier::BarrierLevel;
13use cubecl_core::prelude::*;
14use cubecl_std::tensor::r#virtual::VirtualTensor;
15use cubecl_std::{CubeOption, CubeOptionExpand};
16
17#[cube]
18/// A strategy for fully and asynchronously loading a stage.
19pub trait AsyncFullLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
20    /// The layout describing how data is tiled across the stage.
21    type TilingLayout: TilingLayout;
22
23    /// The [LoadingJob] for this strategy.
24    type Job<MP: MatmulPrecision>: AsyncLoadingJob<MP, Self::TilingLayout>;
25
26    /// Returns the job with preliminary calculations done.
27    fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
28        #[comptime] ident: InputIdent,
29        #[comptime] config: G,
30    ) -> Self::Job<MP>;
31
32    /// The barrier level at which the copy mechanism works
33    fn barrier_level() -> BarrierLevel;
34}
35
36#[derive(CubeType)]
37/// Loads the entire stage memory using asynchronous data movement operations.
38///
39/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
40/// each Task represents a single data transfer for a specific unit
41pub struct AsyncFullLoader<
42    MP: MatmulPrecision,
43    CM: CopyMechanism<MP::ES>,
44    S: stage::StageConfig,
45    L: AsyncFullLoadingStrategy,
46    G: GlobalConfig,
47> {
48    tensor_reader: TensorReader<MP::EI>,
49    stage_memory: StageMemory<MP::ES, L::TilingLayout>,
50    loading_job: CubeOption<L::Job<MP>>,
51    #[cube(comptime)]
52    ident: InputIdent,
53    #[cube(comptime)]
54    _phantom: PhantomData<(S, L, CM, G)>,
55}
56
57#[cube]
58impl<
59    MP: MatmulPrecision,
60    CM: CopyMechanism<MP::ES>,
61    S: stage::StageConfig,
62    L: AsyncFullLoadingStrategy,
63    G: GlobalConfig,
64> AsyncFullLoader<MP, CM, S, L, G>
65{
66    /// Create a new AsyncFullLoader
67    pub fn new(
68        tensor: VirtualTensor<MP::EI>,
69        x_offset: u32,
70        y_offset: u32,
71        batch_offset: u32,
72        quantization: CubeOption<Quantization<MP>>,
73        #[comptime] ident: InputIdent,
74        #[comptime] config: G,
75    ) -> Self {
76        comptime! {
77            if quantization.is_some() {
78                todo!();
79            }
80        }
81
82        let mut stage_memory =
83            StageMemory::new::<G::StageConfig>(1u32, ident.as_ident(), config.stage_config());
84        let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
85
86        let loading_job = match config.precompute_job() {
87            true => CubeOption::new_Some(L::new_job::<MP, G>(ident, config)),
88            false => CubeOption::new_None(),
89        };
90
91        match ident {
92            InputIdent::Lhs =>
93            {
94                #[allow(clippy::collapsible_if)]
95                if config.check_row_bounds(ident) {
96                    if tensor_reader.x_offset.read()
97                        > tensor_reader.shape_x - config.tiling_scheme().elements_in_stage_m()
98                    {
99                        stage_memory.clear_all::<G>(ident, config);
100                    }
101                }
102            }
103            InputIdent::Rhs =>
104            {
105                #[allow(clippy::collapsible_if)]
106                if config.check_col_bounds(ident) {
107                    if tensor_reader.y_offset.read()
108                        > tensor_reader.shape_y - config.tiling_scheme().elements_in_stage_n()
109                    {
110                        stage_memory.clear_all::<G>(ident, config);
111                    }
112                }
113            }
114        }
115
116        AsyncFullLoader::<MP, CM, S, L, G> {
117            tensor_reader,
118            stage_memory,
119            loading_job,
120            ident,
121            _phantom: PhantomData,
122        }
123    }
124
125    /// Accomplish the entire job of filling the stage memory
126    pub fn fill_stage(this: &mut Self, mechanism: &CM, #[comptime] config: G) {
127        let mut loading_job = match this.loading_job {
128            CubeOption::Some(loading_job) => loading_job,
129            CubeOption::None => L::new_job::<MP, G>(this.ident, config),
130        };
131
132        let len = L::Job::task_count(&loading_job);
133        for task_id in 0..len {
134            L::Job::<MP>::execute_task::<CM, G>(
135                &mut loading_job,
136                task_id,
137                &this.tensor_reader,
138                &mut this.stage_memory,
139                mechanism,
140                config,
141            );
142        }
143    }
144
145    /// Zero out the stage memory
146    pub fn clear_stage(this: &mut Self, #[comptime] config: G) {
147        this.stage_memory.clear_all::<G>(this.ident, config)
148    }
149
150    /// Give a reader to the loaded stage memory.
151    pub fn reader(this: &Self) -> FullStageToTileReader<MP::ES, L::TilingLayout> {
152        FullStageToTileReader::new(this.stage_memory, this.ident)
153    }
154
155    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
156    pub fn advance_view(this: &mut Self, k_offset: u32) {
157        this.tensor_reader.update_view(k_offset, this.ident);
158    }
159}