cubecl_matmul/components/global/read/strategy/
async_full_cooperative.rs1use crate::components::{
2 InvalidConfigError, MatmulIdent, MatrixLayout, MatrixPrecision,
3 global::{
4 CopyMechanism, GlobalConfig,
5 memory::{GlobalIterator, load_window_in_stage},
6 read::AsyncFullLoadingStrategy,
7 },
8 stage::{StridedStage, StridedTilingLayout, TilingValidation},
9};
10use cubecl_core::prelude::*;
11use cubecl_core::{self as cubecl, prelude::barrier::BarrierLevel};
12
13use super::{AsyncLoadingJob, LoadingValidation};
14
15#[derive(CubeType, Clone, Copy)]
16pub struct AsyncFullCooperativeLoading {}
21
22impl LoadingValidation for AsyncFullCooperativeLoading {
23 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError> {
24 StridedTilingLayout::check(config.global_memory_config(ident))?;
25
26 Ok(())
27 }
28}
29
30#[cube]
31impl AsyncFullLoadingStrategy for AsyncFullCooperativeLoading {
32 type TilingLayout = StridedTilingLayout;
33 type Job<IP: MatrixPrecision> = AsyncFullCooperativeJob;
34
35 fn new_job<IP: MatrixPrecision, G: GlobalConfig>(
36 #[comptime] ident: MatmulIdent,
37 #[comptime] config: G,
38 ) -> AsyncFullCooperativeJob {
39 let matrix_layout = config.matrix_layout(ident);
40
41 let num_slices = match matrix_layout {
42 MatrixLayout::RowMajor => config.tiling_scheme().elements_in_stage_row(ident),
43 MatrixLayout::ColMajor => config.tiling_scheme().elements_in_stage_col(ident),
44 };
45
46 AsyncFullCooperativeJob { num_slices, ident }
47 }
48
49 fn barrier_level() -> BarrierLevel {
50 BarrierLevel::cube_coop(0u32)
51 }
52}
53
54#[derive(CubeType, Clone, Copy)]
55pub struct AsyncFullCooperativeJob {
56 #[cube(comptime)]
57 num_slices: u32,
58 #[cube(comptime)]
59 ident: MatmulIdent,
60}
61
62#[cube]
63impl<IP: MatrixPrecision> AsyncLoadingJob<IP, StridedTilingLayout> for AsyncFullCooperativeJob {
64 fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
65 this: &mut Self,
66 task_id: u32,
67 global_iter: &GlobalIterator<Line<IP::Global>>,
68 stage: &mut StridedStage<IP::Stage, StridedTilingLayout>,
69 mechanism: &CM,
70 #[comptime] config: G,
71 ) {
72 let window = load_window_in_stage(
73 &global_iter.view(),
74 task_id,
75 comptime!(config.global_memory_config(this.ident)),
76 );
77
78 let mut destination: SliceMut<Line<IP::Stage>> = StridedTilingLayout::nth_slice::<IP::Stage>(
79 stage,
80 task_id,
81 comptime!(config.stage_memory_config(this.ident)),
82 );
83
84 CM::memcpy_async(mechanism, &window.try_cast_unchecked(), &mut destination);
85 }
86
87 fn task_count(this: &Self) -> comptime_type!(u32) {
88 this.num_slices
89 }
90}