cubecl_matmul/components/global/read/strategy/
async_full_cooperative.rs

1use crate::components::{
2    InvalidConfigError, MatmulElems, MatrixLayout,
3    global::{
4        GlobalReaderConfig,
5        memory::{GlobalIterator, load_window_in_stage},
6        multi_stage::LoadMaxRoundPlaneCount,
7        read::{
8            FullLoadingStrategy, LoadingJob, async_barrier::AsyncBarrier, validate_async_barrier,
9            validate_noswizzle,
10        },
11    },
12    stage::{StridedStageFamily, StridedStageMemory, StridedTilingLayout, TilingValidation},
13};
14use cubecl_core::prelude::{barrier::Barrier, *};
15use cubecl_core::{self as cubecl};
16
17use super::LoadingValidation;
18
19#[derive(CubeType, Clone, Copy)]
20/// Loads global memory into the stage without layout change,  
21/// dividing the stage into the smallest possible contiguous slices.  
22///
23/// Each `memcpy_async` is called with the same arguments for cooperative behaviour
24pub struct AsyncFullCooperativeLoading {}
25
26impl LoadingValidation for AsyncFullCooperativeLoading {
27    fn check<R: Runtime>(
28        client: &ComputeClient<R::Server>,
29        config: &GlobalReaderConfig,
30        _dtypes: &MatmulElems,
31    ) -> Result<(), InvalidConfigError> {
32        StridedTilingLayout::check(config.smem_config)?;
33        validate_async_barrier::<R>(client)?;
34        validate_noswizzle(config.smem_config)?;
35
36        Ok(())
37    }
38}
39
40impl LoadMaxRoundPlaneCount for AsyncFullCooperativeLoading {
41    fn max_round_plane_count(
42        _elements_per_tile: u32,
43        _tiles_per_stage: u32,
44        _line_size: u8,
45        _plane_dim: u32,
46    ) -> u32 {
47        // Not sure what's ideal here, the current specialization isn't great anyways so can deal
48        // with it later
49        4
50    }
51}
52
53#[cube]
54impl FullLoadingStrategy for AsyncFullCooperativeLoading {
55    type TilingLayout = StridedTilingLayout;
56    type SyncStrategy = AsyncBarrier;
57    type Job<EG: Numeric, ES: Numeric> = AsyncFullCooperativeJob;
58
59    const SHOULD_CLEAR: bool = true;
60
61    fn new_job<EG: Numeric, ES: Numeric>(
62        #[comptime] _line_size: u32,
63        #[comptime] config: GlobalReaderConfig,
64    ) -> AsyncFullCooperativeJob {
65        let matrix_layout = config.gmem_config.matrix_layout;
66
67        let num_slices = match matrix_layout {
68            MatrixLayout::RowMajor => config.smem_config.elements_per_stage_along_row(),
69            MatrixLayout::ColMajor => config.smem_config.elements_per_stage_along_col(),
70        };
71
72        AsyncFullCooperativeJob { num_slices }
73    }
74}
75
76#[derive(CubeType, Clone, Copy)]
77pub struct AsyncFullCooperativeJob {
78    #[cube(comptime)]
79    num_slices: u32,
80}
81
82#[cube]
83impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, StridedTilingLayout, AsyncBarrier>
84    for AsyncFullCooperativeJob
85{
86    type Stage = StridedStageFamily;
87
88    fn execute_task(
89        _this: &mut Self,
90        #[comptime] task_id: u32,
91        global_iter: &GlobalIterator<Line<EG>>,
92        stage: &mut StridedStageMemory<ES, StridedTilingLayout>,
93        barrier: &mut Barrier,
94        #[comptime] config: GlobalReaderConfig,
95    ) {
96        let window = load_window_in_stage(
97            &global_iter.view(),
98            task_id,
99            config.smem_config,
100            config.gmem_config,
101        );
102
103        let mut destination: SliceMut<Line<ES>> =
104            StridedTilingLayout::nth_slice::<ES>(stage, task_id, comptime!(config.smem_config));
105
106        barrier.memcpy_async_cooperative(&window.try_cast_unchecked(), &mut destination);
107    }
108
109    fn task_count(this: &Self) -> comptime_type!(u32) {
110        this.num_slices
111    }
112}