cubecl_matmul/components/global/read/strategy/
async_full_cooperative.rs

1use crate::components::{
2    InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
3    global::{
4        CopyMechanism, GlobalConfig,
5        memory::{GlobalIterator, load_window_in_stage},
6        read::AsyncFullLoadingStrategy,
7    },
8    stage::{StridedStage, StridedTilingLayout, TilingValidation},
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
12
13use super::{AsyncLoadingJob, LoadingValidation};
14
15#[derive(CubeType, Clone, Copy)]
16/// Loads global memory into the stage without layout change,  
17/// dividing the stage into the smallest possible contiguous slices.  
18///
19/// Each `memcpy_async` is called with the same arguments for cooperative behaviour
20pub struct AsyncFullCooperativeLoading {}
21
22impl LoadingValidation for AsyncFullCooperativeLoading {
23    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
24        StridedTilingLayout::check(config.global_memory_config(ident))?;
25
26        Ok(())
27    }
28}
29
30#[cube]
31impl AsyncFullLoadingStrategy for AsyncFullCooperativeLoading {
32    type TilingLayout = StridedTilingLayout;
33    type Job<IP: MatrixPrecision> = AsyncFullCooperativeJob;
34
35    fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
36        #[comptime] ident: MatmulIdent,
37        #[comptime] config: G,
38    ) -> AsyncFullCooperativeJob {
39        let matrix_layout = config.matrix_layout(ident);
40
41        let num_slices = match matrix_layout {
42            MatrixLayout::RowMajor => config.tiling_scheme().elements_in_stage_row(ident),
43            MatrixLayout::ColMajor => config.tiling_scheme().elements_in_stage_col(ident),
44        };
45
46        AsyncFullCooperativeJob { num_slices, ident }
47    }
48
49    fn barrier_level() -> BarrierLevel {
50        BarrierLevel::cube_coop(0u32)
51    }
52}
53
54#[derive(CubeType, Clone, Copy)]
55pub struct AsyncFullCooperativeJob {
56    #[cube(comptime)]
57    num_slices: u32,
58    #[cube(comptime)]
59    ident: MatmulIdent,
60}
61
62#[cube]
63impl<IP: MatrixPrecision> AsyncLoadingJob<IP, StridedTilingLayout> for AsyncFullCooperativeJob {
64    fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
65        this: &mut Self,
66        task_id: u32,
67        global_iter: &GlobalIterator<Line<IP::Global>>,
68        stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
69        mechanism: &CM,
70        #[comptime] config: G,
71    ) {
72        let window = load_window_in_stage(
73            &global_iter.view(),
74            task_id,
75            comptime!(config.global_memory_config(this.ident)),
76        );
77
78        let mut destination: SliceMut<Line<IP::Stage>> = StridedTilingLayout::nth_slice::<IP::Stage>(
79            stage,
80            task_id,
81            comptime!(config.stage_memory_config(this.ident)),
82        );
83
84        CM::memcpy_async(mechanism, &window.try_cast_unchecked(), &mut destination);
85    }
86
87    fn task_count(this: &Self) -> comptime_type!(u32) {
88        this.num_slices
89    }
90}