cubecl_linalg/matmul/components/global/load/loader/
tma_loader.rs1use core::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
5use cubecl_std::CubeOption;
6
7use crate::matmul::components::stage::{FullReader, RowMajorTilingOrder};
8use crate::matmul::components::{
9 Ident, InputIdent, MatmulPrecision, MatrixLayout,
10 global::{Quantization, single_stage},
11};
12use crate::matmul::components::{
13 global::{self, GlobalConfig, tensor_view::MappedTensorReader},
14 stage::{self, ColMajorTilingOrder, ContiguousTilingLayout, Stage, StageConfig, TilingOrder},
15};
16
17pub type TmaTiling = ContiguousTilingLayout<TmaTilingOrder>;
18pub type TmaReader<MP> = FullReader<<MP as MatmulPrecision>::ES, TmaTiling>;
19
20#[derive(CubeType, Clone, Copy)]
21pub struct TmaTilingOrder;
22
23#[cube]
24impl TilingOrder for TmaTilingOrder {
25 fn to_row_col<C: StageConfig>(
26 nth: u32,
27 #[comptime] tile_count_rows: u32,
28 #[comptime] tile_count_cols: u32,
29 #[comptime] ident: Ident,
30 #[comptime] config: C,
31 ) -> (u32, u32) {
32 match config.matrix_layout(ident) {
33 MatrixLayout::RowMajor => ColMajorTilingOrder::to_row_col::<C>(
34 nth,
35 tile_count_rows,
36 tile_count_cols,
37 ident,
38 config,
39 ),
40 MatrixLayout::ColMajor => RowMajorTilingOrder::to_row_col::<C>(
41 nth,
42 tile_count_rows,
43 tile_count_cols,
44 ident,
45 config,
46 ),
47 }
48 }
49 fn to_nth_tile<C: StageConfig>(
50 row: u32,
51 col: u32,
52 #[comptime] tile_count_rows: u32,
53 #[comptime] tile_count_cols: u32,
54 #[comptime] ident: Ident,
55 #[comptime] config: C,
56 ) -> u32 {
57 match config.matrix_layout(ident) {
58 MatrixLayout::RowMajor => ColMajorTilingOrder::to_nth_tile::<C>(
59 row,
60 col,
61 tile_count_rows,
62 tile_count_cols,
63 ident,
64 config,
65 ),
66 MatrixLayout::ColMajor => RowMajorTilingOrder::to_nth_tile::<C>(
67 row,
68 col,
69 tile_count_rows,
70 tile_count_cols,
71 ident,
72 config,
73 ),
74 }
75 }
76}
77
78#[derive(CubeType)]
79pub struct TmaLoader<MP: MatmulPrecision, S: stage::StageConfig> {
80 pub tensor_view: MappedTensorReader<MP::EI>,
81 pub stage: Stage<MP::ES, TmaTiling>,
82 #[cube(comptime)]
83 ident: InputIdent,
84 #[cube(comptime)]
85 _config: PhantomData<S>,
86}
87
88#[cube]
89impl<MP: MatmulPrecision, S: stage::StageConfig> TmaLoader<MP, S> {
90 pub fn new<G: global::GlobalConfig>(
91 tensor: TensorMap<MP::EI>,
92 x: u32,
93 y: u32,
94 batch: u32,
95 quantization: CubeOption<Quantization<MP>>,
96 #[comptime] ident: InputIdent,
97 #[comptime] config: G,
98 ) -> Self {
99 comptime! {
100 if quantization.is_some() {
101 todo!();
102 }
103 }
104
105 let stage = Stage::new_aligned::<G::SmmConfig>(
106 comptime!(ident.as_ident()),
107 128u32,
108 config.to_smm_config(),
109 );
110
111 let tensor_view = MappedTensorReader::new(tensor, x, y, batch);
112
113 TmaLoader::<MP, S> {
114 tensor_view,
115 stage,
116 ident,
117 _config: PhantomData::<S>,
118 }
119 }
120
121 pub fn fill_stage(
122 this: &mut Self,
123 barrier: &Barrier<MP::ES>,
124 #[comptime] config: single_stage::Config<S>,
125 ) {
126 if UNIT_POS == 0 {
127 let ident = comptime!(this.ident.as_ident());
128 let (row, col) = match config.matrix_layout(ident) {
130 MatrixLayout::RowMajor => (this.tensor_view.tile_x, this.tensor_view.tile_y),
131 MatrixLayout::ColMajor => (this.tensor_view.tile_y, this.tensor_view.tile_x),
132 };
133
134 let tiling_dims = config.tiling_dimensions(ident);
135 let size_row = match config.matrix_layout(ident) {
136 MatrixLayout::RowMajor => tiling_dims.total_row(),
137 MatrixLayout::ColMajor => tiling_dims.total_col(),
138 };
139 let size_col = match config.matrix_layout(ident) {
140 MatrixLayout::RowMajor => tiling_dims.tile_shape_col(),
141 MatrixLayout::ColMajor => tiling_dims.tile_shape_row(),
142 };
143 let tile_count_col = match config.matrix_layout(ident) {
144 MatrixLayout::RowMajor => tiling_dims.tile_count_col(),
145 MatrixLayout::ColMajor => tiling_dims.tile_count_row(),
146 };
147
148 let tensor = this.tensor_view.tensor.try_cast_unchecked();
149 let mut stage = this.stage.as_slice_mut(1u32);
150 let slice_size = size_row * size_col;
151 let batch = this.tensor_view.batch as i32;
152
153 #[unroll]
154 for tile_col in 0..tile_count_col {
155 let slice_start = tile_col * slice_size;
156 let mut slice = stage.slice_mut(slice_start, slice_start + slice_size);
157 let col = col + tile_col * size_col;
158
159 barrier.tma_load_3d(&tensor, &mut slice, batch, row as i32, col as i32);
160 }
161 }
162 }
163
164 pub fn reader(this: &Self) -> TmaReader<MP> {
165 TmaReader::<MP>::new(this.stage, this.ident)
166 }
167
168 pub fn advance_view(this: &mut Self, k_offset: u32) {
169 this.tensor_view
170 .update_view(k_offset, comptime!(this.ident.as_ident()));
171 }
172}