cubecl_convolution/components/global/read/reader/
weight_tma.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3use cubecl_matmul::components::{
4    MatrixPrecision, StageIdent,
5    global::memory::{GlobalIterator, ViewDirection},
6    stage::StageMemoryConfig,
7};
8use cubecl_std::tensor::{View, layout::Coords2d};
9
10use cubecl_matmul::components::stage::RowMajorTilingOrder;
11use cubecl_matmul::components::stage::{ContiguousTilingLayout, StridedStage};
12
13pub type TmaWeightTiling = ContiguousTilingLayout<RowMajorTilingOrder>;
14pub type TmaWeightStage<IP> = StridedStage<<IP as MatrixPrecision>::Stage, TmaWeightTiling>;
15
16#[derive(CubeType)]
17pub struct TmaWeightGlobalReader<IP: MatrixPrecision> {
18    pub global_iter: GlobalIterator<Line<IP::Global>>,
19    pub stages: Sequence<StridedStage<IP::Stage, TmaWeightTiling>>,
20    #[cube(comptime)]
21    config: StageMemoryConfig,
22}
23
24#[cube]
25impl<IP: MatrixPrecision> TmaWeightGlobalReader<IP> {
26    pub fn new(
27        global_view: View<Line<IP::Global>, Coords2d>,
28        k_step: u32,
29        #[comptime] num_stages: u32,
30        #[comptime] config: StageMemoryConfig,
31    ) -> Self {
32        let mut stages = Sequence::new();
33
34        #[unroll]
35        for _ in 0..num_stages {
36            stages.push(StridedStage::new_aligned(StageIdent::Rhs, 128u32, config));
37        }
38
39        let global_iter = GlobalIterator::new(global_view, k_step, ViewDirection::Row, false);
40
41        TmaWeightGlobalReader::<IP> {
42            global_iter,
43            stages,
44            config,
45        }
46    }
47
48    pub fn fill_stage(&mut self, barrier: &Barrier, #[comptime] stage_idx: u32) {
49        let stage = self.stages.index_mut(stage_idx);
50        let config = comptime![self.config];
51
52        if UNIT_POS == 0 {
53            let global_view = self.global_iter.view();
54
55            let mut stage = stage.as_slice_mut(1u32);
56            let slice_size = config.elements_in_stage_col() * config.elements_in_tile_row;
57
58            #[unroll]
59            for tile_k in 0..config.tiles_in_stage_row {
60                let slice_start = slice_size * tile_k;
61                let slice = stage.slice_mut(slice_start, slice_size);
62
63                let k = tile_k * config.elements_in_tile_row;
64                global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), (k, 0));
65            }
66        }
67    }
68
69    pub fn stage(&self, #[comptime] stage_idx: u32) -> TmaWeightStage<IP> {
70        *self.stages.index(stage_idx)
71    }
72
73    pub fn advance_view(&mut self) {
74        self.global_iter.advance();
75    }
76}