cubecl_matmul/components/global/read/strategy/
async_partial_tma.rs

1use crate::components::MatmulElems;
2use crate::components::global::GlobalReaderConfig;
3use crate::components::global::read::{validate_async_barrier, validate_tma};
4use crate::components::global::{RoleRule, multi_stage::LoadMaxRoundPlaneCount};
5use crate::components::stage::StridedStageFamily;
6use crate::components::stage::TmaTilingLayout;
7use crate::components::stage::{StridedStageMemory, SwizzleMode};
8use crate::components::{InvalidConfigError, StageIdent};
9use crate::components::{
10    MatrixLayout,
11    global::read::{PartialLoadingStrategy, async_tma::AsyncTma},
12};
13use crate::components::{global::memory::GlobalIterator, stage::TilingValidation};
14use cubecl_core::prelude::*;
15use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
16
17use super::{LoadingJob, LoadingValidation};
18
19#[derive(CubeType, Clone, Copy)]
20/// Loads the content of all tiles in the stage using TMA load instructions.
21/// Uses special tiling to minimize the number of loads required. Issues one load for each
22/// tile in the major dimension (i.e. `k` for col-major RHS).
23pub struct AsyncPartialTmaLoading {}
24
25impl LoadingValidation for AsyncPartialTmaLoading {
26    fn check<R: Runtime>(
27        client: &ComputeClient<R::Server>,
28        config: &GlobalReaderConfig,
29        dtypes: &MatmulElems,
30    ) -> Result<(), InvalidConfigError> {
31        TmaTilingLayout::check(config.smem_config)?;
32        validate_tma::<R>(client, config.smem_config, config.stage_ident, dtypes)?;
33
34        validate_async_barrier::<R>(client)?;
35
36        Ok(())
37    }
38}
39
40impl LoadMaxRoundPlaneCount for AsyncPartialTmaLoading {
41    fn max_round_plane_count(
42        _elements_per_tile: u32,
43        _tiles_per_stage: u32,
44        _line_size: u8,
45        _plane_dim: u32,
46    ) -> u32 {
47        4
48    }
49}
50
51#[cube]
52impl PartialLoadingStrategy for AsyncPartialTmaLoading {
53    type TilingLayout = TmaTilingLayout;
54    type SyncStrategy = AsyncTma;
55    type Stage = StridedStageFamily;
56
57    type Job<EG: Numeric, ES: Numeric> = AsyncPartialTmaJob;
58
59    fn new_job<EG: Numeric, ES: Numeric>(
60        #[comptime] stage_index: u32,
61        #[comptime] _line_size: u32,
62        #[comptime] config: GlobalReaderConfig,
63    ) -> Self::Job<EG, ES> {
64        let role_rule_config = config.plane_role_config.rule;
65        let config = config.smem_config;
66        let tile_count_col = match config.matrix_layout {
67            MatrixLayout::RowMajor => config.tiles_per_stage_along_col(),
68            MatrixLayout::ColMajor => config.tiles_per_stage_along_row(),
69        };
70        // Swizzle renders the column format irrelevant, so we load the whole stage at once
71        // The tiling is set on launch for TMA, so no further change is needed here.
72        let num_tasks = comptime![match config.swizzle {
73            SwizzleMode::None => tile_count_col,
74            _ => 1u32,
75        }];
76
77        let is_elected = RoleRule::new(role_rule_config).elect_load_leader();
78
79        AsyncPartialTmaJob {
80            is_elected,
81            num_tasks,
82            stage_index,
83        }
84    }
85}
86
87#[derive(CubeType, Clone, Copy)]
88pub struct AsyncPartialTmaJob {
89    is_elected: bool,
90
91    #[cube(comptime)]
92    num_tasks: u32,
93    #[cube(comptime)]
94    stage_index: u32,
95}
96
97#[cube]
98impl<EG: Numeric, ES: Numeric> LoadingJob<EG, ES, TmaTilingLayout, AsyncTma>
99    for AsyncPartialTmaJob
100{
101    type Stage = StridedStageFamily;
102
103    fn execute_task(
104        this: &mut Self,
105        #[comptime] task_id: u32,
106        global_iter: &GlobalIterator<Line<EG>>,
107        stage: &mut StridedStageMemory<ES, TmaTilingLayout>,
108        barrier: &mut Barrier,
109        #[comptime] config: GlobalReaderConfig,
110    ) {
111        let mut stage = stage.with_buffer_index(this.stage_index);
112        if this.is_elected {
113            let size_row = match config.smem_config.matrix_layout {
114                MatrixLayout::RowMajor => config.smem_config.elements_per_stage_along_row(),
115                MatrixLayout::ColMajor => config.smem_config.elements_per_stage_along_col(),
116            };
117            let size_col = match config.smem_config.matrix_layout {
118                MatrixLayout::RowMajor => config.smem_config.elements_per_tile_along_col,
119                MatrixLayout::ColMajor => config.smem_config.elements_per_tile_along_row,
120            };
121
122            let (offs_row, offs_col) = comptime![match config.stage_ident {
123                StageIdent::Lhs => (
124                    0,
125                    this.stage_index * config.smem_config.elements_per_stage_along_col()
126                ),
127                StageIdent::Rhs => (
128                    this.stage_index * config.smem_config.elements_per_stage_along_row(),
129                    0
130                ),
131                _ => (0, 0),
132            }]
133            .runtime();
134
135            let global_view = global_iter.view();
136            let mut stage = stage.as_slice_mut(1u32);
137            let slice_size = size_row * size_col;
138
139            let slice_start = task_id * slice_size;
140            let slice = stage.slice_mut(slice_start, slice_start + slice_size);
141            // "column" to be loaded, may be a row for col-major (can't think of a better name)
142            let load_col = task_id * size_col;
143
144            let pos = match config.smem_config.matrix_layout {
145                MatrixLayout::RowMajor => (offs_row, load_col + offs_col),
146                MatrixLayout::ColMajor => (load_col + offs_row, offs_col),
147            };
148
149            global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), pos);
150        }
151    }
152
153    fn task_count(this: &Self) -> comptime_type!(u32) {
154        this.num_tasks
155    }
156}