1use crate::{
2 SpirvCompiler, SpirvTarget,
3 item::{Elem, Item},
4 lookups::Matrix,
5 variable::Variable,
6};
7use cubecl_core::ir::{self as core, CoopMma, ElemType, Id, MatrixLayout};
8use rspirv::spirv::{
9 Capability, CooperativeMatrixLayout, CooperativeMatrixOperands, CooperativeMatrixUse,
10 StorageClass,
11};
12
13impl<T: SpirvTarget> SpirvCompiler<T> {
14 pub fn compile_cmma(&mut self, cmma: CoopMma, out: Option<core::Variable>) {
15 self.capabilities.insert(Capability::CooperativeMatrixKHR);
16 let out = out.unwrap();
17 match cmma {
18 CoopMma::Fill { value } => self.compile_fill(out, value),
19 CoopMma::Load {
20 value,
21 stride,
22 layout,
23 offset,
24 } => self.compile_load(out, value, stride, offset, layout),
25 CoopMma::Execute {
26 mat_a,
27 mat_b,
28 mat_c,
29 } => self.compile_execute(mat_a, mat_b, mat_c, out),
30 CoopMma::Store {
31 mat,
32 stride,
33 layout,
34 offset,
35 } => self.compile_store(mat, out, stride, offset, layout),
36 CoopMma::Cast { input } => self.compile_cast(input, out),
37 CoopMma::RowIndex { .. }
38 | CoopMma::ColIndex { .. }
39 | CoopMma::ExecuteManual { .. }
40 | CoopMma::ExecuteScaled { .. } => {
41 panic!("Manual register management not currently supported in SPIR-V")
42 }
43 }
44 }
45
46 fn compile_load(
47 &mut self,
48 mat: core::Variable,
49 value: core::Variable,
50 stride: core::Variable,
51 offset: core::Variable,
52 layout: Option<MatrixLayout>,
53 ) {
54 let mat = self.compile_variable(mat);
55 let mat = self.matrix_var(&mat).1;
56
57 let value = self.compile_variable(value);
58 let stride = self.compile_variable(stride);
59 let stride_item = stride.item();
60 let mut stride = self.read(&stride);
61
62 if let Item::Vector(_, line_size) = value.item() {
63 let shift = stride_item.const_u32(self, line_size.trailing_zeros());
64 let stride_ty = stride_item.id(self);
65 stride = self
66 .shift_right_logical(stride_ty, None, stride, shift)
67 .unwrap();
68 }
69
70 let layout = layout
71 .and_then(compile_layout)
72 .or(mat.layout)
73 .unwrap_or(CooperativeMatrixLayout::RowMajorKHR);
74 let memory_layout = self.const_u32(layout as u32);
75
76 let offset = self.compile_variable(offset);
77 let ptr = self.index_ptr(&value, &offset);
78 let out_ty = self.item(&mat);
79 let ty = out_ty.id(self);
80
81 let mat_id = self
82 .cooperative_matrix_load_khr(ty, None, ptr, memory_layout, Some(stride), None, vec![])
83 .unwrap();
84
85 self.store(mat.id, mat_id, None, vec![]).unwrap();
86 }
87
88 fn compile_fill(&mut self, mat: core::Variable, value: core::Variable) {
89 let mat = self.compile_variable(mat);
90 let value = self.compile_variable(value);
91 let mat = self.matrix_var(&mat).1;
92 let item = self.item(&mat);
93 let ty = item.id(self);
94 let mat_id = match value {
95 Variable::ConstantScalar(id, _, _) => self.constant_composite(ty, vec![id]),
96 var => {
97 let var = self.read(&var);
98 self.composite_construct(ty, None, vec![var]).unwrap()
99 }
100 };
101
102 self.store(mat.id, mat_id, None, vec![]).unwrap();
103 }
104
105 fn compile_store(
106 &mut self,
107 mat: core::Variable,
108 out: core::Variable,
109 stride: core::Variable,
110 offset: core::Variable,
111 layout: MatrixLayout,
112 ) {
113 let mat = self.compile_variable(mat);
114 let mat = self.matrix_var(&mat).1;
115 let item = self.item(&mat);
116 let ty = item.id(self);
117 let mat_obj = self.load(ty, None, mat.id, None, vec![]).unwrap();
118 let out = self.compile_variable(out);
121 let stride = self.compile_variable(stride);
122 let stride_item = stride.item();
123 let mut stride = self.read(&stride);
124 let layout = compile_layout(layout).unwrap_or(CooperativeMatrixLayout::RowMajorKHR);
125 let memory_layout = self.const_u32(layout as u32);
126 let offset = self.compile_variable(offset);
127 let ptr = self.index_ptr(&out, &offset);
128
129 if let Item::Vector(_, line_size) = out.item() {
130 let shift = stride_item.const_u32(self, line_size.trailing_zeros());
131 let stride_ty = stride_item.id(self);
132 stride = self
133 .shift_right_logical(stride_ty, None, stride, shift)
134 .unwrap();
135 }
136
137 self.cooperative_matrix_store_khr(ptr, mat_obj, memory_layout, Some(stride), None, vec![])
138 .unwrap();
139 }
140
141 fn compile_execute(
142 &mut self,
143 mat_a: core::Variable,
144 mat_b: core::Variable,
145 mat_c: core::Variable,
146 mat_d: core::Variable,
147 ) {
148 let mat_a = self.compile_variable(mat_a);
149 let mat_b = self.compile_variable(mat_b);
150 let mat_c = self.compile_variable(mat_c);
151 let mat_d = self.compile_variable(mat_d);
152
153 let mat_a = self.matrix_var(&mat_a).1;
154 let mat_b = self.matrix_var(&mat_b).1;
155 let mat_c = self.matrix_var(&mat_c).1;
156 let mat_d = self.matrix_var(&mat_d).1;
157
158 let mat_a_ty = self.item(&mat_a).id(self);
159 let mat_b_ty = self.item(&mat_b).id(self);
160 let mat_c_ty = self.item(&mat_c).id(self);
161
162 let mat_a_id = self.load(mat_a_ty, None, mat_a.id, None, vec![]).unwrap();
163 let mat_b_id = self.load(mat_b_ty, None, mat_b.id, None, vec![]).unwrap();
164 let mat_c_id = self.load(mat_c_ty, None, mat_c.id, None, vec![]).unwrap();
165
166 let ty = self.item(&mat_d).id(self);
167
168 let mut operands = CooperativeMatrixOperands::NONE_KHR;
169 if matches!(mat_a.elem, Elem::Int(_, true)) {
170 operands |= CooperativeMatrixOperands::MATRIX_A_SIGNED_COMPONENTS_KHR;
171 }
172 if matches!(mat_b.elem, Elem::Int(_, true)) {
173 operands |= CooperativeMatrixOperands::MATRIX_B_SIGNED_COMPONENTS_KHR;
174 }
175 if matches!(mat_c.elem, Elem::Int(_, true)) {
176 operands |= CooperativeMatrixOperands::MATRIX_C_SIGNED_COMPONENTS_KHR;
177 }
178 if matches!(mat_d.elem, Elem::Int(_, true)) {
179 operands |= CooperativeMatrixOperands::MATRIX_RESULT_SIGNED_COMPONENTS_KHR;
180 }
181
182 let mat_d_id = self
183 .cooperative_matrix_mul_add_khr(ty, None, mat_a_id, mat_b_id, mat_c_id, Some(operands))
184 .unwrap();
185
186 self.store(mat_d.id, mat_d_id, None, vec![]).unwrap();
187 }
188
189 fn compile_cast(&mut self, input: core::Variable, output: core::Variable) {
190 let input = self.compile_variable(input);
191 let output = self.compile_variable(output);
192
193 let input = self.matrix_var(&input).1;
194 let output = self.matrix_var(&output).1;
195
196 let input_ty = self.item(&input).id(self);
197 let output_ty = self.item(&output).id(self);
198
199 let fragment_id = self.load(input_ty, None, input.id, None, vec![]).unwrap();
200
201 let frag_new = self.f_convert(output_ty, None, fragment_id).unwrap();
202
203 self.store(output.id, frag_new, None, vec![]).unwrap();
204 }
205
206 fn matrix_var(&mut self, var: &Variable) -> (Id, Matrix) {
207 let id = match var {
208 Variable::CoopMatrix(id, _) => *id,
209 _ => unreachable!(),
210 };
211 let mat = self.state.matrices[&id];
212 (id, mat)
213 }
214
215 fn rows(&mut self, mat: &Matrix) -> u32 {
216 let rows = match mat.ident {
217 CooperativeMatrixUse::MatrixAKHR => mat.m,
218 CooperativeMatrixUse::MatrixBKHR => mat.k,
219 CooperativeMatrixUse::MatrixAccumulatorKHR => mat.m,
220 };
221 self.const_u32(rows)
222 }
223
224 fn columns(&mut self, mat: &Matrix) -> u32 {
225 let columns = match mat.ident {
226 CooperativeMatrixUse::MatrixAKHR => mat.k,
227 CooperativeMatrixUse::MatrixBKHR => mat.n,
228 CooperativeMatrixUse::MatrixAccumulatorKHR => mat.n,
229 };
230 self.const_u32(columns)
231 }
232
233 pub fn item(&mut self, mat: &Matrix) -> Item {
234 Item::CoopMatrix {
235 ty: mat.elem,
236 rows: self.rows(mat),
237 columns: self.columns(mat),
238 ident: mat.ident,
239 }
240 }
241
242 pub fn init_coop_matrix(&mut self, mat: core::Matrix, var: core::Variable) -> Matrix {
243 if mat.storage.elem_type() == ElemType::Float(core::FloatKind::BF16) {
244 self.capabilities
245 .insert(Capability::BFloat16CooperativeMatrixKHR);
246 }
247 if matches!(
248 mat.storage.elem_type(),
249 ElemType::Float(core::FloatKind::E5M2 | core::FloatKind::E4M3)
250 ) {
251 self.capabilities
252 .insert(Capability::Float8CooperativeMatrixEXT);
253 }
254
255 let elem = self.compile_type(core::Type::new(mat.storage)).elem();
256 let ident = match mat.ident {
257 core::MatrixIdent::A => CooperativeMatrixUse::MatrixAKHR,
258 core::MatrixIdent::B => CooperativeMatrixUse::MatrixBKHR,
259 core::MatrixIdent::Accumulator => CooperativeMatrixUse::MatrixAccumulatorKHR,
260 };
261 let layout = compile_layout(mat.layout);
262
263 let mut mat = Matrix {
264 id: 0,
265 ident,
266 m: mat.m,
267 n: mat.n,
268 k: mat.k,
269 elem,
270 layout,
271 };
272
273 let item = Item::Pointer(StorageClass::Function, Box::new(self.item(&mat)));
274 let ty = item.id(self);
275 mat.id = self.declare_function_variable(ty);
276 self.debug_var_name(mat.id, var);
277
278 mat
279 }
280}
281
282fn compile_layout(layout: MatrixLayout) -> Option<CooperativeMatrixLayout> {
283 match layout {
284 core::MatrixLayout::ColMajor => Some(CooperativeMatrixLayout::ColumnMajorKHR),
285 core::MatrixLayout::RowMajor => Some(CooperativeMatrixLayout::RowMajorKHR),
286 core::MatrixLayout::Undefined => None,
287 }
288}