cubecl_linalg/matmul/components/global/load/loader/
sync_buffer_loader.rs

1use std::marker::PhantomData;
2
3use super::BufferId;
4use crate::matmul::components::InputIdent;
5use crate::matmul::components::MatmulPrecision;
6use crate::matmul::components::global::GlobalConfig;
7use crate::matmul::components::global::LoadingValidation;
8use crate::matmul::components::global::Quantization;
9use crate::matmul::components::global::load::LoadingJob;
10use crate::matmul::components::global::tensor_view::TensorReader;
11use crate::matmul::components::stage::BufferReader;
12use crate::matmul::components::stage::Stage;
13use crate::matmul::components::stage::TilingLayout;
14use cubecl_core as cubecl;
15use cubecl_core::prelude::*;
16use cubecl_std::tensor::r#virtual::VirtualTensor;
17use cubecl_std::{CubeOption, CubeOptionExpand};
18
19#[cube]
20/// A strategy for synchronously loading a buffer (partial stage), either eagerly or as a deferred job.
21pub trait SyncBufferLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
22    /// The layout describing how data is tiled across the stage.
23    type TilingLayout: TilingLayout;
24
25    /// The [LoadingJob] for this strategy.
26    type Job<MP: MatmulPrecision>: LoadingJob<MP, Self::TilingLayout>;
27
28    /// Returns the job with preliminary calculations done.
29    fn new_job<MP: MatmulPrecision, G: GlobalConfig>(
30        #[comptime] buffer_index: u32,
31        #[comptime] ident: InputIdent,
32        #[comptime] config: G,
33    ) -> Self::Job<MP>;
34}
35
36#[derive(Clone, CubeType)]
37pub struct SyncBufferLoader<MP: MatmulPrecision, G: GlobalConfig, L: SyncBufferLoadingStrategy> {
38    tensor_reader: TensorReader<MP::EI>,
39    stage: Stage<MP::ES, L::TilingLayout>,
40    loading_job: CubeOption<(L::Job<MP>, L::Job<MP>)>,
41    quantization: CubeOption<Quantization<MP>>,
42    #[cube(comptime)]
43    input_ident: InputIdent,
44    #[cube(comptime)]
45    _config: PhantomData<G>,
46}
47
48#[cube]
49impl<MP: MatmulPrecision, G: GlobalConfig, L: SyncBufferLoadingStrategy>
50    SyncBufferLoader<MP, G, L>
51{
52    pub fn new(
53        tensor: VirtualTensor<MP::EI>,
54        x_offset: u32,
55        y_offset: u32,
56        batch_offset: u32,
57        quantization: CubeOption<Quantization<MP>>,
58        #[comptime] input_ident: InputIdent,
59        #[comptime] config: G,
60    ) -> Self {
61        let stage = Stage::new::<G::SmmConfig>(input_ident.as_ident(), config.to_smm_config());
62        let tensor_reader = TensorReader::new(tensor, x_offset, y_offset, batch_offset);
63
64        let loading_job = match config.precompute_job() {
65            true => CubeOption::new_Some((
66                L::new_job::<MP, G>(0u32, input_ident, config),
67                L::new_job::<MP, G>(1u32, input_ident, config),
68            )),
69            false => CubeOption::new_None(),
70        };
71
72        SyncBufferLoader::<MP, G, L> {
73            tensor_reader,
74            stage,
75            loading_job,
76            quantization,
77            input_ident,
78            _config: PhantomData::<G>,
79        }
80    }
81
82    pub fn reader(
83        this: &Self,
84        #[comptime] buffer_id: BufferId,
85    ) -> BufferReader<MP::ES, L::TilingLayout> {
86        BufferReader::new(this.stage, buffer_id, this.input_ident)
87    }
88
89    pub fn advance_view(this: &mut Self, k_offset: u32) {
90        this.tensor_reader.update_view(k_offset, this.input_ident);
91    }
92
93    pub fn fill_stage(this: &mut Self, #[comptime] buffer_id: BufferId, #[comptime] config: G) {
94        let mut loading_job = match this.loading_job {
95            CubeOption::Some(job) => match buffer_id {
96                BufferId::A => job.0,
97                BufferId::B => job.1,
98            },
99            CubeOption::None => match buffer_id {
100                BufferId::A => L::new_job::<MP, G>(0u32, this.input_ident, config),
101                BufferId::B => L::new_job::<MP, G>(1u32, this.input_ident, config),
102            },
103        };
104
105        let len = L::Job::task_count(&loading_job);
106        for task_id in 0..len {
107            L::Job::<MP>::execute_task::<G>(
108                &mut loading_job,
109                task_id,
110                &this.tensor_reader,
111                &mut this.stage,
112                &this.quantization,
113                config,
114            );
115        }
116    }
117
118    pub fn create_job(
119        this: &Self,
120        #[comptime] buffer_id: BufferId,
121        #[comptime] config: G,
122    ) -> SyncBufferLoaderJob<MP, L> {
123        let loading = match this.loading_job {
124            CubeOption::Some(job) => match buffer_id {
125                BufferId::A => job.0,
126                BufferId::B => job.1,
127            },
128            CubeOption::None => match buffer_id {
129                BufferId::A => L::new_job::<MP, G>(0u32, this.input_ident, config),
130                BufferId::B => L::new_job::<MP, G>(1u32, this.input_ident, config),
131            },
132        };
133
134        let num_tasks = L::Job::task_count(&loading);
135
136        SyncBufferLoaderJob::<MP, L> {
137            loading,
138            num_tasks,
139            current: ComptimeCell::new(TaskCounter { counter: 0u32 }),
140        }
141    }
142
143    pub fn execute_task(
144        this: &mut Self,
145        job: &mut SyncBufferLoaderJob<MP, L>,
146        #[comptime] config: G,
147    ) {
148        let task_id = job.current.read().counter;
149
150        L::Job::<MP>::execute_task::<G>(
151            &mut job.loading,
152            task_id,
153            &this.tensor_reader,
154            &mut this.stage,
155            &this.quantization,
156            config,
157        );
158
159        job.current.store(TaskCounter {
160            counter: comptime!(task_id + 1u32),
161        });
162    }
163}
164
165#[derive(CubeType)]
166pub struct SyncBufferLoaderJob<MP: MatmulPrecision, L: SyncBufferLoadingStrategy> {
167    loading: L::Job<MP>,
168    #[cube(comptime)]
169    pub num_tasks: u32,
170    pub current: ComptimeCell<TaskCounter>,
171}
172
173#[derive(CubeType, Clone)]
174pub struct TaskCounter {
175    #[cube(comptime)]
176    pub counter: u32,
177}