cubecl_matmul/components/tile/register/
matmul.rs

1use std::marker::PhantomData;
2
3use crate::components::tile::{TileConfig, TileMatmul, register::reader::RegisterFragmentReader};
4use crate::components::tile::{
5    io::Strided,
6    register::{
7        config::{ProductType, RegisterConfig},
8        reader::RegisterStageReader,
9    },
10};
11use crate::components::tile::{io::TileKind, tile_data::StridedTile};
12use crate::components::{StageIdent, tile::register::writer::RegisterStageWriter};
13use cubecl_core::prelude::*;
14use cubecl_core::{self as cubecl};
15
16/// Uses one unit to perform a small matmul directly in registers
17pub struct RegisterMatmul<Acc: TileKind> {
18    _ty: PhantomData<Acc>,
19}
20
21/// Doesn't impact performance much, but may increase kernel size too much when true (often ~6X).
22///
23/// TODO: make it configurable
24pub(super) const UNROLL: bool = false;
25
26#[cube]
27impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
28    for RegisterMatmul<AccTile>
29where
30    RegisterStageReader<AccTile>: RegisterFragmentReader<TileKind = AccTile>,
31{
32    type Config = RegisterConfig;
33    type LhsFragment = Array<L>;
34    type RhsFragment = Array<R>;
35    type AccFragment = Array<A>;
36
37    type LhsTile = Strided;
38    type RhsTile = Strided;
39    type AccTile = AccTile;
40    type OutTile = Strided;
41
42    fn execute(
43        lhs: &Self::LhsFragment,
44        rhs: &Self::RhsFragment,
45        acc: &mut Self::AccFragment,
46        #[comptime] config: Self::Config,
47    ) {
48        match config.product_type() {
49            ProductType::Inner => Self::inner_product(lhs, rhs, acc, config),
50            ProductType::Outer => Self::outer_product(lhs, rhs, acc, config),
51        }
52    }
53
54    fn allocate_lhs(#[comptime] config: Self::Config) -> Self::LhsFragment {
55        Array::new(config.tile_size().mk())
56    }
57
58    fn allocate_rhs(#[comptime] config: Self::Config) -> Self::RhsFragment {
59        Array::new(config.tile_size().nk())
60    }
61
62    fn allocate_acc(#[comptime] config: Self::Config) -> Self::AccFragment {
63        Array::new(config.tile_size().mn())
64    }
65
66    fn load_lhs<E: Numeric>(
67        tile: &StridedTile<E>,
68        lhs: &mut Self::LhsFragment,
69        #[comptime] config: Self::Config,
70    ) {
71        RegisterStageReader::<Strided>::load_fragment(tile, lhs, StageIdent::Lhs, config)
72    }
73
74    fn load_rhs<E: Numeric>(
75        tile: &StridedTile<E>,
76        rhs: &mut Self::RhsFragment,
77        #[comptime] config: Self::Config,
78    ) {
79        RegisterStageReader::<Strided>::load_fragment(tile, rhs, StageIdent::Rhs, config)
80    }
81
82    fn load_acc<E: Numeric>(
83        tile: &AccTile::Tile<E>,
84        acc: &mut Self::AccFragment,
85        #[comptime] config: Self::Config,
86    ) {
87        RegisterStageReader::<AccTile>::load_fragment(tile, acc, StageIdent::Acc, config);
88    }
89
90    fn write_results<E: Numeric>(
91        tile: &mut StridedTile<E, ReadWrite>,
92        acc: &Self::AccFragment,
93        #[comptime] config: Self::Config,
94    ) {
95        RegisterStageWriter::store_fragment(tile, acc, config)
96    }
97}
98
99#[cube]
100impl<Acc: TileKind> RegisterMatmul<Acc> {
101    fn inner_product<Lhs: Numeric, Rhs: Numeric, EA: Numeric>(
102        lhs: &Array<Lhs>,
103        rhs: &Array<Rhs>,
104        acc: &mut Array<EA>,
105        #[comptime] config: RegisterConfig,
106    ) {
107        let (m, n, k) =
108            comptime! {let (m, n, k): (u32, u32, u32) = (*config.tile_size()).into(); (m, n, k)};
109
110        #[unroll(UNROLL)]
111        for m_ in 0..m {
112            #[unroll(UNROLL)]
113            for n_ in 0..n {
114                #[unroll(UNROLL)]
115                for k_ in 0..k {
116                    let lhs_elem = EA::cast_from(lhs[m_ * k + k_]);
117                    let rhs_elem = EA::cast_from(rhs[n_ * k + k_]);
118                    acc[m_ * n + n_] += lhs_elem * rhs_elem;
119                }
120            }
121        }
122    }
123
124    fn outer_product<Lhs: Numeric, Rhs: Numeric, EA: Numeric>(
125        lhs: &Array<Lhs>,
126        rhs: &Array<Rhs>,
127        acc: &mut Array<EA>,
128        #[comptime] config: RegisterConfig,
129    ) {
130        let (m, n, k) =
131            comptime! {let (m, n, k): (u32, u32, u32) = (*config.tile_size()).into(); (m, n, k)};
132
133        #[unroll(UNROLL)]
134        for k_ in 0..k {
135            #[unroll(UNROLL)]
136            for m_ in 0..m {
137                let lhs_elem = EA::cast_from(lhs[k_ * m + m_]);
138                #[unroll(UNROLL)]
139                for n_ in 0..n {
140                    let rhs_elem = EA::cast_from(rhs[k_ * n + n_]);
141                    acc[m_ * n + n_] += lhs_elem * rhs_elem;
142                }
143            }
144        }
145    }
146
147    pub fn load_plain<ES: Numeric, ER: Numeric>(
148        tile: &StridedTile<ES>,
149        array: &mut Array<ER>,
150        #[comptime] num_segments: u32,
151        #[comptime] segment_size: u32,
152        #[comptime] line_size: u32,
153    ) {
154        let num_lines_per_segment = segment_size / line_size;
155
156        #[unroll(UNROLL)]
157        for segment in 0..num_segments {
158            #[unroll(UNROLL)]
159            for line_within_segment in 0..num_lines_per_segment {
160                let line = tile.get_line(segment, line_within_segment);
161                #[unroll]
162                for pos_within_line in 0..line_size {
163                    array[segment * segment_size
164                        + line_within_segment * line_size
165                        + pos_within_line] = ER::cast_from(line[pos_within_line]);
166                }
167            }
168        }
169    }
170
171    pub fn load_transposed<ES: Numeric, ER: Numeric>(
172        tile: &StridedTile<ES>,
173        array: &mut Array<ER>,
174        #[comptime] num_segments: u32,
175        #[comptime] segment_size: u32,
176        #[comptime] line_size: u32,
177    ) {
178        let num_lines_per_segment = segment_size / line_size;
179
180        #[unroll(UNROLL)]
181        for segment in 0..num_segments {
182            #[unroll(UNROLL)]
183            for line_within_segment in 0..num_lines_per_segment {
184                let line = tile.get_line(segment, line_within_segment);
185                #[unroll]
186                for pos_within_line in 0..line_size {
187                    array[(line_within_segment * line_size + pos_within_line) * num_segments
188                        + segment] = ER::cast_from(line[pos_within_line]);
189                }
190            }
191        }
192    }
193}