cubecl_matmul/components/global/load/loader/
async_partial_loader.rs

1use super::StageIdent;
2use crate::components::global::base::GlobalConfig;
3use crate::components::global::global_memory::TensorReader;
4use crate::components::global::load::{AsyncLoadingJob, LoadingValidation};
5use crate::components::global::multi_stage::double_buffering::DoubleBufferingGlobalConfig;
6use crate::components::global::{CopyMechanism, Quantization};
7use crate::components::stage::PartialStageToTileReader;
8use crate::components::stage::TilingLayout;
9use crate::components::stage::{self, StageMemory};
10use crate::components::{InputIdent, MatmulPrecision};
11use core::marker::PhantomData;
12use cubecl_core as cubecl;
13use cubecl_core::prelude::barrier::BarrierLevel;
14use cubecl_core::prelude::*;
15use cubecl_std::tensor::r#virtual::VirtualTensor;
16use cubecl_std::{CubeOption, CubeOptionExpand};
17
18#[cube]
19/// A strategy for asynchronously loading a stage of stage memory
20pub trait AsyncPartialLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
21    /// The layout describing how data is tiled across the stage.
22    type TilingLayout: TilingLayout;
23
24    /// The [LoadingJob] for this strategy.
25    type Job<MP: MatmulPrecision>: AsyncLoadingJob<MP, Self::TilingLayout>;
26
27    /// Returns the job with preliminary calculations done.
28    fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
29        #[comptime] buffer_index: u32,
30        #[comptime] ident: InputIdent,
31        #[comptime] config: G,
32    ) -> Self::Job<MP>;
33
34    /// The barrier level at which the copy mechanism works
35    fn barrier_level() -> BarrierLevel;
36}
37
38#[derive(CubeType)]
39/// Loads a stage from stage memory using asynchronous data movement operations.
40///
41/// A complete load is referred to as a `Job`, which is divided into `Tasks`—
42/// each Task represents a single data transfer for a specific unit
43pub struct AsyncBufferLoader<
44    MP: MatmulPrecision,
45    S: stage::StageConfig,
46    CM: CopyMechanism<MP::ES>,
47    L: AsyncPartialLoadingStrategy,
48> {
49    tensor_reader: TensorReader<MP::EI>,
50    stage_memory: StageMemory<MP::ES, L::TilingLayout>,
51    loading_job: CubeOption<(L::Job<MP>, L::Job<MP>)>,
52    #[cube(comptime)]
53    input_ident: InputIdent,
54    #[cube(comptime)]
55    _phantom: PhantomData<(S, CM)>,
56}
57
58#[cube]
59impl<
60    MP: MatmulPrecision,
61    S: stage::StageConfig,
62    CM: CopyMechanism<MP::ES>,
63    L: AsyncPartialLoadingStrategy,
64> AsyncBufferLoader<MP, S, CM, L>
65{
66    /// Create a new AsyncPartialLoader
67    pub fn new(
68        tensor: VirtualTensor<MP::EI>,
69        x_offset: u32,
70        y_offset: u32,
71        batch_offset: u32,
72        quantization: CubeOption<Quantization<MP>>,
73        #[comptime] input_ident: InputIdent,
74        #[comptime] config: DoubleBufferingGlobalConfig<S>,
75    ) -> Self {
76        let stage_memory =
77            StageMemory::new::<S>(2u32, input_ident.as_ident(), config.stage_config());
78        let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
79
80        comptime! {
81            if quantization.is_some() {
82                todo!();
83            }
84        }
85
86        let loading_job = match config.precompute_job() {
87            true => CubeOption::new_Some((
88                L::new_job::<MP, DoubleBufferingGlobalConfig<S>>(0u32, input_ident, config),
89                L::new_job::<MP, DoubleBufferingGlobalConfig<S>>(1u32, input_ident, config),
90            )),
91            false => CubeOption::new_None(),
92        };
93
94        AsyncBufferLoader::<MP, S, CM, L> {
95            tensor_reader,
96            stage_memory,
97            loading_job,
98            input_ident,
99            _phantom: PhantomData::<(S, CM)>,
100        }
101    }
102
103    /// Give a reader to the loaded stage memory.
104    pub fn reader(
105        this: &Self,
106        #[comptime] stage_ident: StageIdent,
107    ) -> PartialStageToTileReader<MP::ES, L::TilingLayout> {
108        PartialStageToTileReader::new(this.stage_memory, stage_ident, this.input_ident)
109    }
110
111    /// Advance the view over global memory along the k dimension by a specified offset, `k_offset`.
112    pub fn advance_view(this: &mut Self, k_offset: u32) {
113        this.tensor_reader.update_view(k_offset, this.input_ident);
114    }
115
116    /// Accomplish the entire job of filling the stage memory
117    pub fn fill_stage(
118        this: &mut Self,
119        mechanism: &CM,
120        #[comptime] stage_ident: StageIdent,
121        #[comptime] config: DoubleBufferingGlobalConfig<S>,
122    ) {
123        let mut loading_job = match this.loading_job {
124            CubeOption::Some(job) => match stage_ident {
125                StageIdent::A => job.0,
126                StageIdent::B => job.1,
127            },
128            CubeOption::None => match stage_ident {
129                StageIdent::A => {
130                    L::new_job::<MP, DoubleBufferingGlobalConfig<S>>(0u32, this.input_ident, config)
131                }
132                StageIdent::B => {
133                    L::new_job::<MP, DoubleBufferingGlobalConfig<S>>(1u32, this.input_ident, config)
134                }
135            },
136        };
137
138        let len = L::Job::task_count(&loading_job);
139        for task_id in 0..len {
140            L::Job::<MP>::execute_task::<CM, DoubleBufferingGlobalConfig<S>>(
141                &mut loading_job,
142                task_id,
143                &this.tensor_reader,
144                &mut this.stage_memory,
145                mechanism,
146                config,
147            );
148        }
149    }
150
151    /// Zero out the stage memory
152    pub fn clear_stage(
153        this: &mut Self,
154        #[comptime] stage_ident: StageIdent,
155        #[comptime] config: DoubleBufferingGlobalConfig<S>,
156    ) {
157        this.stage_memory
158            .clear_stage::<DoubleBufferingGlobalConfig<S>>(stage_ident, this.input_ident, config)
159    }
160}