cubecl_matmul/components/global/read/strategy/
async_partial_tma.rs1use crate::components::MatmulElems;
2use crate::components::global::GlobalReaderConfig;
3use crate::components::global::read::{validate_async_barrier, validate_tma};
4use crate::components::global::{RoleRule, multi_stage::LoadMaxRoundPlaneCount};
5use crate::components::stage::StridedStageFamily;
6use crate::components::stage::TmaTilingLayout;
7use crate::components::stage::{StridedStageMemory, SwizzleMode};
8use crate::components::{InvalidConfigError, StageIdent};
9use crate::components::{
10 MatrixLayout,
11 global::read::{PartialLoadingStrategy, async_tma::AsyncTma},
12};
13use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
14use cubecl_core::prelude::*;
15use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
16
17use super::{LoadingJob, LoadingValidation};
18
19#[derive(CubeType, Clone, Copy)]
20pub struct AsyncPartialTmaLoading {}
24
25impl LoadingValidation for AsyncPartialTmaLoading {
26 fn check<R: Runtime>(
27 client: &ComputeClient<R::Server>,
28 config: &GlobalReaderConfig,
29 dtypes: &MatmulElems,
30 ) -> Result<(), InvalidConfigError> {
31 TmaTilingLayout::check(config.smem_config)?;
32 validate_tma::<R>(client, config.smem_config, config.stage_ident, dtypes)?;
33
34 validate_async_barrier::<R>(client)?;
35
36 Ok(())
37 }
38}
39
40impl LoadMaxRoundPlaneCount for AsyncPartialTmaLoading {
41 fn max_round_plane_count(
42 _elements_per_tile: u32,
43 _tiles_per_stage: u32,
44 _line_size: u8,
45 _plane_dim: u32,
46 ) -> u32 {
47 4
48 }
49}
50
51#[cube]
52impl PartialLoadingStrategy for AsyncPartialTmaLoading {
53 type TilingLayout = TmaTilingLayout;
54 type SyncStrategy = AsyncTma;
55 type Stage = StridedStageFamily;
56
57 type Job<EG: Numeric, ES: Numeric> = AsyncPartialTmaJob;
58
59 fn new_job<EG: Numeric, ES: Numeric>(
60 #[comptime] stage_index: u32,
61 #[comptime] _line_size: u32,
62 #[comptime] config: GlobalReaderConfig,
63 ) -> Self::Job<EG, ES> {
64 let role_rule_config = config.plane_role_config.rule;
65 let config = config.smem_config;
66 let tile_count_col = match config.matrix_layout {
67 MatrixLayout::RowMajor => config.tiles_per_stage_along_col(),
68 MatrixLayout::ColMajor => config.tiles_per_stage_along_row(),
69 };
70 let num_tasks = comptime![match config.swizzle {
73 SwizzleMode::None => tile_count_col,
74 _ => 1u32,
75 }];
76
77 let is_elected = RoleRule::new(role_rule_config).elect_load_leader();
78
79 AsyncPartialTmaJob {
80 is_elected,
81 num_tasks,
82 stage_index,
83 }
84 }
85}
86
87#[derive(CubeType, Clone, Copy)]
88pub struct AsyncPartialTmaJob {
89 is_elected: bool,
90
91 #[cube(comptime)]
92 num_tasks: u32,
93 #[cube(comptime)]
94 stage_index: u32,
95}
96
97#[cube]
98impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, TmaTilingLayout, AsyncTma>
99 for AsyncPartialTmaJob
100{
101 type Stage = StridedStageFamily;
102
103 fn execute_task(
104 this: &mut Self,
105 #[comptime] task_id: u32,
106 global_iter: &GlobalIterator<Line<EG>>,
107 stage: &mut StridedStageMemory<ES, TmaTilingLayout>,
108 barrier: &mut Barrier,
109 #[comptime] config: GlobalReaderConfig,
110 ) {
111 let mut stage = stage.with_buffer_index(this.stage_index);
112 if this.is_elected {
113 let size_row = match config.smem_config.matrix_layout {
114 MatrixLayout::RowMajor => config.smem_config.elements_per_stage_along_row(),
115 MatrixLayout::ColMajor => config.smem_config.elements_per_stage_along_col(),
116 };
117 let size_col = match config.smem_config.matrix_layout {
118 MatrixLayout::RowMajor => config.smem_config.elements_per_tile_along_col,
119 MatrixLayout::ColMajor => config.smem_config.elements_per_tile_along_row,
120 };
121
122 let (offs_row, offs_col) = comptime![match config.stage_ident {
123 StageIdent::Lhs => (
124 0,
125 this.stage_index * config.smem_config.elements_per_stage_along_col()
126 ),
127 StageIdent::Rhs => (
128 this.stage_index * config.smem_config.elements_per_stage_along_row(),
129 0
130 ),
131 _ => (0, 0),
132 }]
133 .runtime();
134
135 let global_view = global_iter.view();
136 let mut stage = stage.as_slice_mut(1u32);
137 let slice_size = size_row * size_col;
138
139 let slice_start = task_id * slice_size;
140 let slice = stage.slice_mut(slice_start, slice_start + slice_size);
141 let load_col = task_id * size_col;
143
144 let pos = match config.smem_config.matrix_layout {
145 MatrixLayout::RowMajor => (offs_row, load_col + offs_col),
146 MatrixLayout::ColMajor => (load_col + offs_row, offs_col),
147 };
148
149 global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), pos);
150 }
151 }
152
153 fn task_count(this: &Self) -> comptime_type!(u32) {
154 this.num_tasks
155 }
156}