cubecl_linalg/convolution/loader/
weight_tma.rs1use core::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl, prelude::barrier::Barrier};
5use cubecl_std::{CubeOption, FastDivmod};
6
7use crate::matmul::components::{
8 Ident, InputIdent, MatmulPrecision, global::Quantization, stage::FullReader,
9};
10use crate::matmul::components::{
11 global::{self, tensor_view::MappedTensorReader},
12 stage::{ContiguousTilingLayout, Stage, StageConfig},
13};
14use crate::{convolution::base::RuntimeArgs, matmul::components::stage::RowMajorTilingOrder};
15
16pub type TmaWeightTiling = ContiguousTilingLayout<RowMajorTilingOrder>;
17pub type TmaWeightReader<MP> = FullReader<<MP as MatmulPrecision>::ES, TmaWeightTiling>;
18
19#[derive(CubeType)]
20pub struct TmaWeightLoader<MP: MatmulPrecision, S: StageConfig> {
21 pub tensor_view: MappedTensorReader<MP::EI>,
22 pub stage: Stage<MP::ES, TmaWeightTiling>,
23 padded_channels: FastDivmod,
24 #[cube(comptime)]
25 _config: PhantomData<S>,
26}
27
28#[cube]
29impl<MP: MatmulPrecision, S: StageConfig> TmaWeightLoader<MP, S> {
30 pub fn new<G: global::GlobalConfig>(
31 tensor: TensorMap<MP::EI>,
32 x: u32,
33 y: u32,
34 quantization: CubeOption<Quantization<MP>>,
35 runtime_args: &RuntimeArgs,
36 #[comptime] config: G,
37 ) -> Self {
38 comptime! {
39 if quantization.is_some() {
40 todo!();
41 }
42 }
43
44 let stage = Stage::new_aligned::<G::SmmConfig>(Ident::Rhs, 128u32, config.to_smm_config());
45
46 let tensor_view = MappedTensorReader::new(tensor, x, y, 0);
47
48 TmaWeightLoader::<MP, S> {
49 tensor_view,
50 stage,
51 padded_channels: runtime_args.padded_channels,
52 _config: PhantomData::<S>,
53 }
54 }
55
56 pub fn fill_stage(this: &mut Self, barrier: &Barrier<MP::ES>, #[comptime] config: S) {
57 if UNIT_POS == 0 {
58 let k = this.tensor_view.tile_x;
59 let out_c = this.tensor_view.tile_y;
60 let tiling_dims = config.tiling_dimensions(Ident::Rhs);
61
62 let tensor = this.tensor_view.tensor.try_cast_unchecked();
63 let mut stage = this.stage.as_slice_mut(1u32);
64 let slice_size = tiling_dims.total_col() * tiling_dims.tile_shape_row();
65
66 #[unroll]
67 for tile_k in 0..tiling_dims.tile_count_row() {
68 let slice_start = slice_size * tile_k;
69 let mut slice = stage.slice_mut(slice_start, slice_size);
70
71 let k = k + tile_k * tiling_dims.tile_shape_row();
72 let (k_idx, in_c) = this.padded_channels.div_mod(k);
73
74 barrier.tma_load_3d(&tensor, &mut slice, out_c as i32, k_idx as i32, in_c as i32);
75 }
76 }
77 }
78
79 pub fn reader(this: &Self) -> TmaWeightReader<MP> {
80 TmaWeightReader::<MP>::new(this.stage, InputIdent::Rhs)
81 }
82
83 pub fn advance_view(this: &mut Self, k_offset: u32) {
84 this.tensor_view.update_view(k_offset, Ident::Rhs);
85 }
86}