cubecl_linalg/matmul/components/tile/
base.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::matmul::components::{
5 Ident, InputIdent, MatmulConfigFactory, MatmulPrecision, MatmulSize, MatrixLayout,
6 config::MatmulConfig, stage::shared::StageVectorization,
7};
8
9#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
10pub struct TileMatmulConfigInput {
11 pub vectorization: StageVectorization,
12 pub size: MatmulSize,
13}
14
15pub trait TileMatmulFamily:
16 MatmulConfigFactory<Input = TileMatmulConfigInput, Config: TileConfig>
17{
18 fn tile_shape(config: &Self::Config) -> MatmulSize;
19 fn requires_tensor_cores() -> bool;
20
21 type Matmul<MP: MatmulPrecision>: TileMatmul<MP, Config = Self::Config>;
22}
23
24#[cube]
37pub trait TileMatmul<MP: MatmulPrecision>: 'static + Send + Sync {
38 type Config: TileConfig;
39 type Lhs: CubeType;
41 type Rhs: CubeType;
43 type Accumulator: CubeType + Copy + Clone;
45
46 fn execute(
48 lhs: &Self::Lhs,
49 rhs: &Self::Rhs,
50 out: &mut Self::Accumulator,
51 #[comptime] config: Self::Config,
52 );
53
54 fn allocate_lhs(#[comptime] config: Self::Config) -> Self::Lhs;
61
62 fn allocate_rhs(#[comptime] config: Self::Config) -> Self::Rhs;
69
70 fn fill_lhs(slice: &Tile<MP::ES>, lhs: &mut Self::Lhs, #[comptime] config: Self::Config);
72
73 fn fill_rhs(slice: &Tile<MP::ES>, rhs: &mut Self::Rhs, #[comptime] config: Self::Config);
75
76 fn fill_accumulator(
78 tile: &Tile<MP::EA>,
79 acc: &mut Self::Accumulator,
80 #[comptime] config: Self::Config,
81 );
82
83 fn read_accumulator<C: Numeric>(
85 out: &Self::Accumulator,
87 slice: &mut SliceMut<Line<C>>,
88 #[comptime] config: Self::Config,
89 );
90
91 fn allocate_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator;
100
101 fn zero_accumulator(acc: &mut Self::Accumulator, #[comptime] config: Self::Config);
103}
104
105pub trait TileConfig: MatmulConfig {
107 fn plane_dim(&self) -> u32;
109
110 fn matrix_layout(&self, ident: Ident) -> MatrixLayout;
112
113 fn stage_line_size(&self, ident: Ident) -> u32;
115
116 fn tile_shape(&self) -> &MatmulSize;
118}
119
120#[derive(CubeType)]
121pub struct Tile<ES: Numeric> {
123 pub slice: Slice<Line<ES>>,
125 pub stride: u32,
127}
128
129#[cube]
130impl<ES: Numeric> Tile<ES> {
131 pub fn new_contiguous<T: TileConfig>(
132 slice: Slice<Line<ES>>,
133 #[comptime] ident: Ident,
134 #[comptime] config: T,
135 ) -> Tile<ES> {
136 let stride = comptime! {
137 (match ident.as_input_ident() {
138 InputIdent::Lhs => match config.matrix_layout(ident) {
139 MatrixLayout::RowMajor => config.tile_shape().k,
140 MatrixLayout::ColMajor => config.tile_shape().m,
141 },
142 InputIdent::Rhs => match config.matrix_layout(ident) {
143 MatrixLayout::RowMajor => config.tile_shape().n,
144 MatrixLayout::ColMajor => config.tile_shape().k,
145 },
146 }) / config.stage_line_size(ident)};
147
148 Tile::<ES> { slice, stride }
149 }
150
151 pub fn new_strided(slice: Slice<Line<ES>>, stride: u32) -> Tile<ES> {
152 Tile::<ES> { slice, stride }
153 }
154
155 pub fn as_unlined<T: TileConfig>(
156 &self,
157 #[comptime] ident: Ident,
158 #[comptime] config: T,
159 ) -> (Slice<ES>, u32) {
160 (
161 self.slice.try_cast_unchecked(),
162 self.stride * config.stage_line_size(ident),
163 )
164 }
165}