cubecl_matmul/components/global/read/strategy/
async_full_tma.rs1use crate::components::InvalidConfigError;
2use crate::components::MatmulElems;
3use crate::components::global::GlobalReaderConfig;
4use crate::components::global::read::{validate_async_barrier, validate_tma};
5use crate::components::global::{RoleRule, read::async_tma::AsyncTma};
6use crate::components::stage::StridedStageFamily;
7use crate::components::stage::{StridedStageMemory, SwizzleMode};
8use crate::components::{MatrixLayout, global::read::FullLoadingStrategy};
9use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
10use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TmaTilingLayout};
11use cubecl_core::prelude::*;
12use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
13
14use super::{LoadingJob, LoadingValidation};
15
16#[derive(CubeType, Clone, Copy)]
17pub struct AsyncFullTmaLoading {}
21
22impl LoadingValidation for AsyncFullTmaLoading {
23 fn check<R: Runtime>(
24 client: &ComputeClient<R::Server>,
25 config: &GlobalReaderConfig,
26 dtypes: &MatmulElems,
27 ) -> Result<(), InvalidConfigError> {
28 TmaTilingLayout::check(config.smem_config)?;
29 validate_tma::<R>(client, config.smem_config, config.stage_ident, dtypes)?;
30 validate_async_barrier::<R>(client)?;
31
32 Ok(())
33 }
34}
35
36impl LoadMaxRoundPlaneCount for AsyncFullTmaLoading {
37 fn max_round_plane_count(
38 _elements_per_tile: u32,
39 _tiles_per_stage: u32,
40 _line_size: u8,
41 _plane_dim: u32,
42 ) -> u32 {
43 4
46 }
47}
48
49#[cube]
50impl FullLoadingStrategy for AsyncFullTmaLoading {
51 type TilingLayout = TmaTilingLayout;
52 type SyncStrategy = AsyncTma;
53 type Job<EG: Numeric, ES: Numeric> = AsyncFullTmaJob;
54
55 fn new_job<EG: Numeric, ES: Numeric>(
56 #[comptime] _line_size: u32,
57 #[comptime] config: GlobalReaderConfig,
58 ) -> Self::Job<EG, ES> {
59 let role_rule_config = config.plane_role_config.rule;
60 let config = config.smem_config;
61 let tile_count_col = match config.matrix_layout {
62 MatrixLayout::RowMajor => config.tiles_per_stage_along_col(),
63 MatrixLayout::ColMajor => config.tiles_per_stage_along_row(),
64 };
65 let num_tasks = comptime![match config.swizzle {
68 SwizzleMode::None => tile_count_col,
69 _ => 1u32,
70 }];
71
72 let is_elected = RoleRule::new(role_rule_config).elect_load_leader();
73
74 AsyncFullTmaJob {
75 is_elected,
76 num_tasks,
77 }
78 }
79}
80
81#[derive(CubeType, Clone, Copy)]
82pub struct AsyncFullTmaJob {
83 is_elected: bool,
84
85 #[cube(comptime)]
86 num_tasks: u32,
87}
88
89#[cube]
90impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, TmaTilingLayout, AsyncTma> for AsyncFullTmaJob {
91 type Stage = StridedStageFamily;
92
93 fn execute_task(
94 this: &mut Self,
95 #[comptime] task_id: u32,
96 global_iter: &GlobalIterator<Line<EG>>,
97 stage: &mut StridedStageMemory<ES, TmaTilingLayout>,
98 barrier: &mut Barrier,
99 #[comptime] config: GlobalReaderConfig,
100 ) {
101 if this.is_elected {
102 let config = comptime![config.smem_config];
103
104 let size_row = match config.matrix_layout {
105 MatrixLayout::RowMajor => config.elements_per_stage_along_row(),
106 MatrixLayout::ColMajor => config.elements_per_stage_along_col(),
107 };
108 let size_col = match config.matrix_layout {
109 MatrixLayout::RowMajor => config.elements_per_tile_along_col,
110 MatrixLayout::ColMajor => config.elements_per_tile_along_row,
111 };
112
113 let global_view = global_iter.view();
114 let mut stage = stage.as_slice_mut(1u32);
115 let slice_size = size_row * size_col;
116
117 let slice_start = task_id * slice_size;
118 let slice = stage.slice_mut(slice_start, slice_start + slice_size);
119 let col = task_id * size_col;
120
121 let pos = match config.matrix_layout {
122 MatrixLayout::RowMajor => (0, col),
123 MatrixLayout::ColMajor => (col, 0),
124 };
125
126 global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), pos);
127 }
128 }
129
130 fn task_count(this: &Self) -> comptime_type!(u32) {
131 this.num_tasks
132 }
133}