cubecl_linalg/matmul/components/global/load/loader/
async_buffer_loader.rs1use super::BufferId;
2use crate::matmul::components::global::base::GlobalConfig;
3use crate::matmul::components::global::load::AsyncLoadingJob;
4use crate::matmul::components::global::tensor_view::TensorReader;
5use crate::matmul::components::global::{
6 CommonGlobalConfig, CopyMechanism, LoadingValidation, Quantization,
7};
8use crate::matmul::components::stage::BufferReader;
9use crate::matmul::components::stage::TilingLayout;
10use crate::matmul::components::stage::{self, Stage};
11use crate::matmul::components::{InputIdent, MatmulPrecision};
12use core::marker::PhantomData;
13use cubecl_core as cubecl;
14use cubecl_core::prelude::barrier::BarrierLevel;
15use cubecl_core::prelude::*;
16use cubecl_std::tensor::r#virtual::VirtualTensor;
17use cubecl_std::{CubeOption, CubeOptionExpand};
18
19#[cube]
20pub trait AsyncBufferLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
22 type TilingLayout: TilingLayout;
24
25 type Job<MP: MatmulPrecision>: AsyncLoadingJob<MP, Self::TilingLayout>;
27
28 fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
30 #[comptime] buffer_index: u32,
31 #[comptime] ident: InputIdent,
32 #[comptime] config: G,
33 ) -> Self::Job<MP>;
34
35 fn barrier_level() -> BarrierLevel;
37}
38
39#[derive(CubeType)]
40pub struct AsyncBufferLoader<
41 MP: MatmulPrecision,
42 S: stage::StageConfig,
43 CM: CopyMechanism<MP::ES>,
44 L: AsyncBufferLoadingStrategy,
45> {
46 tensor_reader: TensorReader<MP::EI>,
47 stage: Stage<MP::ES, L::TilingLayout>,
48 loading_job: CubeOption<(L::Job<MP>, L::Job<MP>)>,
49 #[cube(comptime)]
50 input_ident: InputIdent,
51 #[cube(comptime)]
52 _phantom: PhantomData<(S, CM)>,
53}
54
55#[cube]
56impl<
57 MP: MatmulPrecision,
58 S: stage::StageConfig,
59 CM: CopyMechanism<MP::ES>,
60 L: AsyncBufferLoadingStrategy,
61> AsyncBufferLoader<MP, S, CM, L>
62{
63 pub fn new(
64 tensor: VirtualTensor<MP::EI>,
65 x_offset: u32,
66 y_offset: u32,
67 batch_offset: u32,
68 quantization: CubeOption<Quantization<MP>>,
69 #[comptime] input_ident: InputIdent,
70 #[comptime] config: CommonGlobalConfig<S>,
71 ) -> Self {
72 comptime! {
73 if quantization.is_some() {
74 todo!();
75 }
76 }
77
78 let stage = Stage::new::<S>(input_ident.as_ident(), config.to_smm_config());
79 let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
80 let loading_job = match config.precompute_job() {
81 true => CubeOption::new_Some((
82 L::new_job::<MP, CommonGlobalConfig<S>>(0u32, input_ident, config),
83 L::new_job::<MP, CommonGlobalConfig<S>>(1u32, input_ident, config),
84 )),
85 false => CubeOption::new_None(),
86 };
87
88 AsyncBufferLoader::<MP, S, CM, L> {
89 tensor_reader,
90 stage,
91 loading_job,
92 input_ident,
93 _phantom: PhantomData::<(S, CM)>,
94 }
95 }
96
97 pub fn reader(
98 this: &Self,
99 #[comptime] buffer_id: BufferId,
100 ) -> BufferReader<MP::ES, L::TilingLayout> {
101 BufferReader::new(this.stage, buffer_id, this.input_ident)
102 }
103
104 pub fn advance_view(this: &mut Self, k_offset: u32) {
105 this.tensor_reader.update_view(k_offset, this.input_ident);
106 }
107
108 pub fn fill_stage(
109 this: &mut Self,
110 mechanism: &CM,
111 #[comptime] buffer_id: BufferId,
112 #[comptime] config: CommonGlobalConfig<S>,
113 ) {
114 let mut loading_job = match this.loading_job {
115 CubeOption::Some(job) => match buffer_id {
116 BufferId::A => job.0,
117 BufferId::B => job.1,
118 },
119 CubeOption::None => match buffer_id {
120 BufferId::A => {
121 L::new_job::<MP, CommonGlobalConfig<S>>(0u32, this.input_ident, config)
122 }
123 BufferId::B => {
124 L::new_job::<MP, CommonGlobalConfig<S>>(1u32, this.input_ident, config)
125 }
126 },
127 };
128
129 let len = L::Job::task_count(&loading_job);
130 for task_id in 0..len {
131 L::Job::<MP>::execute_task::<CM, CommonGlobalConfig<S>>(
132 &mut loading_job,
133 task_id,
134 &this.tensor_reader,
135 &mut this.stage,
136 mechanism,
137 config,
138 );
139 }
140 }
141
142 pub fn clear_stage(
143 this: &mut Self,
144 #[comptime] buffer_id: BufferId,
145 #[comptime] config: CommonGlobalConfig<S>,
146 ) {
147 this.stage
148 .clear_buffer::<S>(buffer_id, this.input_ident, config.to_smm_config())
149 }
150}