1use crate::{
2 SpirvCompiler, SpirvTarget,
3 item::{Elem, Item},
4 lookups::Matrix,
5 variable::Variable,
6};
7use cubecl_core::ir::{self as core, CoopMma, ElemType, Id, MatrixLayout};
8use rspirv::spirv::{
9 Capability, CooperativeMatrixLayout, CooperativeMatrixOperands, CooperativeMatrixUse,
10 StorageClass,
11};
12
13impl<T: SpirvTarget> SpirvCompiler<T> {
14 pub fn compile_cmma(&mut self, cmma: CoopMma, out: Option<core::Variable>) {
15 self.capabilities.insert(Capability::CooperativeMatrixKHR);
16 let out = out.unwrap();
17 match cmma {
18 CoopMma::Fill { value } => self.compile_fill(out, value),
19 CoopMma::Load {
20 value,
21 stride,
22 layout,
23 offset,
24 } => self.compile_load(out, value, stride, offset, layout),
25 CoopMma::Execute {
26 mat_a,
27 mat_b,
28 mat_c,
29 } => self.compile_execute(mat_a, mat_b, mat_c, out),
30 CoopMma::Store {
31 mat,
32 stride,
33 layout,
34 offset,
35 } => self.compile_store(mat, out, stride, offset, layout),
36 CoopMma::Cast { input } => self.compile_cast(input, out),
37 CoopMma::RowIndex { .. }
38 | CoopMma::ColIndex { .. }
39 | CoopMma::LoadMatrix { .. }
40 | CoopMma::StoreMatrix { .. }
41 | CoopMma::ExecuteManual { .. }
42 | CoopMma::ExecuteScaled { .. } => {
43 panic!("Manual register management not currently supported in SPIR-V")
44 }
45 }
46 }
47
48 fn compile_load(
49 &mut self,
50 mat: core::Variable,
51 value: core::Variable,
52 stride: core::Variable,
53 offset: core::Variable,
54 layout: Option<MatrixLayout>,
55 ) {
56 let mat = self.compile_variable(mat);
57 let mat = self.matrix_var(&mat).1;
58
59 let value = self.compile_variable(value);
60 let stride = self.compile_variable(stride);
61 let stride_item = stride.item();
62 let mut stride = self.read(&stride);
63
64 if let Item::Vector(_, line_size) = value.item() {
65 let shift = stride_item.const_u32(self, line_size.trailing_zeros());
66 let stride_ty = stride_item.id(self);
67 stride = self
68 .shift_right_logical(stride_ty, None, stride, shift)
69 .unwrap();
70 }
71
72 let layout = layout
73 .and_then(compile_layout)
74 .or(mat.layout)
75 .unwrap_or(CooperativeMatrixLayout::RowMajorKHR);
76 let memory_layout = self.const_u32(layout as u32);
77
78 let offset = self.compile_variable(offset);
79 let ptr = self.index_ptr(&value, &offset);
80 let out_ty = self.item(&mat);
81 let ty = out_ty.id(self);
82
83 let mat_id = self
84 .cooperative_matrix_load_khr(ty, None, ptr, memory_layout, Some(stride), None, vec![])
85 .unwrap();
86
87 self.store(mat.id, mat_id, None, vec![]).unwrap();
88 }
89
90 fn compile_fill(&mut self, mat: core::Variable, value: core::Variable) {
91 let mat = self.compile_variable(mat);
92 let value = self.compile_variable(value);
93 let mat = self.matrix_var(&mat).1;
94 let item = self.item(&mat);
95 let ty = item.id(self);
96 let mat_id = match value {
97 Variable::ConstantScalar(id, _, _) => self.constant_composite(ty, vec![id]),
98 var => {
99 let var = self.read(&var);
100 self.composite_construct(ty, None, vec![var]).unwrap()
101 }
102 };
103
104 self.store(mat.id, mat_id, None, vec![]).unwrap();
105 }
106
107 fn compile_store(
108 &mut self,
109 mat: core::Variable,
110 out: core::Variable,
111 stride: core::Variable,
112 offset: core::Variable,
113 layout: MatrixLayout,
114 ) {
115 let mat = self.compile_variable(mat);
116 let mat = self.matrix_var(&mat).1;
117 let item = self.item(&mat);
118 let ty = item.id(self);
119 let mat_obj = self.load(ty, None, mat.id, None, vec![]).unwrap();
120 let out = self.compile_variable(out);
123 let stride = self.compile_variable(stride);
124 let stride_item = stride.item();
125 let mut stride = self.read(&stride);
126 let layout = compile_layout(layout).unwrap_or(CooperativeMatrixLayout::RowMajorKHR);
127 let memory_layout = self.const_u32(layout as u32);
128 let offset = self.compile_variable(offset);
129 let ptr = self.index_ptr(&out, &offset);
130
131 if let Item::Vector(_, line_size) = out.item() {
132 let shift = stride_item.const_u32(self, line_size.trailing_zeros());
133 let stride_ty = stride_item.id(self);
134 stride = self
135 .shift_right_logical(stride_ty, None, stride, shift)
136 .unwrap();
137 }
138
139 self.cooperative_matrix_store_khr(ptr, mat_obj, memory_layout, Some(stride), None, vec![])
140 .unwrap();
141 }
142
143 fn compile_execute(
144 &mut self,
145 mat_a: core::Variable,
146 mat_b: core::Variable,
147 mat_c: core::Variable,
148 mat_d: core::Variable,
149 ) {
150 let mat_a = self.compile_variable(mat_a);
151 let mat_b = self.compile_variable(mat_b);
152 let mat_c = self.compile_variable(mat_c);
153 let mat_d = self.compile_variable(mat_d);
154
155 let mat_a = self.matrix_var(&mat_a).1;
156 let mat_b = self.matrix_var(&mat_b).1;
157 let mat_c = self.matrix_var(&mat_c).1;
158 let mat_d = self.matrix_var(&mat_d).1;
159
160 let mat_a_ty = self.item(&mat_a).id(self);
161 let mat_b_ty = self.item(&mat_b).id(self);
162 let mat_c_ty = self.item(&mat_c).id(self);
163
164 let mat_a_id = self.load(mat_a_ty, None, mat_a.id, None, vec![]).unwrap();
165 let mat_b_id = self.load(mat_b_ty, None, mat_b.id, None, vec![]).unwrap();
166 let mat_c_id = self.load(mat_c_ty, None, mat_c.id, None, vec![]).unwrap();
167
168 let ty = self.item(&mat_d).id(self);
169
170 let mut operands = CooperativeMatrixOperands::NONE_KHR;
171 if matches!(mat_a.elem, Elem::Int(_, true)) {
172 operands |= CooperativeMatrixOperands::MATRIX_A_SIGNED_COMPONENTS_KHR;
173 }
174 if matches!(mat_b.elem, Elem::Int(_, true)) {
175 operands |= CooperativeMatrixOperands::MATRIX_B_SIGNED_COMPONENTS_KHR;
176 }
177 if matches!(mat_c.elem, Elem::Int(_, true)) {
178 operands |= CooperativeMatrixOperands::MATRIX_C_SIGNED_COMPONENTS_KHR;
179 }
180 if matches!(mat_d.elem, Elem::Int(_, true)) {
181 operands |= CooperativeMatrixOperands::MATRIX_RESULT_SIGNED_COMPONENTS_KHR;
182 }
183
184 let mat_d_id = self
185 .cooperative_matrix_mul_add_khr(ty, None, mat_a_id, mat_b_id, mat_c_id, Some(operands))
186 .unwrap();
187
188 self.store(mat_d.id, mat_d_id, None, vec![]).unwrap();
189 }
190
191 fn compile_cast(&mut self, input: core::Variable, output: core::Variable) {
192 let input = self.compile_variable(input);
193 let output = self.compile_variable(output);
194
195 let input = self.matrix_var(&input).1;
196 let output = self.matrix_var(&output).1;
197
198 let input_ty = self.item(&input).id(self);
199 let output_ty = self.item(&output).id(self);
200
201 let fragment_id = self.load(input_ty, None, input.id, None, vec![]).unwrap();
202
203 let frag_new = self.f_convert(output_ty, None, fragment_id).unwrap();
204
205 self.store(output.id, frag_new, None, vec![]).unwrap();
206 }
207
208 fn matrix_var(&mut self, var: &Variable) -> (Id, Matrix) {
209 let id = match var {
210 Variable::CoopMatrix(id, _) => *id,
211 _ => unreachable!(),
212 };
213 let mat = self.state.matrices[&id];
214 (id, mat)
215 }
216
217 fn rows(&mut self, mat: &Matrix) -> u32 {
218 let rows = match mat.ident {
219 CooperativeMatrixUse::MatrixAKHR => mat.m,
220 CooperativeMatrixUse::MatrixBKHR => mat.k,
221 CooperativeMatrixUse::MatrixAccumulatorKHR => mat.m,
222 };
223 self.const_u32(rows)
224 }
225
226 fn columns(&mut self, mat: &Matrix) -> u32 {
227 let columns = match mat.ident {
228 CooperativeMatrixUse::MatrixAKHR => mat.k,
229 CooperativeMatrixUse::MatrixBKHR => mat.n,
230 CooperativeMatrixUse::MatrixAccumulatorKHR => mat.n,
231 };
232 self.const_u32(columns)
233 }
234
235 pub fn item(&mut self, mat: &Matrix) -> Item {
236 Item::CoopMatrix {
237 ty: mat.elem,
238 rows: self.rows(mat),
239 columns: self.columns(mat),
240 ident: mat.ident,
241 }
242 }
243
244 pub fn init_coop_matrix(&mut self, mat: core::Matrix, var: core::Variable) -> Matrix {
245 if mat.storage.elem_type() == ElemType::Float(core::FloatKind::BF16) {
246 self.capabilities
247 .insert(Capability::BFloat16CooperativeMatrixKHR);
248 }
249 if matches!(
250 mat.storage.elem_type(),
251 ElemType::Float(core::FloatKind::E5M2 | core::FloatKind::E4M3)
252 ) {
253 self.capabilities
254 .insert(Capability::Float8CooperativeMatrixEXT);
255 }
256
257 let elem = self.compile_type(core::Type::new(mat.storage)).elem();
258 let ident = match mat.ident {
259 core::MatrixIdent::A => CooperativeMatrixUse::MatrixAKHR,
260 core::MatrixIdent::B => CooperativeMatrixUse::MatrixBKHR,
261 core::MatrixIdent::Accumulator => CooperativeMatrixUse::MatrixAccumulatorKHR,
262 };
263 let layout = compile_layout(mat.layout);
264
265 let mut mat = Matrix {
266 id: 0,
267 ident,
268 m: mat.m,
269 n: mat.n,
270 k: mat.k,
271 elem,
272 layout,
273 };
274
275 let item = Item::Pointer(StorageClass::Function, Box::new(self.item(&mat)));
276 let ty = item.id(self);
277 mat.id = self.declare_function_variable(ty);
278 self.debug_var_name(mat.id, var);
279
280 mat
281 }
282}
283
284fn compile_layout(layout: MatrixLayout) -> Option<CooperativeMatrixLayout> {
285 match layout {
286 core::MatrixLayout::ColMajor => Some(CooperativeMatrixLayout::ColumnMajorKHR),
287 core::MatrixLayout::RowMajor => Some(CooperativeMatrixLayout::RowMajorKHR),
288 core::MatrixLayout::Undefined => None,
289 }
290}