cubecl_convolution/loader/
weight_tma.rs1use core::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
5use cubecl_std::{CubeOption, FastDivmod};
6
7use crate::base::RuntimeArgs;
8use cubecl_matmul::components::stage::RowMajorTilingOrder;
9use cubecl_matmul::components::{
10 Ident, InputIdent, MatmulPrecision, global::Quantization, stage::FullStageToTileReader,
11};
12use cubecl_matmul::components::{
13 global::{self, global_memory::MappedTensorReader},
14 stage::{ContiguousTilingLayout, StageConfig, StageMemory},
15};
16
17pub type TmaWeightTiling = ContiguousTilingLayout<RowMajorTilingOrder>;
18pub type TmaWeightReader<MP> = FullStageToTileReader<<MP as MatmulPrecision>::ES, TmaWeightTiling>;
19
20#[derive(CubeType)]
21pub struct TmaWeightLoader<MP: MatmulPrecision, S: StageConfig> {
22 pub tensor_view: MappedTensorReader<MP::EI>,
23 pub stages: Sequence<StageMemory<MP::ES, TmaWeightTiling>>,
24 padded_channels: FastDivmod,
25 #[cube(comptime)]
26 _config: PhantomData<S>,
27}
28
29#[cube]
30impl<MP: MatmulPrecision, S: StageConfig> TmaWeightLoader<MP, S> {
31 pub fn new<G: global::GlobalConfig>(
32 tensor: TensorMap<MP::EI>,
33 x: u32,
34 y: u32,
35 quantization: CubeOption<Quantization<MP>>,
36 runtime_args: &RuntimeArgs,
37 #[comptime] num_stages: u32,
38 #[comptime] config: G,
39 ) -> Self {
40 comptime! {
41 if quantization.is_some() {
42 todo!();
43 }
44 }
45
46 let mut stages = Sequence::new();
47
48 #[unroll]
49 for _ in 0..num_stages {
50 stages.push(StageMemory::new_aligned::<G::StageConfig>(
51 Ident::Rhs,
52 128u32,
53 config.stage_config(),
54 ));
55 }
56
57 let tensor_view = MappedTensorReader::new(tensor, x, y, 0);
58
59 TmaWeightLoader::<MP, S> {
60 tensor_view,
61 stages,
62 padded_channels: runtime_args.padded_channels,
63 _config: PhantomData::<S>,
64 }
65 }
66
67 pub fn fill_stage(
68 this: &mut Self,
69 barrier: &Barrier<MP::ES>,
70 #[comptime] stage_idx: u32,
71 #[comptime] config: S,
72 ) {
73 let stage = this.stages.index_mut(stage_idx);
74
75 if UNIT_POS == 0 {
76 let k = this.tensor_view.tile_x;
77 let out_c = this.tensor_view.tile_y;
78
79 let tensor = this.tensor_view.tensor.try_cast_unchecked();
80 let mut stage = stage.as_slice_mut(1u32);
81 let slice_size = config.tiling_scheme().elements_in_stage_n()
82 * config.tiling_scheme().elements_in_tile_k();
83
84 #[unroll]
85 for tile_k in 0..config.tiling_scheme().tiles_in_stage_k() {
86 let slice_start = slice_size * tile_k;
87 let mut slice = stage.slice_mut(slice_start, slice_size);
88
89 let k = k + tile_k * config.tiling_scheme().elements_in_tile_k();
90 let (k_idx, in_c) = this.padded_channels.div_mod(k);
91
92 barrier.tma_load_3d(&tensor, &mut slice, out_c as i32, k_idx as i32, in_c as i32);
93 }
94 }
95 }
96
97 pub fn reader(this: &Self, #[comptime] stage_idx: u32) -> TmaWeightReader<MP> {
98 TmaWeightReader::<MP>::new(*this.stages.index(stage_idx), InputIdent::Rhs)
99 }
100
101 pub fn advance_view(this: &mut Self, k_offset: u32) {
102 this.tensor_view.update_view(k_offset, Ident::Rhs);
103 }
104}