1use cubecl;
2use cubecl::prelude::*;
3
4use crate::{
5 StageIdent,
6 tile::{
7 MmaFragment, MmaFragmentExpand, Tile, TileExpand, TileScope,
8 compute::matmul::{
9 cmma::{cmma_load_from_shared, cmma_load_zeros, cmma_write_to_shared},
10 interleaved::{
11 interleaved_load_from_shared, interleaved_load_zeros, interleaved_write_to_shared,
12 },
13 mma::{
14 mma_load_acc_from_shared, mma_load_acc_zeros, mma_load_lhs_from_shared,
15 mma_load_rhs_from_shared, mma_write_to_shared,
16 },
17 plane_vec::{planevec_load_from_shared, planevec_load_zeros, planevec_write_to_shared},
18 register::{register_load_from_shared, register_load_zeros, register_write_to_shared},
19 },
20 data::BounceTile,
21 },
22};
23
24#[cube]
29pub(crate) fn cmma_to_whitebox_fragment<E: Float>(b: &mut BounceTile<E>) {
30 let stride = comptime!(b.cmma.tile_size.n());
31 cubecl::cmma::store(
32 &mut b.smem,
33 &b.cmma.matrix,
34 stride,
35 cubecl::cmma::MatrixLayout::RowMajor,
36 );
37 sync_cube();
38 b.fragment.load_from_slice(&b.smem.to_slice());
39 sync_cube();
40}
41
42#[cube]
46pub(crate) fn whitebox_fragment_to_cmma<E: Float>(b: &mut BounceTile<E>) {
47 let stride = comptime!(b.cmma.tile_size.n());
48 b.fragment.store_to(&mut b.smem);
49 sync_cube();
50 cubecl::cmma::load_with_layout(
51 &b.cmma.matrix,
52 &b.smem.to_slice(),
53 stride,
54 cubecl::cmma::MatrixLayout::RowMajor,
55 );
56}
57
58#[cube]
59impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
60 pub fn copy_from<
67 SE: Numeric,
68 SS: Size,
69 L: Numeric,
70 R: Numeric,
71 A: Numeric,
72 SIO: SliceVisibility,
73 >(
74 &mut self,
75 source: &Tile<SE, Sc, SIO>,
76 #[comptime] ident: StageIdent,
77 ) {
78 match (source, self) {
79 (Tile::SharedMemory(shared), Tile::Cmma(t)) => {
81 let shared = shared.view::<SS>();
82 cmma_load_from_shared::<SE, SS, N, SIO>(
83 &shared,
84 &mut t.matrix,
85 ident,
86 t.matrix_layout,
87 );
88 }
89 (Tile::None, Tile::Cmma(t)) => {
90 cmma_load_zeros::<N>(&mut t.matrix);
91 }
92
93 (Tile::SharedMemory(shared), Tile::Bounce(b)) => {
95 let shared = shared.view::<SS>();
96 cmma_load_from_shared::<SE, SS, N, SIO>(
97 &shared,
98 &mut b.cmma.matrix,
99 ident,
100 b.cmma.matrix_layout,
101 );
102 }
103 (Tile::None, Tile::Bounce(b)) => {
104 cmma_load_zeros::<N>(&mut b.cmma.matrix);
105 }
106
107 (Tile::SharedMemory(shared), Tile::Mma(t)) => {
109 let shared = shared.view::<SS>();
110 match &mut t.fragment {
111 MmaFragment::Lhs(f) => mma_load_lhs_from_shared::<SE, SS, N, R, A, SIO>(
112 &shared,
113 f,
114 t.matrix_layout,
115 t.config,
116 ),
117 MmaFragment::Rhs(f) => mma_load_rhs_from_shared::<SE, SS, N, L, A, SIO>(
118 &shared,
119 f,
120 t.matrix_layout,
121 t.config,
122 ),
123 MmaFragment::Acc(f) => mma_load_acc_from_shared::<SE, SS, N, L, R, SIO>(
124 &shared,
125 f,
126 t.matrix_layout,
127 t.config,
128 ),
129 }
130 }
131 (Tile::None, Tile::Mma(t)) => match &mut t.fragment {
132 MmaFragment::Acc(f) => {
133 mma_load_acc_zeros::<SE, SS, N, L, R>(f, t.matrix_layout, t.config);
134 }
135 MmaFragment::Lhs(_) | MmaFragment::Rhs(_) => {
136 panic!("Mma zero-load only supported for Acc role")
137 }
138 },
139
140 (Tile::SharedMemory(shared), Tile::Register(t)) => {
142 let shared = shared.view::<SS>();
143 register_load_from_shared::<SE, SS, N, SIO>(
144 &shared,
145 &mut t.data,
146 t.matrix_layout,
147 t.config,
148 ident,
149 );
150 }
151 (Tile::None, Tile::Register(t)) => {
152 register_load_zeros::<N>(&mut t.data, t.config, ident);
153 }
154
155 (Tile::SharedMemory(shared), Tile::PlaneVec(t)) => {
157 let shared = shared.view::<SS>();
158 planevec_load_from_shared::<SE, SS, N, SIO>(&shared, &mut t.data, t.config, ident);
159 }
160 (Tile::None, Tile::PlaneVec(t)) => {
161 planevec_load_zeros::<N>(&mut t.data, t.config);
162 }
163
164 (Tile::SharedMemory(shared), Tile::Interleaved(t)) => {
166 let shared = shared.view::<SS>();
167 interleaved_load_from_shared::<SE, SS, N, SIO>(
168 &shared,
169 &mut t.data,
170 t.config,
171 ident,
172 );
173 }
174 (Tile::None, Tile::Interleaved(t)) => {
175 interleaved_load_zeros::<N>(&mut t.data, t.config);
176 }
177
178 (Tile::Cmma(t), Tile::SharedMemory(shared)) => {
180 let mut shared = shared.view::<SS>();
181 cmma_write_to_shared::<N, SS, SE>(&mut shared, &t.matrix);
182 }
183 (Tile::Bounce(b), Tile::SharedMemory(shared)) => {
184 let mut shared = shared.view::<SS>();
185 cmma_write_to_shared::<N, SS, SE>(&mut shared, &b.cmma.matrix);
186 }
187 (Tile::Mma(t), Tile::SharedMemory(shared)) => {
188 let mut shared = shared.view::<SS>();
189 match &t.fragment {
190 MmaFragment::Acc(f) => {
191 mma_write_to_shared::<N, SS, SE, L, R>(&mut shared, f, t.config);
192 }
193 MmaFragment::Lhs(_) | MmaFragment::Rhs(_) => {
194 panic!("Mma write_to_shared only supported for Acc role")
195 }
196 }
197 }
198 (Tile::Register(t), Tile::SharedMemory(shared)) => {
199 let mut shared = shared.view::<SS>();
200 register_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
201 }
202 (Tile::PlaneVec(t), Tile::SharedMemory(shared)) => {
203 let mut shared = shared.view::<SS>();
204 planevec_write_to_shared::<SE, N, SS>(&mut shared, &t.data, t.config);
205 }
206 (Tile::Interleaved(t), Tile::SharedMemory(shared)) => {
207 let mut shared = shared.view::<SS>();
208 interleaved_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
209 }
210
211 _ => panic!("Unsupported storage pair for copy_from"),
212 }
213 }
214}