cubecl_matmul/components/global/read/reader/
tma_reader.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3use cubecl_std::tensor::{View, layout::Coords2d};
4
5use crate::components::stage::{
6 ColMajorTilingOrder, ContiguousTilingLayout, StridedStage, TilingOrder,
7};
8use crate::components::stage::{RowMajorTilingOrder, StageMemoryConfig, TilingOrderEnum};
9use crate::components::{MatmulIdent, MatrixPrecision};
10use crate::components::{MatrixLayout, global::memory::GlobalIterator};
11
12pub type TmaTiling = ContiguousTilingLayout<TmaTilingOrder>;
14pub type TmaStage<IP> = StridedStage<<IP as MatrixPrecision>::Stage, TmaTiling>;
16
17#[derive(CubeType, Clone, Copy)]
18pub struct TmaTilingOrder;
22
23#[cube]
24impl TilingOrder for TmaTilingOrder {
25 fn to_row_col(
26 nth: u32,
27 tile_count_rows: u32,
28 tile_count_cols: u32,
29 #[comptime] config: StageMemoryConfig,
30 ) -> Coords2d {
31 match config.matrix_layout {
32 MatrixLayout::RowMajor => {
33 ColMajorTilingOrder::to_row_col(nth, tile_count_rows, tile_count_cols, config)
34 }
35 MatrixLayout::ColMajor => {
36 RowMajorTilingOrder::to_row_col(nth, tile_count_rows, tile_count_cols, config)
37 }
38 }
39 }
40
41 fn to_nth_tile(
42 tile: Coords2d,
43 tile_count_rows: u32,
44 tile_count_cols: u32,
45 #[comptime] config: StageMemoryConfig,
46 ) -> u32 {
47 match config.matrix_layout {
48 MatrixLayout::RowMajor => {
49 ColMajorTilingOrder::to_nth_tile(tile, tile_count_rows, tile_count_cols, config)
50 }
51 MatrixLayout::ColMajor => {
52 RowMajorTilingOrder::to_nth_tile(tile, tile_count_rows, tile_count_cols, config)
53 }
54 }
55 }
56
57 fn to_enum() -> comptime_type!(TilingOrderEnum) {
58 TilingOrderEnum::Tma
59 }
60}
61
62#[derive(CubeType)]
63pub struct TmaGlobalReader<IP: MatrixPrecision> {
65 global_iter: GlobalIterator<Line<IP::Global>>,
66 stage: StridedStage<IP::Stage, TmaTiling>,
67 #[cube(comptime)]
68 config: StageMemoryConfig,
69}
70
71#[cube]
72impl<IP: MatrixPrecision> TmaGlobalReader<IP> {
73 pub fn new(
75 global_view: View<Line<IP::Global>, Coords2d>,
76 k_step: u32,
77 #[comptime] ident: MatmulIdent,
78 #[comptime] config: StageMemoryConfig,
79 ) -> Self {
80 let global_iter = GlobalIterator::new(global_view, k_step, ident.view_direction(), false);
81 let stage = StridedStage::new_aligned(comptime!(ident.into_stage()), 128u32, config);
82
83 TmaGlobalReader::<IP> {
84 global_iter,
85 stage,
86 config,
87 }
88 }
89
90 pub fn load_stage(&mut self, barrier: &Barrier) {
94 if UNIT_POS == 0 {
95 let config = comptime![self.config];
96
97 let size_row = match config.matrix_layout {
98 MatrixLayout::RowMajor => config.elements_in_stage_row(),
99 MatrixLayout::ColMajor => config.elements_in_stage_col(),
100 };
101 let size_col = match config.matrix_layout {
102 MatrixLayout::RowMajor => config.elements_in_tile_col,
103 MatrixLayout::ColMajor => config.elements_in_tile_row,
104 };
105 let tile_count_col = match config.matrix_layout {
106 MatrixLayout::RowMajor => config.tiles_in_stage_col,
107 MatrixLayout::ColMajor => config.tiles_in_stage_row,
108 };
109
110 let global_view = self.global_iter.view();
111 let mut stage = self.stage.as_slice_mut(1u32);
112 let slice_size = size_row * size_col;
113
114 #[unroll]
115 for tile_col in 0..tile_count_col {
116 let slice_start = tile_col * slice_size;
117 let slice = stage.slice_mut(slice_start, slice_start + slice_size);
118 let col = tile_col * size_col;
119
120 let pos = match config.matrix_layout {
121 MatrixLayout::RowMajor => (0, col),
122 MatrixLayout::ColMajor => (col, 0),
123 };
124
125 global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), pos);
126 }
127 }
128 }
129
130 pub fn stage(&self) -> TmaStage<IP> {
132 self.stage
133 }
134
135 pub fn advance_view(&mut self) {
137 self.global_iter.advance();
138 }
139}
140
141#[cube]
142pub fn arrive_tma(barrier: &Barrier, #[comptime] num_bytes: u32) {
144 if UNIT_POS == 0 {
145 barrier.arrive_tx(1, num_bytes);
146 } else {
147 barrier.arrive();
148 }
149}