cubecl_linalg/convolution/loader/
weight_tma.rs

1use core::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
5use cubecl_std::{CubeOption, FastDivmod};
6
7use crate::matmul::components::{
8    Ident, InputIdent, MatmulPrecision, global::Quantization, stage::FullReader,
9};
10use crate::matmul::components::{
11    global::{self, tensor_view::MappedTensorReader},
12    stage::{ContiguousTilingLayout, Stage, StageConfig},
13};
14use crate::{convolution::base::RuntimeArgs, matmul::components::stage::RowMajorTilingOrder};
15
16pub type TmaWeightTiling = ContiguousTilingLayout<RowMajorTilingOrder>;
17pub type TmaWeightReader<MP> = FullReader<<MP as MatmulPrecision>::ES, TmaWeightTiling>;
18
19#[derive(CubeType)]
20pub struct TmaWeightLoader<MP: MatmulPrecision, S: StageConfig> {
21    pub tensor_view: MappedTensorReader<MP::EI>,
22    pub stage: Stage<MP::ES, TmaWeightTiling>,
23    padded_channels: FastDivmod,
24    #[cube(comptime)]
25    _config: PhantomData<S>,
26}
27
28#[cube]
29impl<MP: MatmulPrecision, S: StageConfig> TmaWeightLoader<MP, S> {
30    pub fn new<G: global::GlobalConfig>(
31        tensor: TensorMap<MP::EI>,
32        x: u32,
33        y: u32,
34        quantization: CubeOption<Quantization<MP>>,
35        runtime_args: &RuntimeArgs,
36        #[comptime] config: G,
37    ) -> Self {
38        comptime! {
39            if quantization.is_some() {
40                todo!();
41            }
42        }
43
44        let stage = Stage::new_aligned::<G::SmmConfig>(Ident::Rhs, 128u32, config.to_smm_config());
45
46        let tensor_view = MappedTensorReader::new(tensor, x, y, 0);
47
48        TmaWeightLoader::<MP, S> {
49            tensor_view,
50            stage,
51            padded_channels: runtime_args.padded_channels,
52            _config: PhantomData::<S>,
53        }
54    }
55
56    pub fn fill_stage(this: &mut Self, barrier: &Barrier<MP::ES>, #[comptime] config: S) {
57        if UNIT_POS == 0 {
58            let k = this.tensor_view.tile_x;
59            let out_c = this.tensor_view.tile_y;
60            let tiling_dims = config.tiling_dimensions(Ident::Rhs);
61
62            let tensor = this.tensor_view.tensor.try_cast_unchecked();
63            let mut stage = this.stage.as_slice_mut(1u32);
64            let slice_size = tiling_dims.total_col() * tiling_dims.tile_shape_row();
65
66            #[unroll]
67            for tile_k in 0..tiling_dims.tile_count_row() {
68                let slice_start = slice_size * tile_k;
69                let mut slice = stage.slice_mut(slice_start, slice_size);
70
71                let k = k + tile_k * tiling_dims.tile_shape_row();
72                let (k_idx, in_c) = this.padded_channels.div_mod(k);
73
74                barrier.tma_load_3d(&tensor, &mut slice, out_c as i32, k_idx as i32, in_c as i32);
75            }
76        }
77    }
78
79    pub fn reader(this: &Self) -> TmaWeightReader<MP> {
80        TmaWeightReader::<MP>::new(this.stage, InputIdent::Rhs)
81    }
82
83    pub fn advance_view(this: &mut Self, k_offset: u32) {
84        this.tensor_view.update_view(k_offset, Ident::Rhs);
85    }
86}