cubecl_linalg/matmul/components/global/load/loader/
sync_full_loader.rs1use std::marker::PhantomData;
2
3use crate::matmul::components::global::Quantization;
4use crate::matmul::components::global::load::LoadingJob;
5use crate::matmul::components::global::tensor_view::TensorReader;
6use crate::matmul::components::global::{self};
7use crate::matmul::components::global::{GlobalConfig, LoadingValidation, single_stage};
8use crate::matmul::components::stage::FullReader;
9use crate::matmul::components::stage::TilingLayout;
10use crate::matmul::components::stage::{self, Stage};
11use crate::matmul::components::{InputIdent, MatmulPrecision};
12use cubecl_core as cubecl;
13use cubecl_core::prelude::*;
14use cubecl_std::tensor::r#virtual::VirtualTensor;
15use cubecl_std::{CubeOption, CubeOptionExpand};
16
17#[cube]
18pub trait SyncFullLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
20 type TilingLayout: TilingLayout;
22
23 type Job<MP: MatmulPrecision>: LoadingJob<MP, Self::TilingLayout>;
25
26 fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
28 #[comptime] ident: InputIdent,
29 #[comptime] config: G,
30 ) -> Self::Job<MP>;
31}
32
33#[derive(CubeType)]
34pub struct SyncFullLoader<MP: MatmulPrecision, S: stage::StageConfig, L: SyncFullLoadingStrategy> {
35 tensor_reader: TensorReader<MP::EI>,
36 stage: Stage<MP::ES, L::TilingLayout>,
37 loading_job: CubeOption<L::Job<MP>>,
38 quantization: CubeOption<Quantization<MP>>,
39 #[cube(comptime)]
40 input_ident: InputIdent,
41 #[cube(comptime)]
42 _phantom: PhantomData<(S, L)>,
43}
44
45#[cube]
46impl<MP: MatmulPrecision, S: stage::StageConfig, L: SyncFullLoadingStrategy>
47 SyncFullLoader<MP, S, L>
48{
49 pub fn new<G: global::GlobalConfig>(
50 tensor: VirtualTensor<MP::EI>,
51 x_offset: u32,
52 y_offset: u32,
53 batch_offset: u32,
54 quantization: CubeOption<Quantization<MP>>,
55 #[comptime] input_ident: InputIdent,
56 #[comptime] config: G,
57 ) -> Self {
58 let stage = Stage::new::<G::SmmConfig>(input_ident.as_ident(), config.to_smm_config());
59 let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
60
61 let loading_job = match config.precompute_job() {
62 true => CubeOption::new_Some(L::new_job::<MP, G>(input_ident, config)),
63 false => CubeOption::new_None(),
64 };
65
66 SyncFullLoader::<MP, S, L> {
67 tensor_reader,
68 stage,
69 loading_job,
70 quantization,
71 input_ident,
72 _phantom: PhantomData::<(S, L)>,
73 }
74 }
75
76 pub fn reader(this: &Self) -> FullReader<MP::ES, L::TilingLayout> {
77 FullReader::new(this.stage, this.input_ident)
78 }
79
80 pub fn advance_view(this: &mut Self, k_offset: u32) {
81 this.tensor_reader.update_view(k_offset, this.input_ident);
82 }
83
84 pub fn fill_stage(this: &mut Self, #[comptime] config: single_stage::Config<S>) {
85 let mut loading_job = match this.loading_job {
86 CubeOption::Some(loading_job) => loading_job,
87 CubeOption::None => L::new_job::<MP, single_stage::Config<S>>(this.input_ident, config),
88 };
89
90 let len = L::Job::task_count(&loading_job);
91 for task_id in 0..len {
92 L::Job::<MP>::execute_task::<single_stage::Config<S>>(
93 &mut loading_job,
94 task_id,
95 &this.tensor_reader,
96 &mut this.stage,
97 &this.quantization,
98 config,
99 );
100 }
101 }
102}