cubecl_matmul/components/global/read/reader/
async_partial_reader.rs1use super::StageBuffer;
2use crate::components::global::CopyMechanism;
3use crate::components::global::base::GlobalConfig;
4use crate::components::global::memory::GlobalIterator;
5use crate::components::global::multi_stage::double_buffering::DoubleBufferingGlobalConfig;
6use crate::components::global::read::{AsyncLoadingJob, LoadingValidation};
7use crate::components::stage::TilingLayout;
8use crate::components::stage::{self, StridedStage};
9use crate::components::{MatmulIdent, MatrixPrecision};
10use core::marker::PhantomData;
11use cubecl_core as cubecl;
12use cubecl_core::prelude::barrier::BarrierLevel;
13use cubecl_core::prelude::*;
14use cubecl_std::{
15    CubeOption, CubeOptionExpand,
16    tensor::{View, layout::Coords2d},
17};
18
19#[cube]
20pub trait AsyncPartialLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
22    type TilingLayout: TilingLayout;
24
25    type Job<IP: MatrixPrecision>: AsyncLoadingJob<IP, Self::TilingLayout>;
27
28    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
30        #[comptime] buffer_index: u32,
31        #[comptime] ident: MatmulIdent,
32        #[comptime] config: G,
33    ) -> Self::Job<IP>;
34
35    fn barrier_level() -> BarrierLevel;
37}
38
39#[derive(CubeType)]
40pub struct AsyncBufferGlobalReader<
45    IP: MatrixPrecision,
46    S: stage::StageConfig,
47    CM: CopyMechanism,
48    L: AsyncPartialLoadingStrategy,
49> {
50    global_iter: GlobalIterator<Line<IP::Global>>,
51    stage: StridedStage<IP::Stage, L::TilingLayout>,
52    loading_job: CubeOption<(L::Job<IP>, L::Job<IP>)>,
53    #[cube(comptime)]
54    ident: MatmulIdent,
55    #[cube(comptime)]
56    _phantom: PhantomData<(S, CM)>,
57}
58
59#[cube]
60impl<IP: MatrixPrecision, S: stage::StageConfig, CM: CopyMechanism, L: AsyncPartialLoadingStrategy>
61    AsyncBufferGlobalReader<IP, S, CM, L>
62{
63    pub fn new(
65        tensor: View<Line<IP::Global>, Coords2d>,
66        k_step: u32,
67        #[comptime] ident: MatmulIdent,
68        #[comptime] config: DoubleBufferingGlobalConfig<S>,
69    ) -> Self {
70        let stage = StridedStage::new(
71            comptime!(ident.into_stage()),
72            config.stage_memory_config(ident),
73        );
74        let global_iter = GlobalIterator::new(tensor, k_step, ident.view_direction(), true);
75
76        let loading_job = match config.precompute_job() {
77            true => CubeOption::new_Some((
78                L::new_job::<IP, DoubleBufferingGlobalConfig<S>>(0u32, ident, config),
79                L::new_job::<IP, DoubleBufferingGlobalConfig<S>>(1u32, ident, config),
80            )),
81            false => CubeOption::new_None(),
82        };
83
84        AsyncBufferGlobalReader::<IP, S, CM, L> {
85            global_iter,
86            stage,
87            loading_job,
88            ident,
89            _phantom: PhantomData::<(S, CM)>,
90        }
91    }
92
93    pub fn stage(
95        &mut self,
96        #[comptime] stage_buffer: StageBuffer,
97    ) -> StridedStage<IP::Stage, L::TilingLayout> {
98        self.stage.with_buffer_index(stage_buffer.to_index())
99    }
100
101    pub fn advance_view(&mut self) {
103        self.global_iter.advance();
104    }
105
106    pub fn load_stage(
108        &mut self,
109        mechanism: &CM,
110        #[comptime] stage_buffer: StageBuffer,
111        #[comptime] config: DoubleBufferingGlobalConfig<S>,
112    ) {
113        let mut loading_job = match self.loading_job {
114            CubeOption::Some(job) => match stage_buffer {
115                StageBuffer::A => job.0,
116                StageBuffer::B => job.1,
117            },
118            CubeOption::None => match stage_buffer {
119                StageBuffer::A => {
120                    L::new_job::<IP, DoubleBufferingGlobalConfig<S>>(0u32, self.ident, config)
121                }
122                StageBuffer::B => {
123                    L::new_job::<IP, DoubleBufferingGlobalConfig<S>>(1u32, self.ident, config)
124                }
125            },
126        };
127
128        let len = L::Job::task_count(&loading_job);
129        for task_id in 0..len {
130            L::Job::<IP>::execute_task::<CM, DoubleBufferingGlobalConfig<S>>(
131                &mut loading_job,
132                task_id,
133                &self.global_iter,
134                &mut self.stage,
135                mechanism,
136                config,
137            );
138        }
139    }
140
141    pub fn clear_stage(
143        &mut self,
144        #[comptime] stage_buffer: StageBuffer,
145        #[comptime] config: DoubleBufferingGlobalConfig<S>,
146    ) {
147        self.stage
148            .clear_stage::<DoubleBufferingGlobalConfig<S>>(stage_buffer, self.ident, config)
149    }
150}