cubecl_matmul/components/global/load/loader/
async_partial_loader.rs1use super::StageIdent;
2use crate::components::global::base::GlobalConfig;
3use crate::components::global::global_memory::TensorReader;
4use crate::components::global::load::{AsyncLoadingJob, LoadingValidation};
5use crate::components::global::multi_stage::double_buffering::DoubleBufferingGlobalConfig;
6use crate::components::global::{CopyMechanism, Quantization};
7use crate::components::stage::PartialStageToTileReader;
8use crate::components::stage::TilingLayout;
9use crate::components::stage::{self, StageMemory};
10use crate::components::{InputIdent, MatmulPrecision};
11use core::marker::PhantomData;
12use cubecl_core as cubecl;
13use cubecl_core::prelude::barrier::BarrierLevel;
14use cubecl_core::prelude::*;
15use cubecl_std::tensor::r#virtual::VirtualTensor;
16use cubecl_std::{CubeOption, CubeOptionExpand};
17
18#[cube]
19pub trait AsyncPartialLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
21 type TilingLayout: TilingLayout;
23
24 type Job<MP: MatmulPrecision>: AsyncLoadingJob<MP, Self::TilingLayout>;
26
27 fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
29 #[comptime] buffer_index: u32,
30 #[comptime] ident: InputIdent,
31 #[comptime] config: G,
32 ) -> Self::Job<MP>;
33
34 fn barrier_level() -> BarrierLevel;
36}
37
38#[derive(CubeType)]
39pub struct AsyncBufferLoader<
44 MP: MatmulPrecision,
45 S: stage::StageConfig,
46 CM: CopyMechanism<MP::ES>,
47 L: AsyncPartialLoadingStrategy,
48> {
49 tensor_reader: TensorReader<MP::EI>,
50 stage_memory: StageMemory<MP::ES, L::TilingLayout>,
51 loading_job: CubeOption<(L::Job<MP>, L::Job<MP>)>,
52 #[cube(comptime)]
53 input_ident: InputIdent,
54 #[cube(comptime)]
55 _phantom: PhantomData<(S, CM)>,
56}
57
58#[cube]
59impl<
60 MP: MatmulPrecision,
61 S: stage::StageConfig,
62 CM: CopyMechanism<MP::ES>,
63 L: AsyncPartialLoadingStrategy,
64> AsyncBufferLoader<MP, S, CM, L>
65{
66 pub fn new(
68 tensor: VirtualTensor<MP::EI>,
69 x_offset: u32,
70 y_offset: u32,
71 batch_offset: u32,
72 quantization: CubeOption<Quantization<MP>>,
73 #[comptime] input_ident: InputIdent,
74 #[comptime] config: DoubleBufferingGlobalConfig<S>,
75 ) -> Self {
76 let stage_memory =
77 StageMemory::new::<S>(2u32, input_ident.as_ident(), config.stage_config());
78 let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
79
80 comptime! {
81 if quantization.is_some() {
82 todo!();
83 }
84 }
85
86 let loading_job = match config.precompute_job() {
87 true => CubeOption::new_Some((
88 L::new_job::<MP, DoubleBufferingGlobalConfig<S>>(0u32, input_ident, config),
89 L::new_job::<MP, DoubleBufferingGlobalConfig<S>>(1u32, input_ident, config),
90 )),
91 false => CubeOption::new_None(),
92 };
93
94 AsyncBufferLoader::<MP, S, CM, L> {
95 tensor_reader,
96 stage_memory,
97 loading_job,
98 input_ident,
99 _phantom: PhantomData::<(S, CM)>,
100 }
101 }
102
103 pub fn reader(
105 this: &Self,
106 #[comptime] stage_ident: StageIdent,
107 ) -> PartialStageToTileReader<MP::ES, L::TilingLayout> {
108 PartialStageToTileReader::new(this.stage_memory, stage_ident, this.input_ident)
109 }
110
111 pub fn advance_view(this: &mut Self, k_offset: u32) {
113 this.tensor_reader.update_view(k_offset, this.input_ident);
114 }
115
116 pub fn fill_stage(
118 this: &mut Self,
119 mechanism: &CM,
120 #[comptime] stage_ident: StageIdent,
121 #[comptime] config: DoubleBufferingGlobalConfig<S>,
122 ) {
123 let mut loading_job = match this.loading_job {
124 CubeOption::Some(job) => match stage_ident {
125 StageIdent::A => job.0,
126 StageIdent::B => job.1,
127 },
128 CubeOption::None => match stage_ident {
129 StageIdent::A => {
130 L::new_job::<MP, DoubleBufferingGlobalConfig<S>>(0u32, this.input_ident, config)
131 }
132 StageIdent::B => {
133 L::new_job::<MP, DoubleBufferingGlobalConfig<S>>(1u32, this.input_ident, config)
134 }
135 },
136 };
137
138 let len = L::Job::task_count(&loading_job);
139 for task_id in 0..len {
140 L::Job::<MP>::execute_task::<CM, DoubleBufferingGlobalConfig<S>>(
141 &mut loading_job,
142 task_id,
143 &this.tensor_reader,
144 &mut this.stage_memory,
145 mechanism,
146 config,
147 );
148 }
149 }
150
151 pub fn clear_stage(
153 this: &mut Self,
154 #[comptime] stage_ident: StageIdent,
155 #[comptime] config: DoubleBufferingGlobalConfig<S>,
156 ) {
157 this.stage_memory
158 .clear_stage::<DoubleBufferingGlobalConfig<S>>(stage_ident, this.input_ident, config)
159 }
160}