cubecl_matmul/components/global/read/strategy/
async_full_tma.rs

1use crate::components::InvalidConfigError;
2use crate::components::MatmulElems;
3use crate::components::global::GlobalReaderConfig;
4use crate::components::global::read::{validate_async_barrier, validate_tma};
5use crate::components::global::{RoleRule, read::async_tma::AsyncTma};
6use crate::components::stage::StridedStageFamily;
7use crate::components::stage::{StridedStageMemory, SwizzleMode};
8use crate::components::{MatrixLayout, global::read::FullLoadingStrategy};
9use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
10use crate::components::{global::multi_stage::LoadMaxRoundPlaneCount, stage::TmaTilingLayout};
11use cubecl_core::prelude::*;
12use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
13
14use super::{LoadingJob, LoadingValidation};
15
16#[derive(CubeType, Clone, Copy)]
17/// Loads the content of all tiles in the stage using TMA load instructions.
18/// Uses special tiling to minimize the number of loads required. Issues one load for each
19/// tile in the major dimension (i.e. `k` for col-major RHS).
20pub struct AsyncFullTmaLoading {}
21
22impl LoadingValidation for AsyncFullTmaLoading {
23    fn check<R: Runtime>(
24        client: &ComputeClient<R::Server>,
25        config: &GlobalReaderConfig,
26        dtypes: &MatmulElems,
27    ) -> Result<(), InvalidConfigError> {
28        TmaTilingLayout::check(config.smem_config)?;
29        validate_tma::<R>(client, config.smem_config, config.stage_ident, dtypes)?;
30        validate_async_barrier::<R>(client)?;
31
32        Ok(())
33    }
34}
35
36impl LoadMaxRoundPlaneCount for AsyncFullTmaLoading {
37    fn max_round_plane_count(
38        _elements_per_tile: u32,
39        _tiles_per_stage: u32,
40        _line_size: u8,
41        _plane_dim: u32,
42    ) -> u32 {
43        // Not sure this is the best value, but TMA is executed per-warpgroup so this is the maximum
44        // number of planes executing one set of TMA loads.
45        4
46    }
47}
48
49#[cube]
50impl FullLoadingStrategy for AsyncFullTmaLoading {
51    type TilingLayout = TmaTilingLayout;
52    type SyncStrategy = AsyncTma;
53    type Job<EG: Numeric, ES: Numeric> = AsyncFullTmaJob;
54
55    fn new_job<EG: Numeric, ES: Numeric>(
56        #[comptime] _line_size: u32,
57        #[comptime] config: GlobalReaderConfig,
58    ) -> Self::Job<EG, ES> {
59        let role_rule_config = config.plane_role_config.rule;
60        let config = config.smem_config;
61        let tile_count_col = match config.matrix_layout {
62            MatrixLayout::RowMajor => config.tiles_per_stage_along_col(),
63            MatrixLayout::ColMajor => config.tiles_per_stage_along_row(),
64        };
65        // Swizzle renders the column format irrelevant, so we load the whole stage at once
66        // The tiling is set on launch for TMA, so no further change is needed here.
67        let num_tasks = comptime![match config.swizzle {
68            SwizzleMode::None => tile_count_col,
69            _ => 1u32,
70        }];
71
72        let is_elected = RoleRule::new(role_rule_config).elect_load_leader();
73
74        AsyncFullTmaJob {
75            is_elected,
76            num_tasks,
77        }
78    }
79}
80
81#[derive(CubeType, Clone, Copy)]
82pub struct AsyncFullTmaJob {
83    is_elected: bool,
84
85    #[cube(comptime)]
86    num_tasks: u32,
87}
88
89#[cube]
90impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, TmaTilingLayout, AsyncTma> for AsyncFullTmaJob {
91    type Stage = StridedStageFamily;
92
93    fn execute_task(
94        this: &mut Self,
95        #[comptime] task_id: u32,
96        global_iter: &GlobalIterator<Line<EG>>,
97        stage: &mut StridedStageMemory<ES, TmaTilingLayout>,
98        barrier: &mut Barrier,
99        #[comptime] config: GlobalReaderConfig,
100    ) {
101        if this.is_elected {
102            let config = comptime![config.smem_config];
103
104            let size_row = match config.matrix_layout {
105                MatrixLayout::RowMajor => config.elements_per_stage_along_row(),
106                MatrixLayout::ColMajor => config.elements_per_stage_along_col(),
107            };
108            let size_col = match config.matrix_layout {
109                MatrixLayout::RowMajor => config.elements_per_tile_along_col,
110                MatrixLayout::ColMajor => config.elements_per_tile_along_row,
111            };
112
113            let global_view = global_iter.view();
114            let mut stage = stage.as_slice_mut(1u32);
115            let slice_size = size_row * size_col;
116
117            let slice_start = task_id * slice_size;
118            let slice = stage.slice_mut(slice_start, slice_start + slice_size);
119            let col = task_id * size_col;
120
121            let pos = match config.matrix_layout {
122                MatrixLayout::RowMajor => (0, col),
123                MatrixLayout::ColMajor => (col, 0),
124            };
125
126            global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), pos);
127        }
128    }
129
130    fn task_count(this: &Self) -> comptime_type!(u32) {
131        this.num_tasks
132    }
133}