cubecl_linalg/matmul/components/global/load/loader/
tma_loader.rs

1use core::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
5use cubecl_std::CubeOption;
6
7use crate::matmul::components::stage::{FullReader, RowMajorTilingOrder};
8use crate::matmul::components::{
9    Ident, InputIdent, MatmulPrecision, MatrixLayout,
10    global::{Quantization, single_stage},
11};
12use crate::matmul::components::{
13    global::{self, GlobalConfig, tensor_view::MappedTensorReader},
14    stage::{self, ColMajorTilingOrder, ContiguousTilingLayout, Stage, StageConfig, TilingOrder},
15};
16
17pub type TmaTiling = ContiguousTilingLayout<TmaTilingOrder>;
18pub type TmaReader<MP> = FullReader<<MP as MatmulPrecision>::ES, TmaTiling>;
19
20#[derive(CubeType, Clone, Copy)]
21pub struct TmaTilingOrder;
22
23#[cube]
24impl TilingOrder for TmaTilingOrder {
25    fn to_row_col<C: StageConfig>(
26        nth: u32,
27        #[comptime] tile_count_rows: u32,
28        #[comptime] tile_count_cols: u32,
29        #[comptime] ident: Ident,
30        #[comptime] config: C,
31    ) -> (u32, u32) {
32        match config.matrix_layout(ident) {
33            MatrixLayout::RowMajor => ColMajorTilingOrder::to_row_col::<C>(
34                nth,
35                tile_count_rows,
36                tile_count_cols,
37                ident,
38                config,
39            ),
40            MatrixLayout::ColMajor => RowMajorTilingOrder::to_row_col::<C>(
41                nth,
42                tile_count_rows,
43                tile_count_cols,
44                ident,
45                config,
46            ),
47        }
48    }
49    fn to_nth_tile<C: StageConfig>(
50        row: u32,
51        col: u32,
52        #[comptime] tile_count_rows: u32,
53        #[comptime] tile_count_cols: u32,
54        #[comptime] ident: Ident,
55        #[comptime] config: C,
56    ) -> u32 {
57        match config.matrix_layout(ident) {
58            MatrixLayout::RowMajor => ColMajorTilingOrder::to_nth_tile::<C>(
59                row,
60                col,
61                tile_count_rows,
62                tile_count_cols,
63                ident,
64                config,
65            ),
66            MatrixLayout::ColMajor => RowMajorTilingOrder::to_nth_tile::<C>(
67                row,
68                col,
69                tile_count_rows,
70                tile_count_cols,
71                ident,
72                config,
73            ),
74        }
75    }
76}
77
78#[derive(CubeType)]
79pub struct TmaLoader<MP: MatmulPrecision, S: stage::StageConfig> {
80    pub tensor_view: MappedTensorReader<MP::EI>,
81    pub stage: Stage<MP::ES, TmaTiling>,
82    #[cube(comptime)]
83    ident: InputIdent,
84    #[cube(comptime)]
85    _config: PhantomData<S>,
86}
87
88#[cube]
89impl<MP: MatmulPrecision, S: stage::StageConfig> TmaLoader<MP, S> {
90    pub fn new<G: global::GlobalConfig>(
91        tensor: TensorMap<MP::EI>,
92        x: u32,
93        y: u32,
94        batch: u32,
95        quantization: CubeOption<Quantization<MP>>,
96        #[comptime] ident: InputIdent,
97        #[comptime] config: G,
98    ) -> Self {
99        comptime! {
100            if quantization.is_some() {
101                todo!();
102            }
103        }
104
105        let stage = Stage::new_aligned::<G::SmmConfig>(
106            comptime!(ident.as_ident()),
107            128u32,
108            config.to_smm_config(),
109        );
110
111        let tensor_view = MappedTensorReader::new(tensor, x, y, batch);
112
113        TmaLoader::<MP, S> {
114            tensor_view,
115            stage,
116            ident,
117            _config: PhantomData::<S>,
118        }
119    }
120
121    pub fn fill_stage(
122        this: &mut Self,
123        barrier: &Barrier<MP::ES>,
124        #[comptime] config: single_stage::Config<S>,
125    ) {
126        if UNIT_POS == 0 {
127            let ident = comptime!(this.ident.as_ident());
128            // The tensor map is encoded as the transposed shape, so we need to swap coordinates
129            let (row, col) = match config.matrix_layout(ident) {
130                MatrixLayout::RowMajor => (this.tensor_view.tile_x, this.tensor_view.tile_y),
131                MatrixLayout::ColMajor => (this.tensor_view.tile_y, this.tensor_view.tile_x),
132            };
133
134            let tiling_dims = config.tiling_dimensions(ident);
135            let size_row = match config.matrix_layout(ident) {
136                MatrixLayout::RowMajor => tiling_dims.total_row(),
137                MatrixLayout::ColMajor => tiling_dims.total_col(),
138            };
139            let size_col = match config.matrix_layout(ident) {
140                MatrixLayout::RowMajor => tiling_dims.tile_shape_col(),
141                MatrixLayout::ColMajor => tiling_dims.tile_shape_row(),
142            };
143            let tile_count_col = match config.matrix_layout(ident) {
144                MatrixLayout::RowMajor => tiling_dims.tile_count_col(),
145                MatrixLayout::ColMajor => tiling_dims.tile_count_row(),
146            };
147
148            let tensor = this.tensor_view.tensor.try_cast_unchecked();
149            let mut stage = this.stage.as_slice_mut(1u32);
150            let slice_size = size_row * size_col;
151            let batch = this.tensor_view.batch as i32;
152
153            #[unroll]
154            for tile_col in 0..tile_count_col {
155                let slice_start = tile_col * slice_size;
156                let mut slice = stage.slice_mut(slice_start, slice_start + slice_size);
157                let col = col + tile_col * size_col;
158
159                barrier.tma_load_3d(&tensor, &mut slice, batch, row as i32, col as i32);
160            }
161        }
162    }
163
164    pub fn reader(this: &Self) -> TmaReader<MP> {
165        TmaReader::<MP>::new(this.stage, this.ident)
166    }
167
168    pub fn advance_view(this: &mut Self, k_offset: u32) {
169        this.tensor_view
170            .update_view(k_offset, comptime!(this.ident.as_ident()));
171    }
172}