cubecl_linalg/matmul/components/global/load/loader/
async_full_loader.rs

1use std::marker::PhantomData;
2
3use crate::matmul::components::global::load::AsyncLoadingJob;
4use crate::matmul::components::global::tensor_view::TensorReader;
5use crate::matmul::components::global::{CopyMechanism, GlobalConfig, LoadingValidation};
6use crate::matmul::components::global::{Quantization, single_stage};
7use crate::matmul::components::stage::FullReader;
8use crate::matmul::components::stage::TilingLayout;
9use crate::matmul::components::stage::{self, Stage};
10use crate::matmul::components::{Ident, InputIdent, MatmulPrecision, global};
11use cubecl_core as cubecl;
12use cubecl_core::prelude::barrier::BarrierLevel;
13use cubecl_core::prelude::*;
14use cubecl_std::tensor::r#virtual::VirtualTensor;
15use cubecl_std::{CubeOption, CubeOptionExpand};
16
17#[cube]
18/// A strategy for fully and asynchronously loading a stage.
19pub trait AsyncFullLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
20    /// The layout describing how data is tiled across the stage.
21    type TilingLayout: TilingLayout;
22
23    /// The [LoadingJob] for this strategy.
24    type Job<MP: MatmulPrecision>: AsyncLoadingJob<MP, Self::TilingLayout>;
25
26    /// Returns the job with preliminary calculations done.
27    fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
28        #[comptime] ident: InputIdent,
29        #[comptime] config: G,
30    ) -> Self::Job<MP>;
31
32    /// The barrier level at which the copy mechanism works
33    fn barrier_level() -> BarrierLevel;
34}
35
36#[derive(CubeType)]
37pub struct AsyncLoader<
38    MP: MatmulPrecision,
39    CM: CopyMechanism<MP::ES>,
40    S: stage::StageConfig,
41    L: AsyncFullLoadingStrategy,
42> {
43    tensor_reader: TensorReader<MP::EI>,
44    stage: Stage<MP::ES, L::TilingLayout>,
45    loading_job: CubeOption<L::Job<MP>>,
46    #[cube(comptime)]
47    ident: InputIdent,
48    #[cube(comptime)]
49    _phantom: PhantomData<(S, L, CM)>,
50}
51
52#[cube]
53impl<
54    MP: MatmulPrecision,
55    CM: CopyMechanism<MP::ES>,
56    S: stage::StageConfig,
57    L: AsyncFullLoadingStrategy,
58> AsyncLoader<MP, CM, S, L>
59{
60    pub fn new<G: global::GlobalConfig>(
61        tensor: VirtualTensor<MP::EI>,
62        x_offset: u32,
63        y_offset: u32,
64        batch_offset: u32,
65        quantization: CubeOption<Quantization<MP>>,
66        #[comptime] ident: InputIdent,
67        #[comptime] config: G,
68    ) -> Self {
69        comptime! {
70            if quantization.is_some() {
71                todo!();
72            }
73        }
74
75        let mut stage = Stage::new::<G::SmmConfig>(ident.as_ident(), config.to_smm_config());
76
77        let loading_job = match config.precompute_job() {
78            true => CubeOption::new_Some(L::new_job::<MP, G>(ident, config)),
79            false => CubeOption::new_None(),
80        };
81
82        match ident {
83            InputIdent::Lhs =>
84            {
85                #[allow(clippy::collapsible_if)]
86                if config.check_row_bounds(ident) {
87                    if x_offset
88                        > tensor.shape(tensor.rank() - 2)
89                            - config.tiling_dimensions(Ident::Lhs).total_row()
90                    {
91                        stage.clear::<G::SmmConfig>(ident, config.to_smm_config());
92                    }
93                }
94            }
95            InputIdent::Rhs =>
96            {
97                #[allow(clippy::collapsible_if)]
98                if config.check_col_bounds(ident) {
99                    if y_offset
100                        > tensor.shape(tensor.rank() - 1)
101                            - config.tiling_dimensions(Ident::Rhs).total_col()
102                    {
103                        stage.clear::<G::SmmConfig>(ident, config.to_smm_config());
104                    }
105                }
106            }
107        }
108
109        let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
110
111        AsyncLoader::<MP, CM, S, L> {
112            tensor_reader,
113            stage,
114            loading_job,
115            ident,
116            _phantom: PhantomData::<(S, L, CM)>,
117        }
118    }
119
120    pub fn fill_stage(
121        this: &mut Self,
122        mechanism: &CM,
123        #[comptime] config: single_stage::Config<S>,
124    ) {
125        let mut loading_job = match this.loading_job {
126            CubeOption::Some(loading_job) => loading_job,
127            CubeOption::None => L::new_job::<MP, single_stage::Config<S>>(this.ident, config),
128        };
129
130        let len = L::Job::task_count(&loading_job);
131        for task_id in 0..len {
132            L::Job::<MP>::execute_task::<CM, single_stage::Config<S>>(
133                &mut loading_job,
134                task_id,
135                &this.tensor_reader,
136                &mut this.stage,
137                mechanism,
138                config,
139            );
140        }
141    }
142
143    pub fn clear_stage(this: &mut Self, #[comptime] config: single_stage::Config<S>) {
144        this.stage.clear::<S>(this.ident, config.to_smm_config())
145    }
146
147    pub fn reader(this: &Self) -> FullReader<MP::ES, L::TilingLayout> {
148        FullReader::new(this.stage, this.ident)
149    }
150
151    pub fn advance_view(this: &mut Self, k_offset: u32) {
152        this.tensor_reader.update_view(k_offset, this.ident);
153    }
154}