cubecl_linalg/matmul/components/global/load/loader/
async_buffer_loader.rs

1use super::BufferId;
2use crate::matmul::components::global::base::GlobalConfig;
3use crate::matmul::components::global::load::AsyncLoadingJob;
4use crate::matmul::components::global::tensor_view::TensorReader;
5use crate::matmul::components::global::{
6    CommonGlobalConfig, CopyMechanism, LoadingValidation, Quantization,
7};
8use crate::matmul::components::stage::BufferReader;
9use crate::matmul::components::stage::TilingLayout;
10use crate::matmul::components::stage::{self, Stage};
11use crate::matmul::components::{InputIdent, MatmulPrecision};
12use core::marker::PhantomData;
13use cubecl_core as cubecl;
14use cubecl_core::prelude::barrier::BarrierLevel;
15use cubecl_core::prelude::*;
16use cubecl_std::tensor::r#virtual::VirtualTensor;
17use cubecl_std::{CubeOption, CubeOptionExpand};
18
19#[cube]
20/// A strategy for asynchronously loading a buffer (partial stage), either eagerly or as a deferred job.
21pub trait AsyncBufferLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
22    /// The layout describing how data is tiled across the stage.
23    type TilingLayout: TilingLayout;
24
25    /// The [LoadingJob] for this strategy.
26    type Job<MP: MatmulPrecision>: AsyncLoadingJob<MP, Self::TilingLayout>;
27
28    /// Returns the job with preliminary calculations done.
29    fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
30        #[comptime] buffer_index: u32,
31        #[comptime] ident: InputIdent,
32        #[comptime] config: G,
33    ) -> Self::Job<MP>;
34
35    /// The barrier level at which the copy mechanism works
36    fn barrier_level() -> BarrierLevel;
37}
38
39#[derive(CubeType)]
40pub struct AsyncBufferLoader<
41    MP: MatmulPrecision,
42    S: stage::StageConfig,
43    CM: CopyMechanism<MP::ES>,
44    L: AsyncBufferLoadingStrategy,
45> {
46    tensor_reader: TensorReader<MP::EI>,
47    stage: Stage<MP::ES, L::TilingLayout>,
48    loading_job: CubeOption<(L::Job<MP>, L::Job<MP>)>,
49    #[cube(comptime)]
50    input_ident: InputIdent,
51    #[cube(comptime)]
52    _phantom: PhantomData<(S, CM)>,
53}
54
55#[cube]
56impl<
57    MP: MatmulPrecision,
58    S: stage::StageConfig,
59    CM: CopyMechanism<MP::ES>,
60    L: AsyncBufferLoadingStrategy,
61> AsyncBufferLoader<MP, S, CM, L>
62{
63    pub fn new(
64        tensor: VirtualTensor<MP::EI>,
65        x_offset: u32,
66        y_offset: u32,
67        batch_offset: u32,
68        quantization: CubeOption<Quantization<MP>>,
69        #[comptime] input_ident: InputIdent,
70        #[comptime] config: CommonGlobalConfig<S>,
71    ) -> Self {
72        comptime! {
73            if quantization.is_some() {
74                todo!();
75            }
76        }
77
78        let stage = Stage::new::<S>(input_ident.as_ident(), config.to_smm_config());
79        let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
80        let loading_job = match config.precompute_job() {
81            true => CubeOption::new_Some((
82                L::new_job::<MP, CommonGlobalConfig<S>>(0u32, input_ident, config),
83                L::new_job::<MP, CommonGlobalConfig<S>>(1u32, input_ident, config),
84            )),
85            false => CubeOption::new_None(),
86        };
87
88        AsyncBufferLoader::<MP, S, CM, L> {
89            tensor_reader,
90            stage,
91            loading_job,
92            input_ident,
93            _phantom: PhantomData::<(S, CM)>,
94        }
95    }
96
97    pub fn reader(
98        this: &Self,
99        #[comptime] buffer_id: BufferId,
100    ) -> BufferReader<MP::ES, L::TilingLayout> {
101        BufferReader::new(this.stage, buffer_id, this.input_ident)
102    }
103
104    pub fn advance_view(this: &mut Self, k_offset: u32) {
105        this.tensor_reader.update_view(k_offset, this.input_ident);
106    }
107
108    pub fn fill_stage(
109        this: &mut Self,
110        mechanism: &CM,
111        #[comptime] buffer_id: BufferId,
112        #[comptime] config: CommonGlobalConfig<S>,
113    ) {
114        let mut loading_job = match this.loading_job {
115            CubeOption::Some(job) => match buffer_id {
116                BufferId::A => job.0,
117                BufferId::B => job.1,
118            },
119            CubeOption::None => match buffer_id {
120                BufferId::A => {
121                    L::new_job::<MP, CommonGlobalConfig<S>>(0u32, this.input_ident, config)
122                }
123                BufferId::B => {
124                    L::new_job::<MP, CommonGlobalConfig<S>>(1u32, this.input_ident, config)
125                }
126            },
127        };
128
129        let len = L::Job::task_count(&loading_job);
130        for task_id in 0..len {
131            L::Job::<MP>::execute_task::<CM, CommonGlobalConfig<S>>(
132                &mut loading_job,
133                task_id,
134                &this.tensor_reader,
135                &mut this.stage,
136                mechanism,
137                config,
138            );
139        }
140    }
141
142    pub fn clear_stage(
143        this: &mut Self,
144        #[comptime] buffer_id: BufferId,
145        #[comptime] config: CommonGlobalConfig<S>,
146    ) {
147        this.stage
148            .clear_buffer::<S>(buffer_id, this.input_ident, config.to_smm_config())
149    }
150}