cubecl_convolution/components/global/read/reader/
weight_tma.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
3use cubecl_matmul::components::{
4    MatrixPrecision,
5    global::memory::{GlobalIterator, ViewDirection},
6    stage::StageMemoryConfig,
7};
8use cubecl_std::tensor::{View, layout::Coords2d};
9
10use cubecl_matmul::components::stage::RowMajorTilingOrder;
11use cubecl_matmul::components::stage::{ContiguousTilingLayout, StridedStageMemory};
12
13pub type TmaWeightTiling = ContiguousTilingLayout<RowMajorTilingOrder>;
14pub type TmaWeightStage<IP> = StridedStageMemory<<IP as MatrixPrecision>::Stage, TmaWeightTiling>;
15
16#[derive(CubeType)]
17pub struct TmaWeightGlobalReader<IP: MatrixPrecision> {
18    pub global_iter: GlobalIterator<Line<IP::Global>>,
19    pub stages: Sequence<StridedStageMemory<IP::Stage, TmaWeightTiling>>,
20    #[cube(comptime)]
21    config: StageMemoryConfig,
22}
23
24#[cube]
25impl<IP: MatrixPrecision> TmaWeightGlobalReader<IP> {
26    pub fn new(
27        global_view: View<Line<IP::Global>, Coords2d>,
28        k_step: u32,
29        #[comptime] num_stages: u32,
30        #[comptime] config: StageMemoryConfig,
31    ) -> Self {
32        let mut stages = Sequence::new();
33
34        #[unroll]
35        for _ in 0..num_stages {
36            stages.push(StridedStageMemory::new_aligned(128u32, config));
37        }
38
39        let global_iter = GlobalIterator::new(global_view, k_step, ViewDirection::Row, false);
40
41        TmaWeightGlobalReader::<IP> {
42            global_iter,
43            stages,
44            config,
45        }
46    }
47
48    pub fn fill_stage(&mut self, barrier: &Barrier, #[comptime] stage_idx: u32) {
49        let stage = self.stages.index_mut(stage_idx);
50        let config = comptime![self.config];
51
52        if UNIT_POS == 0 {
53            let global_view = self.global_iter.view();
54
55            let mut stage = stage.as_slice_mut(1u32);
56            let slice_size =
57                config.elements_per_stage_along_col() * config.elements_per_tile_along_row;
58
59            #[unroll]
60            for tile_k in 0..config.tiles_per_stage_along_row() {
61                let slice_start = slice_size * tile_k;
62                let slice = stage.slice_mut(slice_start, slice_size);
63
64                let k = tile_k * config.elements_per_tile_along_row;
65                global_view.tensor_map_load(barrier, &mut slice.try_cast_unchecked(), (k, 0));
66            }
67        }
68    }
69
70    pub fn stage(&self, #[comptime] stage_idx: u32) -> TmaWeightStage<IP> {
71        *self.stages.index(stage_idx)
72    }
73
74    pub fn advance_view(&mut self) {
75        self.global_iter.advance();
76    }
77}