cubecl_convolution/components/global/read/reader/
weight_tma.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3use cubecl_matmul::components::{
4 MatrixPrecision, StageIdent,
5 global::memory::{GlobalIterator, ViewDirection},
6 stage::StageMemoryConfig,
7};
8use cubecl_std::tensor::{View, layout::Coords2d};
9
10use cubecl_matmul::components::stage::RowMajorTilingOrder;
11use cubecl_matmul::components::stage::{ContiguousTilingLayout, StridedStage};
12
13pub type TmaWeightTiling = ContiguousTilingLayout<RowMajorTilingOrder>;
14pub type TmaWeightStage<IP> = StridedStage<<IP as MatrixPrecision>::Stage, TmaWeightTiling>;
15
16#[derive(CubeType)]
17pub struct TmaWeightGlobalReader<IP: MatrixPrecision> {
18 pub global_iter: GlobalIterator<Line<IP::Global>>,
19 pub stages: Sequence<StridedStage<IP::Stage, TmaWeightTiling>>,
20 #[cube(comptime)]
21 config: StageMemoryConfig,
22}
23
24#[cube]
25impl<IP: MatrixPrecision> TmaWeightGlobalReader<IP> {
26 pub fn new(
27 global_view: View<Line<IP::Global>, Coords2d>,
28 k_step: u32,
29 #[comptime] num_stages: u32,
30 #[comptime] config: StageMemoryConfig,
31 ) -> Self {
32 let mut stages = Sequence::new();
33
34 #[unroll]
35 for _ in 0..num_stages {
36 stages.push(StridedStage::new_aligned(StageIdent::Rhs, 128u32, config));
37 }
38
39 let global_iter = GlobalIterator::new(global_view, k_step, ViewDirection::Row, false);
40
41 TmaWeightGlobalReader::<IP> {
42 global_iter,
43 stages,
44 config,
45 }
46 }
47
48 pub fn fill_stage(&mut self, barrier: &Barrier, #[comptime] stage_idx: u32) {
49 let stage = self.stages.index_mut(stage_idx);
50 let config = comptime![self.config];
51
52 if UNIT_POS == 0 {
53 let global_view = self.global_iter.view();
54
55 let mut stage = stage.as_slice_mut(1u32);
56 let slice_size = config.elements_in_stage_col() * config.elements_in_tile_row;
57
58 #[unroll]
59 for tile_k in 0..config.tiles_in_stage_row {
60 let slice_start = slice_size * tile_k;
61 let slice = stage.slice_mut(slice_start, slice_size);
62
63 let k = tile_k * config.elements_in_tile_row;
64 global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), (k, 0));
65 }
66 }
67 }
68
69 pub fn stage(&self, #[comptime] stage_idx: u32) -> TmaWeightStage<IP> {
70 *self.stages.index(stage_idx)
71 }
72
73 pub fn advance_view(&mut self) {
74 self.global_iter.advance();
75 }
76}