cubecl_matmul/components/global/read/reader/
async_partial_reader.rs1use super::StageBuffer;
2use crate::components::global::CopyMechanism;
3use crate::components::global::base::GlobalConfig;
4use crate::components::global::memory::GlobalIterator;
5use crate::components::global::multi_stage::double_buffering::DoubleBufferingGlobalConfig;
6use crate::components::global::read::{AsyncLoadingJob, LoadingValidation};
7use crate::components::stage::TilingLayout;
8use crate::components::stage::{self, StridedStage};
9use crate::components::{MatmulIdent, MatrixPrecision};
10use core::marker::PhantomData;
11use cubecl_core as cubecl;
12use cubecl_core::prelude::barrier::BarrierLevel;
13use cubecl_core::prelude::*;
14use cubecl_std::{
15 CubeOption, CubeOptionExpand,
16 tensor::{View, layout::Coords2d},
17};
18
19#[cube]
20pub trait AsyncPartialLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
22 type TilingLayout: TilingLayout;
24
25 type Job<IP: MatrixPrecision>: AsyncLoadingJob<IP, Self::TilingLayout>;
27
28 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
30 #[comptime] buffer_index: u32,
31 #[comptime] ident: MatmulIdent,
32 #[comptime] config: G,
33 ) -> Self::Job<IP>;
34
35 fn barrier_level() -> BarrierLevel;
37}
38
39#[derive(CubeType)]
40pub struct AsyncBufferGlobalReader<
45 IP: MatrixPrecision,
46 S: stage::StageConfig,
47 CM: CopyMechanism,
48 L: AsyncPartialLoadingStrategy,
49> {
50 global_iter: GlobalIterator<Line<IP::Global>>,
51 stage: StridedStage<IP::Stage, L::TilingLayout>,
52 loading_job: CubeOption<(L::Job<IP>, L::Job<IP>)>,
53 #[cube(comptime)]
54 ident: MatmulIdent,
55 #[cube(comptime)]
56 _phantom: PhantomData<(S, CM)>,
57}
58
59#[cube]
60impl<IP: MatrixPrecision, S: stage::StageConfig, CM: CopyMechanism, L: AsyncPartialLoadingStrategy>
61 AsyncBufferGlobalReader<IP, S, CM, L>
62{
63 pub fn new(
65 tensor: View<Line<IP::Global>, Coords2d>,
66 k_step: u32,
67 #[comptime] ident: MatmulIdent,
68 #[comptime] config: DoubleBufferingGlobalConfig<S>,
69 ) -> Self {
70 let stage = StridedStage::new(
71 comptime!(ident.into_stage()),
72 config.stage_memory_config(ident),
73 );
74 let global_iter = GlobalIterator::new(tensor, k_step, ident.view_direction(), true);
75
76 let loading_job = match config.precompute_job() {
77 true => CubeOption::new_Some((
78 L::new_job::<IP, DoubleBufferingGlobalConfig<S>>(0u32, ident, config),
79 L::new_job::<IP, DoubleBufferingGlobalConfig<S>>(1u32, ident, config),
80 )),
81 false => CubeOption::new_None(),
82 };
83
84 AsyncBufferGlobalReader::<IP, S, CM, L> {
85 global_iter,
86 stage,
87 loading_job,
88 ident,
89 _phantom: PhantomData::<(S, CM)>,
90 }
91 }
92
93 pub fn stage(
95 &mut self,
96 #[comptime] stage_buffer: StageBuffer,
97 ) -> StridedStage<IP::Stage, L::TilingLayout> {
98 self.stage.with_buffer_index(stage_buffer.to_index())
99 }
100
101 pub fn advance_view(&mut self) {
103 self.global_iter.advance();
104 }
105
106 pub fn load_stage(
108 &mut self,
109 mechanism: &CM,
110 #[comptime] stage_buffer: StageBuffer,
111 #[comptime] config: DoubleBufferingGlobalConfig<S>,
112 ) {
113 let mut loading_job = match self.loading_job {
114 CubeOption::Some(job) => match stage_buffer {
115 StageBuffer::A => job.0,
116 StageBuffer::B => job.1,
117 },
118 CubeOption::None => match stage_buffer {
119 StageBuffer::A => {
120 L::new_job::<IP, DoubleBufferingGlobalConfig<S>>(0u32, self.ident, config)
121 }
122 StageBuffer::B => {
123 L::new_job::<IP, DoubleBufferingGlobalConfig<S>>(1u32, self.ident, config)
124 }
125 },
126 };
127
128 let len = L::Job::task_count(&loading_job);
129 for task_id in 0..len {
130 L::Job::<IP>::execute_task::<CM, DoubleBufferingGlobalConfig<S>>(
131 &mut loading_job,
132 task_id,
133 &self.global_iter,
134 &mut self.stage,
135 mechanism,
136 config,
137 );
138 }
139 }
140
141 pub fn clear_stage(
143 &mut self,
144 #[comptime] stage_buffer: StageBuffer,
145 #[comptime] config: DoubleBufferingGlobalConfig<S>,
146 ) {
147 self.stage
148 .clear_stage::<DoubleBufferingGlobalConfig<S>>(stage_buffer, self.ident, config)
149 }
150}