cubecl_matmul/components/global/load/loader/
async_full_loader.rs1use std::marker::PhantomData;
2
3use crate::components::global::Quantization;
4use crate::components::global::global_memory::TensorReader;
5use crate::components::global::load::{AsyncLoadingJob, LoadingValidation};
6use crate::components::global::{CopyMechanism, GlobalConfig};
7use crate::components::stage::FullStageToTileReader;
8use crate::components::stage::TilingLayout;
9use crate::components::stage::{self, StageMemory};
10use crate::components::{InputIdent, MatmulPrecision};
11use cubecl_core as cubecl;
12use cubecl_core::prelude::barrier::BarrierLevel;
13use cubecl_core::prelude::*;
14use cubecl_std::tensor::r#virtual::VirtualTensor;
15use cubecl_std::{CubeOption, CubeOptionExpand};
16
17#[cube]
18pub trait AsyncFullLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
20 type TilingLayout: TilingLayout;
22
23 type Job<MP: MatmulPrecision>: AsyncLoadingJob<MP, Self::TilingLayout>;
25
26 fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
28 #[comptime] ident: InputIdent,
29 #[comptime] config: G,
30 ) -> Self::Job<MP>;
31
32 fn barrier_level() -> BarrierLevel;
34}
35
36#[derive(CubeType)]
37pub struct AsyncFullLoader<
42 MP: MatmulPrecision,
43 CM: CopyMechanism<MP::ES>,
44 S: stage::StageConfig,
45 L: AsyncFullLoadingStrategy,
46 G: GlobalConfig,
47> {
48 tensor_reader: TensorReader<MP::EI>,
49 stage_memory: StageMemory<MP::ES, L::TilingLayout>,
50 loading_job: CubeOption<L::Job<MP>>,
51 #[cube(comptime)]
52 ident: InputIdent,
53 #[cube(comptime)]
54 _phantom: PhantomData<(S, L, CM, G)>,
55}
56
57#[cube]
58impl<
59 MP: MatmulPrecision,
60 CM: CopyMechanism<MP::ES>,
61 S: stage::StageConfig,
62 L: AsyncFullLoadingStrategy,
63 G: GlobalConfig,
64> AsyncFullLoader<MP, CM, S, L, G>
65{
66 pub fn new(
68 tensor: VirtualTensor<MP::EI>,
69 x_offset: u32,
70 y_offset: u32,
71 batch_offset: u32,
72 quantization: CubeOption<Quantization<MP>>,
73 #[comptime] ident: InputIdent,
74 #[comptime] config: G,
75 ) -> Self {
76 comptime! {
77 if quantization.is_some() {
78 todo!();
79 }
80 }
81
82 let mut stage_memory =
83 StageMemory::new::<G::StageConfig>(1u32, ident.as_ident(), config.stage_config());
84 let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
85
86 let loading_job = match config.precompute_job() {
87 true => CubeOption::new_Some(L::new_job::<MP, G>(ident, config)),
88 false => CubeOption::new_None(),
89 };
90
91 match ident {
92 InputIdent::Lhs =>
93 {
94 #[allow(clippy::collapsible_if)]
95 if config.check_row_bounds(ident) {
96 if tensor_reader.x_offset.read()
97 > tensor_reader.shape_x - config.tiling_scheme().elements_in_stage_m()
98 {
99 stage_memory.clear_all::<G>(ident, config);
100 }
101 }
102 }
103 InputIdent::Rhs =>
104 {
105 #[allow(clippy::collapsible_if)]
106 if config.check_col_bounds(ident) {
107 if tensor_reader.y_offset.read()
108 > tensor_reader.shape_y - config.tiling_scheme().elements_in_stage_n()
109 {
110 stage_memory.clear_all::<G>(ident, config);
111 }
112 }
113 }
114 }
115
116 AsyncFullLoader::<MP, CM, S, L, G> {
117 tensor_reader,
118 stage_memory,
119 loading_job,
120 ident,
121 _phantom: PhantomData,
122 }
123 }
124
125 pub fn fill_stage(this: &mut Self, mechanism: &CM, #[comptime] config: G) {
127 let mut loading_job = match this.loading_job {
128 CubeOption::Some(loading_job) => loading_job,
129 CubeOption::None => L::new_job::<MP, G>(this.ident, config),
130 };
131
132 let len = L::Job::task_count(&loading_job);
133 for task_id in 0..len {
134 L::Job::<MP>::execute_task::<CM, G>(
135 &mut loading_job,
136 task_id,
137 &this.tensor_reader,
138 &mut this.stage_memory,
139 mechanism,
140 config,
141 );
142 }
143 }
144
145 pub fn clear_stage(this: &mut Self, #[comptime] config: G) {
147 this.stage_memory.clear_all::<G>(this.ident, config)
148 }
149
150 pub fn reader(this: &Self) -> FullStageToTileReader<MP::ES, L::TilingLayout> {
152 FullStageToTileReader::new(this.stage_memory, this.ident)
153 }
154
155 pub fn advance_view(this: &mut Self, k_offset: u32) {
157 this.tensor_reader.update_view(k_offset, this.ident);
158 }
159}