cubecl_matmul/components/tile/register/
matmul.rs1use std::marker::PhantomData;
2
3use crate::components::MatrixLayout;
4use crate::components::tile::register::config::{ProductType, RegisterMatmulConfig};
5use crate::components::tile::{TileMatmul, io::Filled, register::reader::RegisterFragmentReader};
6use crate::components::tile::{io::Strided, register::reader::RegisterStageReader};
7use crate::components::tile::{io::TileKind, tile_data::StridedTile};
8use crate::components::{StageIdent, tile::register::writer::RegisterStageWriter};
9use cubecl_core::prelude::*;
10use cubecl_core::{self as cubecl};
11
12pub struct RegisterMatmul<Acc: TileKind = Filled> {
14 _ty: PhantomData<Acc>,
15}
16
17pub(super) const UNROLL: bool = false;
21
22#[derive(CubeType)]
23pub struct UnitFragment<E: Numeric> {
24 pub array: Array<E>,
25 #[cube(comptime)]
26 pub layout: MatrixLayout,
27}
28
29#[cube]
30impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
31 for RegisterMatmul<AccTile>
32where
33 RegisterStageReader<AccTile>: RegisterFragmentReader<TileKind = AccTile>,
34{
35 type Config = RegisterMatmulConfig;
36
37 type LhsFragment = UnitFragment<L>;
38 type RhsFragment = UnitFragment<R>;
39 type AccFragment = UnitFragment<A>;
40
41 type LhsTile = Strided;
42 type RhsTile = Strided;
43 type AccTile = AccTile;
44 type OutTile = Strided;
45
46 fn execute(
47 lhs: &Self::LhsFragment,
48 rhs: &Self::RhsFragment,
49 acc: &mut Self::AccFragment,
50 #[comptime] config: Self::Config,
51 ) {
52 match config.product_type {
53 ProductType::Inner => {
54 Self::inner_product(&lhs.array, &rhs.array, &mut acc.array, config)
55 }
56 ProductType::Outer => {
57 Self::outer_product(&lhs.array, &rhs.array, &mut acc.array, config)
58 }
59 }
60 }
61
62 fn allocate_lhs(
63 #[comptime] layout: MatrixLayout,
64 #[comptime] config: Self::Config,
65 ) -> Self::LhsFragment {
66 UnitFragment::<L> {
67 array: Array::new(config.shared.tile_size.mk()),
68 layout,
69 }
70 }
71
72 fn allocate_rhs(
73 #[comptime] layout: MatrixLayout,
74 #[comptime] config: Self::Config,
75 ) -> Self::RhsFragment {
76 UnitFragment::<R> {
77 array: Array::new(config.shared.tile_size.nk()),
78 layout,
79 }
80 }
81
82 fn allocate_acc(
83 #[comptime] layout: MatrixLayout,
84 #[comptime] config: Self::Config,
85 ) -> Self::AccFragment {
86 UnitFragment::<A> {
87 array: Array::new(config.shared.tile_size.mn()),
88 layout,
89 }
90 }
91
92 fn load_lhs<E: Numeric>(
93 tile: &StridedTile<E>,
94 lhs: &mut Self::LhsFragment,
95 #[comptime] config: Self::Config,
96 ) {
97 RegisterStageReader::<Strided>::load_fragment(tile, lhs, StageIdent::Lhs, config)
98 }
99
100 fn load_rhs<E: Numeric>(
101 tile: &StridedTile<E>,
102 rhs: &mut Self::RhsFragment,
103 #[comptime] config: Self::Config,
104 ) {
105 RegisterStageReader::<Strided>::load_fragment(tile, rhs, StageIdent::Rhs, config)
106 }
107
108 fn load_acc<E: Numeric>(
109 tile: &AccTile::Tile<E>,
110 acc: &mut Self::AccFragment,
111 #[comptime] config: Self::Config,
112 ) {
113 RegisterStageReader::<AccTile>::load_fragment(tile, acc, StageIdent::Acc, config);
114 }
115
116 fn write_results<E: Numeric>(
117 tile: &mut StridedTile<E, ReadWrite>,
118 acc: &Self::AccFragment,
119 #[comptime] config: Self::Config,
120 ) {
121 RegisterStageWriter::store_fragment(tile, acc, config)
122 }
123}
124
125#[cube]
126impl<Acc: TileKind> RegisterMatmul<Acc> {
127 fn inner_product<Lhs: Numeric, Rhs: Numeric, EA: Numeric>(
128 lhs: &Array<Lhs>,
129 rhs: &Array<Rhs>,
130 acc: &mut Array<EA>,
131 #[comptime] config: RegisterMatmulConfig,
132 ) {
133 let (m, n, k) =
134 comptime! {let (m, n, k): (u32, u32, u32) = config.shared.tile_size.into(); (m, n, k)};
135
136 #[unroll(UNROLL)]
137 for m_ in 0..m {
138 #[unroll(UNROLL)]
139 for n_ in 0..n {
140 #[unroll(UNROLL)]
141 for k_ in 0..k {
142 let lhs_elem = EA::cast_from(lhs[m_ * k + k_]);
143 let rhs_elem = EA::cast_from(rhs[n_ * k + k_]);
144 acc[m_ * n + n_] += lhs_elem * rhs_elem;
145 }
146 }
147 }
148 }
149
150 fn outer_product<Lhs: Numeric, Rhs: Numeric, EA: Numeric>(
151 lhs: &Array<Lhs>,
152 rhs: &Array<Rhs>,
153 acc: &mut Array<EA>,
154 #[comptime] config: RegisterMatmulConfig,
155 ) {
156 let (m, n, k) =
157 comptime! {let (m, n, k): (u32, u32, u32) = config.shared.tile_size.into(); (m, n, k)};
158
159 #[unroll(UNROLL)]
160 for k_ in 0..k {
161 #[unroll(UNROLL)]
162 for m_ in 0..m {
163 let lhs_elem = EA::cast_from(lhs[k_ * m + m_]);
164 #[unroll(UNROLL)]
165 for n_ in 0..n {
166 let rhs_elem = EA::cast_from(rhs[k_ * n + n_]);
167 acc[m_ * n + n_] += lhs_elem * rhs_elem;
168 }
169 }
170 }
171 }
172
173 pub fn load_plain<ES: Numeric, ER: Numeric>(
174 tile: &StridedTile<ES>,
175 array: &mut Array<ER>,
176 #[comptime] num_segments: u32,
177 #[comptime] segment_size: u32,
178 #[comptime] line_size: u32,
179 ) {
180 let num_lines_per_segment = segment_size / line_size;
181
182 #[unroll(UNROLL)]
183 for segment in 0..num_segments {
184 #[unroll(UNROLL)]
185 for line_within_segment in 0..num_lines_per_segment {
186 let line = tile.get_line(segment, line_within_segment);
187 #[unroll]
188 for pos_within_line in 0..line_size {
189 array[segment * segment_size
190 + line_within_segment * line_size
191 + pos_within_line] = ER::cast_from(line[pos_within_line]);
192 }
193 }
194 }
195 }
196
197 pub fn load_transposed<ES: Numeric, ER: Numeric>(
198 tile: &StridedTile<ES>,
199 array: &mut Array<ER>,
200 #[comptime] num_segments: u32,
201 #[comptime] segment_size: u32,
202 #[comptime] line_size: u32,
203 ) {
204 let num_lines_per_segment = segment_size / line_size;
205
206 #[unroll(UNROLL)]
207 for segment in 0..num_segments {
208 #[unroll(UNROLL)]
209 for line_within_segment in 0..num_lines_per_segment {
210 let line = tile.get_line(segment, line_within_segment);
211 #[unroll]
212 for pos_within_line in 0..line_size {
213 array[(line_within_segment * line_size + pos_within_line) * num_segments
214 + segment] = ER::cast_from(line[pos_within_line]);
215 }
216 }
217 }
218 }
219}