cubecl_linalg/matmul/components/global/load/loader/
sync_full_loader.rs

1use std::marker::PhantomData;
2
3use crate::matmul::components::global::Quantization;
4use crate::matmul::components::global::load::LoadingJob;
5use crate::matmul::components::global::tensor_view::TensorReader;
6use crate::matmul::components::global::{self};
7use crate::matmul::components::global::{GlobalConfig, LoadingValidation, single_stage};
8use crate::matmul::components::stage::FullReader;
9use crate::matmul::components::stage::TilingLayout;
10use crate::matmul::components::stage::{self, Stage};
11use crate::matmul::components::{InputIdent, MatmulPrecision};
12use cubecl_core as cubecl;
13use cubecl_core::prelude::*;
14use cubecl_std::tensor::r#virtual::VirtualTensor;
15use cubecl_std::{CubeOption, CubeOptionExpand};
16
17#[cube]
18/// A strategy for fully and synchronously loading a stage.
19pub trait SyncFullLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
20    /// The layout describing how data is tiled across the stage.
21    type TilingLayout: TilingLayout;
22
23    /// The [LoadingJob] for this strategy.
24    type Job<MP: MatmulPrecision>: LoadingJob<MP, Self::TilingLayout>;
25
26    /// Returns the job with preliminary calculations done.
27    fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
28        #[comptime] ident: InputIdent,
29        #[comptime] config: G,
30    ) -> Self::Job<MP>;
31}
32
33#[derive(CubeType)]
34pub struct SyncFullLoader<MP: MatmulPrecision, S: stage::StageConfig, L: SyncFullLoadingStrategy> {
35    tensor_reader: TensorReader<MP::EI>,
36    stage: Stage<MP::ES, L::TilingLayout>,
37    loading_job: CubeOption<L::Job<MP>>,
38    quantization: CubeOption<Quantization<MP>>,
39    #[cube(comptime)]
40    input_ident: InputIdent,
41    #[cube(comptime)]
42    _phantom: PhantomData<(S, L)>,
43}
44
45#[cube]
46impl<MP: MatmulPrecision, S: stage::StageConfig, L: SyncFullLoadingStrategy>
47    SyncFullLoader<MP, S, L>
48{
49    pub fn new<G: global::GlobalConfig>(
50        tensor: VirtualTensor<MP::EI>,
51        x_offset: u32,
52        y_offset: u32,
53        batch_offset: u32,
54        quantization: CubeOption<Quantization<MP>>,
55        #[comptime] input_ident: InputIdent,
56        #[comptime] config: G,
57    ) -> Self {
58        let stage = Stage::new::<G::SmmConfig>(input_ident.as_ident(), config.to_smm_config());
59        let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
60
61        let loading_job = match config.precompute_job() {
62            true => CubeOption::new_Some(L::new_job::<MP, G>(input_ident, config)),
63            false => CubeOption::new_None(),
64        };
65
66        SyncFullLoader::<MP, S, L> {
67            tensor_reader,
68            stage,
69            loading_job,
70            quantization,
71            input_ident,
72            _phantom: PhantomData::<(S, L)>,
73        }
74    }
75
76    pub fn reader(this: &Self) -> FullReader<MP::ES, L::TilingLayout> {
77        FullReader::new(this.stage, this.input_ident)
78    }
79
80    pub fn advance_view(this: &mut Self, k_offset: u32) {
81        this.tensor_reader.update_view(k_offset, this.input_ident);
82    }
83
84    pub fn fill_stage(this: &mut Self, #[comptime] config: single_stage::Config<S>) {
85        let mut loading_job = match this.loading_job {
86            CubeOption::Some(loading_job) => loading_job,
87            CubeOption::None => L::new_job::<MP, single_stage::Config<S>>(this.input_ident, config),
88        };
89
90        let len = L::Job::task_count(&loading_job);
91        for task_id in 0..len {
92            L::Job::<MP>::execute_task::<single_stage::Config<S>>(
93                &mut loading_job,
94                task_id,
95                &this.tensor_reader,
96                &mut this.stage,
97                &this.quantization,
98                config,
99            );
100        }
101    }
102}