cubecl_matmul/components/tile/register/
matmul.rs1use std::marker::PhantomData;
2
3use crate::components::tile::{TileConfig, TileMatmul, register::reader::RegisterFragmentReader};
4use crate::components::tile::{
5 io::Strided,
6 register::{
7 config::{ProductType, RegisterConfig},
8 reader::RegisterStageReader,
9 },
10};
11use crate::components::tile::{io::TileKind, tile_data::StridedTile};
12use crate::components::{StageIdent, tile::register::writer::RegisterStageWriter};
13use cubecl_core::prelude::*;
14use cubecl_core::{self as cubecl};
15
16pub struct RegisterMatmul<Acc: TileKind> {
18 _ty: PhantomData<Acc>,
19}
20
21pub(super) const UNROLL: bool = false;
25
26#[cube]
27impl<L: Numeric, R: Numeric, A: Numeric, AccTile: TileKind> TileMatmul<L, R, A>
28 for RegisterMatmul<AccTile>
29where
30 RegisterStageReader<AccTile>: RegisterFragmentReader<TileKind = AccTile>,
31{
32 type Config = RegisterConfig;
33 type LhsFragment = Array<L>;
34 type RhsFragment = Array<R>;
35 type AccFragment = Array<A>;
36
37 type LhsTile = Strided;
38 type RhsTile = Strided;
39 type AccTile = AccTile;
40 type OutTile = Strided;
41
42 fn execute(
43 lhs: &Self::LhsFragment,
44 rhs: &Self::RhsFragment,
45 acc: &mut Self::AccFragment,
46 #[comptime] config: Self::Config,
47 ) {
48 match config.product_type() {
49 ProductType::Inner => Self::inner_product(lhs, rhs, acc, config),
50 ProductType::Outer => Self::outer_product(lhs, rhs, acc, config),
51 }
52 }
53
54 fn allocate_lhs(#[comptime] config: Self::Config) -> Self::LhsFragment {
55 Array::new(config.tile_size().mk())
56 }
57
58 fn allocate_rhs(#[comptime] config: Self::Config) -> Self::RhsFragment {
59 Array::new(config.tile_size().nk())
60 }
61
62 fn allocate_acc(#[comptime] config: Self::Config) -> Self::AccFragment {
63 Array::new(config.tile_size().mn())
64 }
65
66 fn load_lhs<E: Numeric>(
67 tile: &StridedTile<E>,
68 lhs: &mut Self::LhsFragment,
69 #[comptime] config: Self::Config,
70 ) {
71 RegisterStageReader::<Strided>::load_fragment(tile, lhs, StageIdent::Lhs, config)
72 }
73
74 fn load_rhs<E: Numeric>(
75 tile: &StridedTile<E>,
76 rhs: &mut Self::RhsFragment,
77 #[comptime] config: Self::Config,
78 ) {
79 RegisterStageReader::<Strided>::load_fragment(tile, rhs, StageIdent::Rhs, config)
80 }
81
82 fn load_acc<E: Numeric>(
83 tile: &AccTile::Tile<E>,
84 acc: &mut Self::AccFragment,
85 #[comptime] config: Self::Config,
86 ) {
87 RegisterStageReader::<AccTile>::load_fragment(tile, acc, StageIdent::Acc, config);
88 }
89
90 fn write_results<E: Numeric>(
91 tile: &mut StridedTile<E, ReadWrite>,
92 acc: &Self::AccFragment,
93 #[comptime] config: Self::Config,
94 ) {
95 RegisterStageWriter::store_fragment(tile, acc, config)
96 }
97}
98
99#[cube]
100impl<Acc: TileKind> RegisterMatmul<Acc> {
101 fn inner_product<Lhs: Numeric, Rhs: Numeric, EA: Numeric>(
102 lhs: &Array<Lhs>,
103 rhs: &Array<Rhs>,
104 acc: &mut Array<EA>,
105 #[comptime] config: RegisterConfig,
106 ) {
107 let (m, n, k) =
108 comptime! {let (m, n, k): (u32, u32, u32) = (*config.tile_size()).into(); (m, n, k)};
109
110 #[unroll(UNROLL)]
111 for m_ in 0..m {
112 #[unroll(UNROLL)]
113 for n_ in 0..n {
114 #[unroll(UNROLL)]
115 for k_ in 0..k {
116 let lhs_elem = EA::cast_from(lhs[m_ * k + k_]);
117 let rhs_elem = EA::cast_from(rhs[n_ * k + k_]);
118 acc[m_ * n + n_] += lhs_elem * rhs_elem;
119 }
120 }
121 }
122 }
123
124 fn outer_product<Lhs: Numeric, Rhs: Numeric, EA: Numeric>(
125 lhs: &Array<Lhs>,
126 rhs: &Array<Rhs>,
127 acc: &mut Array<EA>,
128 #[comptime] config: RegisterConfig,
129 ) {
130 let (m, n, k) =
131 comptime! {let (m, n, k): (u32, u32, u32) = (*config.tile_size()).into(); (m, n, k)};
132
133 #[unroll(UNROLL)]
134 for k_ in 0..k {
135 #[unroll(UNROLL)]
136 for m_ in 0..m {
137 let lhs_elem = EA::cast_from(lhs[k_ * m + m_]);
138 #[unroll(UNROLL)]
139 for n_ in 0..n {
140 let rhs_elem = EA::cast_from(rhs[k_ * n + n_]);
141 acc[m_ * n + n_] += lhs_elem * rhs_elem;
142 }
143 }
144 }
145 }
146
147 pub fn load_plain<ES: Numeric, ER: Numeric>(
148 tile: &StridedTile<ES>,
149 array: &mut Array<ER>,
150 #[comptime] num_segments: u32,
151 #[comptime] segment_size: u32,
152 #[comptime] line_size: u32,
153 ) {
154 let num_lines_per_segment = segment_size / line_size;
155
156 #[unroll(UNROLL)]
157 for segment in 0..num_segments {
158 #[unroll(UNROLL)]
159 for line_within_segment in 0..num_lines_per_segment {
160 let line = tile.get_line(segment, line_within_segment);
161 #[unroll]
162 for pos_within_line in 0..line_size {
163 array[segment * segment_size
164 + line_within_segment * line_size
165 + pos_within_line] = ER::cast_from(line[pos_within_line]);
166 }
167 }
168 }
169 }
170
171 pub fn load_transposed<ES: Numeric, ER: Numeric>(
172 tile: &StridedTile<ES>,
173 array: &mut Array<ER>,
174 #[comptime] num_segments: u32,
175 #[comptime] segment_size: u32,
176 #[comptime] line_size: u32,
177 ) {
178 let num_lines_per_segment = segment_size / line_size;
179
180 #[unroll(UNROLL)]
181 for segment in 0..num_segments {
182 #[unroll(UNROLL)]
183 for line_within_segment in 0..num_lines_per_segment {
184 let line = tile.get_line(segment, line_within_segment);
185 #[unroll]
186 for pos_within_line in 0..line_size {
187 array[(line_within_segment * line_size + pos_within_line) * num_segments
188 + segment] = ER::cast_from(line[pos_within_line]);
189 }
190 }
191 }
192 }
193}