cubecl_linalg/matmul/components/global/load/loader/
sync_buffer_loader.rs1use std::marker::PhantomData;
2
3use super::BufferId;
4use crate::matmul::components::InputIdent;
5use crate::matmul::components::MatmulPrecision;
6use crate::matmul::components::global::GlobalConfig;
7use crate::matmul::components::global::LoadingValidation;
8use crate::matmul::components::global::Quantization;
9use crate::matmul::components::global::load::LoadingJob;
10use crate::matmul::components::global::tensor_view::TensorReader;
11use crate::matmul::components::stage::BufferReader;
12use crate::matmul::components::stage::Stage;
13use crate::matmul::components::stage::TilingLayout;
14use cubecl_core as cubecl;
15use cubecl_core::prelude::*;
16use cubecl_std::tensor::r#virtual::VirtualTensor;
17use cubecl_std::{CubeOption, CubeOptionExpand};
18
19#[cube]
20pub trait SyncBufferLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
22 type TilingLayout: TilingLayout;
24
25 type Job<MP: MatmulPrecision>: LoadingJob<MP, Self::TilingLayout>;
27
28 fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
30 #[comptime] buffer_index: u32,
31 #[comptime] ident: InputIdent,
32 #[comptime] config: G,
33 ) -> Self::Job<MP>;
34}
35
36#[derive(Clone, CubeType)]
37pub struct SyncBufferLoader<MP: MatmulPrecision, G: GlobalConfig, L: SyncBufferLoadingStrategy> {
38 tensor_reader: TensorReader<MP::EI>,
39 stage: Stage<MP::ES, L::TilingLayout>,
40 loading_job: CubeOption<(L::Job<MP>, L::Job<MP>)>,
41 quantization: CubeOption<Quantization<MP>>,
42 #[cube(comptime)]
43 input_ident: InputIdent,
44 #[cube(comptime)]
45 _config: PhantomData<G>,
46}
47
48#[cube]
49impl<MP: MatmulPrecision, G: GlobalConfig, L: SyncBufferLoadingStrategy>
50 SyncBufferLoader<MP, G, L>
51{
52 pub fn new(
53 tensor: VirtualTensor<MP::EI>,
54 x_offset: u32,
55 y_offset: u32,
56 batch_offset: u32,
57 quantization: CubeOption<Quantization<MP>>,
58 #[comptime] input_ident: InputIdent,
59 #[comptime] config: G,
60 ) -> Self {
61 let stage = Stage::new::<G::SmmConfig>(input_ident.as_ident(), config.to_smm_config());
62 let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
63
64 let loading_job = match config.precompute_job() {
65 true => CubeOption::new_Some((
66 L::new_job::<MP, G>(0u32, input_ident, config),
67 L::new_job::<MP, G>(1u32, input_ident, config),
68 )),
69 false => CubeOption::new_None(),
70 };
71
72 SyncBufferLoader::<MP, G, L> {
73 tensor_reader,
74 stage,
75 loading_job,
76 quantization,
77 input_ident,
78 _config: PhantomData::<G>,
79 }
80 }
81
82 pub fn reader(
83 this: &Self,
84 #[comptime] buffer_id: BufferId,
85 ) -> BufferReader<MP::ES, L::TilingLayout> {
86 BufferReader::new(this.stage, buffer_id, this.input_ident)
87 }
88
89 pub fn advance_view(this: &mut Self, k_offset: u32) {
90 this.tensor_reader.update_view(k_offset, this.input_ident);
91 }
92
93 pub fn fill_stage(this: &mut Self, #[comptime] buffer_id: BufferId, #[comptime] config: G) {
94 let mut loading_job = match this.loading_job {
95 CubeOption::Some(job) => match buffer_id {
96 BufferId::A => job.0,
97 BufferId::B => job.1,
98 },
99 CubeOption::None => match buffer_id {
100 BufferId::A => L::new_job::<MP, G>(0u32, this.input_ident, config),
101 BufferId::B => L::new_job::<MP, G>(1u32, this.input_ident, config),
102 },
103 };
104
105 let len = L::Job::task_count(&loading_job);
106 for task_id in 0..len {
107 L::Job::<MP>::execute_task::<G>(
108 &mut loading_job,
109 task_id,
110 &this.tensor_reader,
111 &mut this.stage,
112 &this.quantization,
113 config,
114 );
115 }
116 }
117
118 pub fn create_job(
119 this: &Self,
120 #[comptime] buffer_id: BufferId,
121 #[comptime] config: G,
122 ) -> SyncBufferLoaderJob<MP, L> {
123 let loading = match this.loading_job {
124 CubeOption::Some(job) => match buffer_id {
125 BufferId::A => job.0,
126 BufferId::B => job.1,
127 },
128 CubeOption::None => match buffer_id {
129 BufferId::A => L::new_job::<MP, G>(0u32, this.input_ident, config),
130 BufferId::B => L::new_job::<MP, G>(1u32, this.input_ident, config),
131 },
132 };
133
134 let num_tasks = L::Job::task_count(&loading);
135
136 SyncBufferLoaderJob::<MP, L> {
137 loading,
138 num_tasks,
139 current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
140 }
141 }
142
143 pub fn execute_task(
144 this: &mut Self,
145 job: &mut SyncBufferLoaderJob<MP, L>,
146 #[comptime] config: G,
147 ) {
148 let task_id = job.current.read().counter;
149
150 L::Job::<MP>::execute_task::<G>(
151 &mut job.loading,
152 task_id,
153 &this.tensor_reader,
154 &mut this.stage,
155 &this.quantization,
156 config,
157 );
158
159 job.current.store(TaskCounter {
160 counter: comptime!(task_id + 1u32),
161 });
162 }
163}
164
165#[derive(CubeType)]
166pub struct SyncBufferLoaderJob<MP: MatmulPrecision, L: SyncBufferLoadingStrategy> {
167 loading: L::Job<MP>,
168 #[cube(comptime)]
169 pub num_tasks: u32,
170 pub current: ComptimeCell<TaskCounter>,
171}
172
173#[derive(CubeType, Clone)]
174pub struct TaskCounter {
175 #[cube(comptime)]
176 pub counter: u32,
177}