cubecl_matmul/components/global/read/strategy/
async_full_cooperative.rs1use crate::components::{
2 InvalidConfigError, MatmulElems, MatrixLayout,
3 global::{
4 GlobalReaderConfig,
5 memory::{GlobalIterator, load_window_in_stage},
6 multi_stage::LoadMaxRoundPlaneCount,
7 read::{
8 FullLoadingStrategy, LoadingJob, async_barrier::AsyncBarrier, validate_async_barrier,
9 validate_noswizzle,
10 },
11 },
12 stage::{StridedStageFamily, StridedStageMemory, StridedTilingLayout, TilingValidation},
13};
14use cubecl_core::prelude::{barrier::Barrier, *};
15use cubecl_core::{self as cubecl};
16
17use super::LoadingValidation;
18
19#[derive(CubeType, Clone, Copy)]
20pub struct AsyncFullCooperativeLoading {}
25
26impl LoadingValidation for AsyncFullCooperativeLoading {
27 fn check<R: Runtime>(
28 client: &ComputeClient<R::Server>,
29 config: &GlobalReaderConfig,
30 _dtypes: &MatmulElems,
31 ) -> Result<(), InvalidConfigError> {
32 StridedTilingLayout::check(config.smem_config)?;
33 validate_async_barrier::<R>(client)?;
34 validate_noswizzle(config.smem_config)?;
35
36 Ok(())
37 }
38}
39
40impl LoadMaxRoundPlaneCount for AsyncFullCooperativeLoading {
41 fn max_round_plane_count(
42 _elements_per_tile: u32,
43 _tiles_per_stage: u32,
44 _line_size: u8,
45 _plane_dim: u32,
46 ) -> u32 {
47 4
50 }
51}
52
53#[cube]
54impl FullLoadingStrategy for AsyncFullCooperativeLoading {
55 type TilingLayout = StridedTilingLayout;
56 type SyncStrategy = AsyncBarrier;
57 type Job<EG: Numeric, ES: Numeric> = AsyncFullCooperativeJob;
58
59 const SHOULD_CLEAR: bool = true;
60
61 fn new_job<EG: Numeric, ES: Numeric>(
62 #[comptime] _line_size: u32,
63 #[comptime] config: GlobalReaderConfig,
64 ) -> AsyncFullCooperativeJob {
65 let matrix_layout = config.gmem_config.matrix_layout;
66
67 let num_slices = match matrix_layout {
68 MatrixLayout::RowMajor => config.smem_config.elements_per_stage_along_row(),
69 MatrixLayout::ColMajor => config.smem_config.elements_per_stage_along_col(),
70 };
71
72 AsyncFullCooperativeJob { num_slices }
73 }
74}
75
76#[derive(CubeType, Clone, Copy)]
77pub struct AsyncFullCooperativeJob {
78 #[cube(comptime)]
79 num_slices: u32,
80}
81
82#[cube]
83impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, StridedTilingLayout, AsyncBarrier>
84 for AsyncFullCooperativeJob
85{
86 type Stage = StridedStageFamily;
87
88 fn execute_task(
89 _this: &mut Self,
90 #[comptime] task_id: u32,
91 global_iter: &GlobalIterator<Line<EG>>,
92 stage: &mut StridedStageMemory<ES, StridedTilingLayout>,
93 barrier: &mut Barrier,
94 #[comptime] config: GlobalReaderConfig,
95 ) {
96 let window = load_window_in_stage(
97 &global_iter.view(),
98 task_id,
99 config.smem_config,
100 config.gmem_config,
101 );
102
103 let mut destination: SliceMut<Line<ES>> =
104 StridedTilingLayout::nth_slice::<ES>(stage, task_id, comptime!(config.smem_config));
105
106 barrier.memcpy_async_cooperative(&window.try_cast_unchecked(), &mut destination);
107 }
108
109 fn task_count(this: &Self) -> comptime_type!(u32) {
110 this.num_slices
111 }
112}