cubecl_matmul/components/global/read/reader/
tma_reader.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3use cubecl_std::tensor::{View, layout::Coords2d};
4
5use crate::components::stage::{
6    ColMajorTilingOrder, ContiguousTilingLayout, StridedStage, TilingOrder,
7};
8use crate::components::stage::{RowMajorTilingOrder, StageMemoryConfig, TilingOrderEnum};
9use crate::components::{MatmulIdent, MatrixPrecision};
10use crate::components::{MatrixLayout, global::memory::GlobalIterator};
11
12pub type TmaTiling = ContiguousTilingLayout<TmaTilingOrder>;
14pub type TmaStage<IP> = StridedStage<<IP as MatrixPrecision>::Stage, TmaTiling>;
16
17#[derive(CubeType, Clone, Copy)]
18pub struct TmaTilingOrder;
22
23#[cube]
24impl TilingOrder for TmaTilingOrder {
25    fn to_row_col(
26        nth: u32,
27        tile_count_rows: u32,
28        tile_count_cols: u32,
29        #[comptime] config: StageMemoryConfig,
30    ) -> Coords2d {
31        match config.matrix_layout {
32            MatrixLayout::RowMajor => {
33                ColMajorTilingOrder::to_row_col(nth, tile_count_rows, tile_count_cols, config)
34            }
35            MatrixLayout::ColMajor => {
36                RowMajorTilingOrder::to_row_col(nth, tile_count_rows, tile_count_cols, config)
37            }
38        }
39    }
40
41    fn to_nth_tile(
42        tile: Coords2d,
43        tile_count_rows: u32,
44        tile_count_cols: u32,
45        #[comptime] config: StageMemoryConfig,
46    ) -> u32 {
47        match config.matrix_layout {
48            MatrixLayout::RowMajor => {
49                ColMajorTilingOrder::to_nth_tile(tile, tile_count_rows, tile_count_cols, config)
50            }
51            MatrixLayout::ColMajor => {
52                RowMajorTilingOrder::to_nth_tile(tile, tile_count_rows, tile_count_cols, config)
53            }
54        }
55    }
56
57    fn to_enum() -> comptime_type!(TilingOrderEnum) {
58        TilingOrderEnum::Tma
59    }
60}
61
62#[derive(CubeType)]
63pub struct TmaGlobalReader<IP: MatrixPrecision> {
65    global_iter: GlobalIterator<Line<IP::Global>>,
66    stage: StridedStage<IP::Stage, TmaTiling>,
67    #[cube(comptime)]
68    config: StageMemoryConfig,
69}
70
71#[cube]
72impl<IP: MatrixPrecision> TmaGlobalReader<IP> {
73    pub fn new(
75        global_view: View<Line<IP::Global>, Coords2d>,
76        k_step: u32,
77        #[comptime] ident: MatmulIdent,
78        #[comptime] config: StageMemoryConfig,
79    ) -> Self {
80        let global_iter = GlobalIterator::new(global_view, k_step, ident.view_direction(), false);
81        let stage = StridedStage::new_aligned(comptime!(ident.into_stage()), 128u32, config);
82
83        TmaGlobalReader::<IP> {
84            global_iter,
85            stage,
86            config,
87        }
88    }
89
90    pub fn load_stage(&mut self, barrier: &Barrier) {
94        if UNIT_POS == 0 {
95            let config = comptime![self.config];
96
97            let size_row = match config.matrix_layout {
98                MatrixLayout::RowMajor => config.elements_in_stage_row(),
99                MatrixLayout::ColMajor => config.elements_in_stage_col(),
100            };
101            let size_col = match config.matrix_layout {
102                MatrixLayout::RowMajor => config.elements_in_tile_col,
103                MatrixLayout::ColMajor => config.elements_in_tile_row,
104            };
105            let tile_count_col = match config.matrix_layout {
106                MatrixLayout::RowMajor => config.tiles_in_stage_col,
107                MatrixLayout::ColMajor => config.tiles_in_stage_row,
108            };
109
110            let global_view = self.global_iter.view();
111            let mut stage = self.stage.as_slice_mut(1u32);
112            let slice_size = size_row * size_col;
113
114            #[unroll]
115            for tile_col in 0..tile_count_col {
116                let slice_start = tile_col * slice_size;
117                let slice = stage.slice_mut(slice_start, slice_start + slice_size);
118                let col = tile_col * size_col;
119
120                let pos = match config.matrix_layout {
121                    MatrixLayout::RowMajor => (0, col),
122                    MatrixLayout::ColMajor => (col, 0),
123                };
124
125                global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), pos);
126            }
127        }
128    }
129
130    pub fn stage(&self) -> TmaStage<IP> {
132        self.stage
133    }
134
135    pub fn advance_view(&mut self) {
137        self.global_iter.advance();
138    }
139}
140
141#[cube]
142pub fn arrive_tma(barrier: &Barrier, #[comptime] num_bytes: u32) {
144    if UNIT_POS == 0 {
145        barrier.arrive_tx(1, num_bytes);
146    } else {
147        barrier.arrive();
148    }
149}