cubecl_convolution/loader/
weight_tma.rs

1use core::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
5use cubecl_std::{CubeOption, FastDivmod};
6
7use crate::base::RuntimeArgs;
8use cubecl_matmul::components::stage::RowMajorTilingOrder;
9use cubecl_matmul::components::{
10    Ident, InputIdent, MatmulPrecision, global::Quantization, stage::FullStageToTileReader,
11};
12use cubecl_matmul::components::{
13    global::{self, global_memory::MappedTensorReader},
14    stage::{ContiguousTilingLayout, StageConfig, StageMemory},
15};
16
17pub type TmaWeightTiling = ContiguousTilingLayout<RowMajorTilingOrder>;
18pub type TmaWeightReader<MP> = FullStageToTileReader<<MP as MatmulPrecision>::ES, TmaWeightTiling>;
19
20#[derive(CubeType)]
21pub struct TmaWeightLoader<MP: MatmulPrecision, S: StageConfig> {
22    pub tensor_view: MappedTensorReader<MP::EI>,
23    pub stages: Sequence<StageMemory<MP::ES, TmaWeightTiling>>,
24    padded_channels: FastDivmod,
25    #[cube(comptime)]
26    _config: PhantomData<S>,
27}
28
29#[cube]
30impl<MP: MatmulPrecision, S: StageConfig> TmaWeightLoader<MP, S> {
31    pub fn new<G: global::GlobalConfig>(
32        tensor: TensorMap<MP::EI>,
33        x: u32,
34        y: u32,
35        quantization: CubeOption<Quantization<MP>>,
36        runtime_args: &RuntimeArgs,
37        #[comptime] num_stages: u32,
38        #[comptime] config: G,
39    ) -> Self {
40        comptime! {
41            if quantization.is_some() {
42                todo!();
43            }
44        }
45
46        let mut stages = Sequence::new();
47
48        #[unroll]
49        for _ in 0..num_stages {
50            stages.push(StageMemory::new_aligned::<G::StageConfig>(
51                Ident::Rhs,
52                128u32,
53                config.stage_config(),
54            ));
55        }
56
57        let tensor_view = MappedTensorReader::new(tensor, x, y, 0);
58
59        TmaWeightLoader::<MP, S> {
60            tensor_view,
61            stages,
62            padded_channels: runtime_args.padded_channels,
63            _config: PhantomData::<S>,
64        }
65    }
66
67    pub fn fill_stage(
68        this: &mut Self,
69        barrier: &Barrier<MP::ES>,
70        #[comptime] stage_idx: u32,
71        #[comptime] config: S,
72    ) {
73        let stage = this.stages.index_mut(stage_idx);
74
75        if UNIT_POS == 0 {
76            let k = this.tensor_view.tile_x;
77            let out_c = this.tensor_view.tile_y;
78
79            let tensor = this.tensor_view.tensor.try_cast_unchecked();
80            let mut stage = stage.as_slice_mut(1u32);
81            let slice_size = config.tiling_scheme().elements_in_stage_n()
82                * config.tiling_scheme().elements_in_tile_k();
83
84            #[unroll]
85            for tile_k in 0..config.tiling_scheme().tiles_in_stage_k() {
86                let slice_start = slice_size * tile_k;
87                let mut slice = stage.slice_mut(slice_start, slice_size);
88
89                let k = k + tile_k * config.tiling_scheme().elements_in_tile_k();
90                let (k_idx, in_c) = this.padded_channels.div_mod(k);
91
92                barrier.tma_load_3d(&tensor, &mut slice, out_c as i32, k_idx as i32, in_c as i32);
93            }
94        }
95    }
96
97    pub fn reader(this: &Self, #[comptime] stage_idx: u32) -> TmaWeightReader<MP> {
98        TmaWeightReader::<MP>::new(*this.stages.index(stage_idx), InputIdent::Rhs)
99    }
100
101    pub fn advance_view(this: &mut Self, k_offset: u32) {
102        this.tensor_view.update_view(k_offset, Ident::Rhs);
103    }
104}