cubecl_linalg/matmul/components/global/load/loader/
async_full_loader.rs1use std::marker::PhantomData;
2
3use crate::matmul::components::global::load::AsyncLoadingJob;
4use crate::matmul::components::global::tensor_view::TensorReader;
5use crate::matmul::components::global::{CopyMechanism, GlobalConfig, LoadingValidation};
6use crate::matmul::components::global::{Quantization, single_stage};
7use crate::matmul::components::stage::FullReader;
8use crate::matmul::components::stage::TilingLayout;
9use crate::matmul::components::stage::{self, Stage};
10use crate::matmul::components::{Ident, InputIdent, MatmulPrecision, global};
11use cubecl_core as cubecl;
12use cubecl_core::prelude::barrier::BarrierLevel;
13use cubecl_core::prelude::*;
14use cubecl_std::tensor::r#virtual::VirtualTensor;
15use cubecl_std::{CubeOption, CubeOptionExpand};
16
17#[cube]
18pub trait AsyncFullLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
20 type TilingLayout: TilingLayout;
22
23 type Job<MP: MatmulPrecision>: AsyncLoadingJob<MP, Self::TilingLayout>;
25
26 fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
28 #[comptime] ident: InputIdent,
29 #[comptime] config: G,
30 ) -> Self::Job<MP>;
31
32 fn barrier_level() -> BarrierLevel;
34}
35
36#[derive(CubeType)]
37pub struct AsyncLoader<
38 MP: MatmulPrecision,
39 CM: CopyMechanism<MP::ES>,
40 S: stage::StageConfig,
41 L: AsyncFullLoadingStrategy,
42> {
43 tensor_reader: TensorReader<MP::EI>,
44 stage: Stage<MP::ES, L::TilingLayout>,
45 loading_job: CubeOption<L::Job<MP>>,
46 #[cube(comptime)]
47 ident: InputIdent,
48 #[cube(comptime)]
49 _phantom: PhantomData<(S, L, CM)>,
50}
51
52#[cube]
53impl<
54 MP: MatmulPrecision,
55 CM: CopyMechanism<MP::ES>,
56 S: stage::StageConfig,
57 L: AsyncFullLoadingStrategy,
58> AsyncLoader<MP, CM, S, L>
59{
60 pub fn new<G: global::GlobalConfig>(
61 tensor: VirtualTensor<MP::EI>,
62 x_offset: u32,
63 y_offset: u32,
64 batch_offset: u32,
65 quantization: CubeOption<Quantization<MP>>,
66 #[comptime] ident: InputIdent,
67 #[comptime] config: G,
68 ) -> Self {
69 comptime! {
70 if quantization.is_some() {
71 todo!();
72 }
73 }
74
75 let mut stage = Stage::new::<G::SmmConfig>(ident.as_ident(), config.to_smm_config());
76
77 let loading_job = match config.precompute_job() {
78 true => CubeOption::new_Some(L::new_job::<MP, G>(ident, config)),
79 false => CubeOption::new_None(),
80 };
81
82 match ident {
83 InputIdent::Lhs =>
84 {
85 #[allow(clippy::collapsible_if)]
86 if config.check_row_bounds(ident) {
87 if x_offset
88 > tensor.shape(tensor.rank() - 2)
89 - config.tiling_dimensions(Ident::Lhs).total_row()
90 {
91 stage.clear::<G::SmmConfig>(ident, config.to_smm_config());
92 }
93 }
94 }
95 InputIdent::Rhs =>
96 {
97 #[allow(clippy::collapsible_if)]
98 if config.check_col_bounds(ident) {
99 if y_offset
100 > tensor.shape(tensor.rank() - 1)
101 - config.tiling_dimensions(Ident::Rhs).total_col()
102 {
103 stage.clear::<G::SmmConfig>(ident, config.to_smm_config());
104 }
105 }
106 }
107 }
108
109 let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
110
111 AsyncLoader::<MP, CM, S, L> {
112 tensor_reader,
113 stage,
114 loading_job,
115 ident,
116 _phantom: PhantomData::<(S, L, CM)>,
117 }
118 }
119
120 pub fn fill_stage(
121 this: &mut Self,
122 mechanism: &CM,
123 #[comptime] config: single_stage::Config<S>,
124 ) {
125 let mut loading_job = match this.loading_job {
126 CubeOption::Some(loading_job) => loading_job,
127 CubeOption::None => L::new_job::<MP, single_stage::Config<S>>(this.ident, config),
128 };
129
130 let len = L::Job::task_count(&loading_job);
131 for task_id in 0..len {
132 L::Job::<MP>::execute_task::<CM, single_stage::Config<S>>(
133 &mut loading_job,
134 task_id,
135 &this.tensor_reader,
136 &mut this.stage,
137 mechanism,
138 config,
139 );
140 }
141 }
142
143 pub fn clear_stage(this: &mut Self, #[comptime] config: single_stage::Config<S>) {
144 this.stage.clear::<S>(this.ident, config.to_smm_config())
145 }
146
147 pub fn reader(this: &Self) -> FullReader<MP::ES, L::TilingLayout> {
148 FullReader::new(this.stage, this.ident)
149 }
150
151 pub fn advance_view(this: &mut Self, k_offset: u32) {
152 this.tensor_reader.update_view(k_offset, this.ident);
153 }
154}