cubecl_matmul/components/tile/register/
matmul.rs

1use std::marker::PhantomData;
2
3use crate::components::MatrixLayout;
4use crate::components::tile::register::config::{ProductType, RegisterMatmulConfig};
5use crate::components::tile::{TileMatmul, io::Filled, register::reader::RegisterFragmentReader};
6use crate::components::tile::{io::Strided, register::reader::RegisterStageReader};
7use crate::components::tile::{io::TileKind, tile_data::StridedTile};
8use crate::components::{StageIdent, tile::register::writer::RegisterStageWriter};
9use cubecl_core::prelude::*;
10use cubecl_core::{self as cubecl};
11
12/// Uses one unit to perform a small matmul directly in registers
13pub struct RegisterMatmul<Acc: TileKind = Filled> {
14    _ty: PhantomData<Acc>,
15}
16
17/// Doesn't impact performance much, but may increase kernel size too much when true (often ~6X).
18///
19/// TODO: make it configurable
20pub(super) const UNROLL: bool = false;
21
22#[derive(CubeType)]
23pub struct UnitFragment<E: Numeric> {
24    pub array: Array<E>,
25    #[cube(comptime)]
26    pub layout: MatrixLayout,
27}
28
29#[cube]
30impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
31    for RegisterMatmul<AccTile>
32where
33    RegisterStageReader<AccTile>: RegisterFragmentReader<TileKind = AccTile>,
34{
35    type Config = RegisterMatmulConfig;
36
37    type LhsFragment = UnitFragment<L>;
38    type RhsFragment = UnitFragment<R>;
39    type AccFragment = UnitFragment<A>;
40
41    type LhsTile = Strided;
42    type RhsTile = Strided;
43    type AccTile = AccTile;
44    type OutTile = Strided;
45
46    fn execute(
47        lhs: &Self::LhsFragment,
48        rhs: &Self::RhsFragment,
49        acc: &mut Self::AccFragment,
50        #[comptime] config: Self::Config,
51    ) {
52        match config.product_type {
53            ProductType::Inner => {
54                Self::inner_product(&lhs.array, &rhs.array, &mut acc.array, config)
55            }
56            ProductType::Outer => {
57                Self::outer_product(&lhs.array, &rhs.array, &mut acc.array, config)
58            }
59        }
60    }
61
62    fn allocate_lhs(
63        #[comptime] layout: MatrixLayout,
64        #[comptime] config: Self::Config,
65    ) -> Self::LhsFragment {
66        UnitFragment::<L> {
67            array: Array::new(config.shared.tile_size.mk()),
68            layout,
69        }
70    }
71
72    fn allocate_rhs(
73        #[comptime] layout: MatrixLayout,
74        #[comptime] config: Self::Config,
75    ) -> Self::RhsFragment {
76        UnitFragment::<R> {
77            array: Array::new(config.shared.tile_size.nk()),
78            layout,
79        }
80    }
81
82    fn allocate_acc(
83        #[comptime] layout: MatrixLayout,
84        #[comptime] config: Self::Config,
85    ) -> Self::AccFragment {
86        UnitFragment::<A> {
87            array: Array::new(config.shared.tile_size.mn()),
88            layout,
89        }
90    }
91
92    fn load_lhs<E: Numeric>(
93        tile: &StridedTile<E>,
94        lhs: &mut Self::LhsFragment,
95        #[comptime] config: Self::Config,
96    ) {
97        RegisterStageReader::<Strided>::load_fragment(tile, lhs, StageIdent::Lhs, config)
98    }
99
100    fn load_rhs<E: Numeric>(
101        tile: &StridedTile<E>,
102        rhs: &mut Self::RhsFragment,
103        #[comptime] config: Self::Config,
104    ) {
105        RegisterStageReader::<Strided>::load_fragment(tile, rhs, StageIdent::Rhs, config)
106    }
107
108    fn load_acc<E: Numeric>(
109        tile: &AccTile::Tile<E>,
110        acc: &mut Self::AccFragment,
111        #[comptime] config: Self::Config,
112    ) {
113        RegisterStageReader::<AccTile>::load_fragment(tile, acc, StageIdent::Acc, config);
114    }
115
116    fn write_results<E: Numeric>(
117        tile: &mut StridedTile<E, ReadWrite>,
118        acc: &Self::AccFragment,
119        #[comptime] config: Self::Config,
120    ) {
121        RegisterStageWriter::store_fragment(tile, acc, config)
122    }
123}
124
125#[cube]
126impl<Acc: TileKind> RegisterMatmul<Acc> {
127    fn inner_product<Lhs: Numeric, Rhs: Numeric, EA: Numeric>(
128        lhs: &Array<Lhs>,
129        rhs: &Array<Rhs>,
130        acc: &mut Array<EA>,
131        #[comptime] config: RegisterMatmulConfig,
132    ) {
133        let (m, n, k) =
134            comptime! {let (m, n, k): (u32, u32, u32) = config.shared.tile_size.into(); (m, n, k)};
135
136        #[unroll(UNROLL)]
137        for m_ in 0..m {
138            #[unroll(UNROLL)]
139            for n_ in 0..n {
140                #[unroll(UNROLL)]
141                for k_ in 0..k {
142                    let lhs_elem = EA::cast_from(lhs[m_ * k + k_]);
143                    let rhs_elem = EA::cast_from(rhs[n_ * k + k_]);
144                    acc[m_ * n + n_] += lhs_elem * rhs_elem;
145                }
146            }
147        }
148    }
149
150    fn outer_product<Lhs: Numeric, Rhs: Numeric, EA: Numeric>(
151        lhs: &Array<Lhs>,
152        rhs: &Array<Rhs>,
153        acc: &mut Array<EA>,
154        #[comptime] config: RegisterMatmulConfig,
155    ) {
156        let (m, n, k) =
157            comptime! {let (m, n, k): (u32, u32, u32) = config.shared.tile_size.into(); (m, n, k)};
158
159        #[unroll(UNROLL)]
160        for k_ in 0..k {
161            #[unroll(UNROLL)]
162            for m_ in 0..m {
163                let lhs_elem = EA::cast_from(lhs[k_ * m + m_]);
164                #[unroll(UNROLL)]
165                for n_ in 0..n {
166                    let rhs_elem = EA::cast_from(rhs[k_ * n + n_]);
167                    acc[m_ * n + n_] += lhs_elem * rhs_elem;
168                }
169            }
170        }
171    }
172
173    pub fn load_plain<ES: Numeric, ER: Numeric>(
174        tile: &StridedTile<ES>,
175        array: &mut Array<ER>,
176        #[comptime] num_segments: u32,
177        #[comptime] segment_size: u32,
178        #[comptime] line_size: u32,
179    ) {
180        let num_lines_per_segment = segment_size / line_size;
181
182        #[unroll(UNROLL)]
183        for segment in 0..num_segments {
184            #[unroll(UNROLL)]
185            for line_within_segment in 0..num_lines_per_segment {
186                let line = tile.get_line(segment, line_within_segment);
187                #[unroll]
188                for pos_within_line in 0..line_size {
189                    array[segment * segment_size
190                        + line_within_segment * line_size
191                        + pos_within_line] = ER::cast_from(line[pos_within_line]);
192                }
193            }
194        }
195    }
196
197    pub fn load_transposed<ES: Numeric, ER: Numeric>(
198        tile: &StridedTile<ES>,
199        array: &mut Array<ER>,
200        #[comptime] num_segments: u32,
201        #[comptime] segment_size: u32,
202        #[comptime] line_size: u32,
203    ) {
204        let num_lines_per_segment = segment_size / line_size;
205
206        #[unroll(UNROLL)]
207        for segment in 0..num_segments {
208            #[unroll(UNROLL)]
209            for line_within_segment in 0..num_lines_per_segment {
210                let line = tile.get_line(segment, line_within_segment);
211                #[unroll]
212                for pos_within_line in 0..line_size {
213                    array[(line_within_segment * line_size + pos_within_line) * num_segments
214                        + segment] = ER::cast_from(line[pos_within_line]);
215                }
216            }
217        }
218    }
219}