cubecl_matmul/components/global/read/reader/
async_full_reader.rs1use std::marker::PhantomData;
2
3use crate::components::global::memory::GlobalIterator;
4use crate::components::global::read::{AsyncLoadingJob, LoadingValidation};
5use crate::components::global::{CopyMechanism, GlobalConfig};
6use crate::components::stage::TilingLayout;
7use crate::components::stage::{self, StridedStage};
8use crate::components::{MatmulIdent, MatrixPrecision};
9use cubecl_core as cubecl;
10use cubecl_core::prelude::barrier::BarrierLevel;
11use cubecl_core::prelude::*;
12use cubecl_std::{
13 CubeOption, CubeOptionExpand,
14 tensor::{View, layout::Coords2d},
15};
16
17#[cube]
18pub trait AsyncFullLoadingStrategy: 'static + Send + Sync + Clone + LoadingValidation {
20 type TilingLayout: TilingLayout;
22
23 type Job<IP: MatrixPrecision>: AsyncLoadingJob<IP, Self::TilingLayout>;
25
26 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
28 #[comptime] ident: MatmulIdent,
29 #[comptime] config: G,
30 ) -> Self::Job<IP>;
31
32 fn barrier_level() -> BarrierLevel;
34}
35
36#[derive(CubeType)]
37pub struct AsyncFullStageGlobalReader<
42 IP: MatrixPrecision,
43 CM: CopyMechanism,
44 S: stage::StageConfig,
45 L: AsyncFullLoadingStrategy,
46 G: GlobalConfig,
47> {
48 tensor_reader: GlobalIterator<Line<IP::Global>>,
49 stage_memory: StridedStage<IP::Stage, L::TilingLayout>,
50 loading_job: CubeOption<L::Job<IP>>,
51 #[cube(comptime)]
52 ident: MatmulIdent,
53 #[cube(comptime)]
54 _phantom: PhantomData<(S, L, CM, G)>,
55}
56
57#[cube]
58impl<
59 IP: MatrixPrecision,
60 CM: CopyMechanism,
61 S: stage::StageConfig,
62 L: AsyncFullLoadingStrategy,
63 G: GlobalConfig,
64> AsyncFullStageGlobalReader<IP, CM, S, L, G>
65{
66 pub fn new(
68 view: View<Line<IP::Global>, Coords2d>,
69 k_step: u32,
70 #[comptime] ident: MatmulIdent,
71 #[comptime] config: G,
72 ) -> Self {
73 let mut stage_memory = StridedStage::new(
74 comptime!(ident.into_stage()),
75 config.stage_memory_config(ident),
76 );
77 let (shape_row, shape_col) = view.shape();
78 let tensor_reader = GlobalIterator::new(view, k_step, ident.view_direction(), true);
79
80 let loading_job = match config.precompute_job() {
81 true => CubeOption::new_Some(L::new_job::<IP, G>(ident, config)),
82 false => CubeOption::new_None(),
83 };
84
85 match ident {
88 MatmulIdent::Lhs =>
89 {
90 #[allow(clippy::collapsible_if)]
91 if config.check_row_bounds(ident) {
92 if shape_row < config.tiling_scheme().elements_in_stage_m() {
93 stage_memory.clear_all::<G>(ident, config);
94 }
95 }
96 }
97 MatmulIdent::Rhs =>
98 {
99 #[allow(clippy::collapsible_if)]
100 if config.check_col_bounds(ident) {
101 if shape_col < config.tiling_scheme().elements_in_stage_n() {
102 stage_memory.clear_all::<G>(ident, config);
103 }
104 }
105 }
106 MatmulIdent::Out => comptime!(unreachable!()),
107 }
108
109 AsyncFullStageGlobalReader::<IP, CM, S, L, G> {
110 tensor_reader,
111 stage_memory,
112 loading_job,
113 ident,
114 _phantom: PhantomData,
115 }
116 }
117
118 pub fn load_stage(&mut self, mechanism: &CM, #[comptime] config: G) {
120 let mut loading_job = match self.loading_job {
121 CubeOption::Some(loading_job) => loading_job,
122 CubeOption::None => L::new_job::<IP, G>(self.ident, config),
123 };
124
125 let len = L::Job::task_count(&loading_job);
126 for task_id in 0..len {
127 L::Job::<IP>::execute_task::<CM, G>(
128 &mut loading_job,
129 task_id,
130 &self.tensor_reader,
131 &mut self.stage_memory,
132 mechanism,
133 config,
134 );
135 }
136 }
137
138 pub fn clear_stage(&mut self, #[comptime] config: G) {
140 self.stage_memory.clear_all::<G>(self.ident, config)
141 }
142
143 pub fn stage(&self) -> StridedStage<IP::Stage, L::TilingLayout> {
145 self.stage_memory
146 }
147
148 pub fn advance_view(&mut self) {
150 self.tensor_reader.advance();
151 }
152}