1const OP_TYPE_COOPERATIVE_MATRIX_KHR: u32 = 4456;
22const OP_COOPERATIVE_MATRIX_LOAD_KHR: u32 = 4457;
24const OP_COOPERATIVE_MATRIX_STORE_KHR: u32 = 4458;
26const OP_COOPERATIVE_MATRIX_MUL_ADD_KHR: u32 = 4459;
28const CAPABILITY_COOPERATIVE_MATRIX_KHR: u32 = 6022;
32const CAPABILITY_SHADER: u32 = 1;
34const CAPABILITY_FLOAT16: u32 = 9;
36const ADDRESSING_MODEL_LOGICAL: u32 = 0;
39const MEMORY_MODEL_GLSL450: u32 = 1;
40
41const EXECUTION_MODEL_GLCOMPUTE: u32 = 5;
44const EXECUTION_MODE_LOCAL_SIZE: u32 = 17;
45
46const STORAGE_CLASS_STORAGE_BUFFER: u32 = 12;
49const STORAGE_CLASS_INPUT: u32 = 1;
50
51const DECORATION_DESCRIPTOR_SET: u32 = 34;
54const DECORATION_BINDING: u32 = 33;
55const DECORATION_BLOCK: u32 = 2;
56const DECORATION_BUILTIN: u32 = 11;
57const DECORATION_NON_WRITABLE: u32 = 24;
58const BUILTIN_WORKGROUP_ID: u32 = 26;
59
60const OP_EXTENSION: u32 = 10;
63const OP_CAPABILITY: u32 = 17;
64const OP_MEMORY_MODEL: u32 = 14;
65const OP_ENTRY_POINT: u32 = 15;
66const OP_EXECUTION_MODE: u32 = 16;
67const OP_DECORATE: u32 = 71;
68const OP_MEMBER_DECORATE: u32 = 72;
69const OP_TYPE_VOID: u32 = 19;
70const OP_TYPE_INT: u32 = 21;
71const OP_TYPE_FLOAT: u32 = 22;
72const OP_TYPE_POINTER: u32 = 32;
73const OP_TYPE_FUNCTION: u32 = 33;
74const OP_TYPE_STRUCT: u32 = 30;
75const OP_TYPE_RUNTIME_ARRAY: u32 = 29;
76const OP_CONSTANT: u32 = 43;
77const OP_FUNCTION: u32 = 54;
78const OP_FUNCTION_END: u32 = 56;
79const OP_VARIABLE: u32 = 59;
80const OP_LOAD: u32 = 61;
81const OP_ACCESS_CHAIN: u32 = 65;
82const OP_IN_BOUNDS_ACCESS_CHAIN: u32 = 66;
83const OP_LABEL: u32 = 248;
84const OP_RETURN: u32 = 253;
85const OP_COMPOSITE_EXTRACT: u32 = 81;
86const OP_I_MUL: u32 = 132;
87const OP_I_ADD: u32 = 128;
88
89const SCOPE_SUBGROUP: u32 = 3;
93
94const MATRIX_USE_A: u32 = 0;
96const MATRIX_USE_B: u32 = 1;
98const MATRIX_USE_ACCUMULATOR: u32 = 2;
100
101const MATRIX_LAYOUT_ROW_MAJOR: u32 = 0;
103
104const COOPERATIVE_MATRIX_OPERANDS_NONE: u32 = 0;
108
109const SPIRV_MAGIC: u32 = 0x07230203;
112const SPIRV_VERSION_1_6: u32 = 0x0001_0600;
114const SPIRV_GENERATOR: u32 = 0x000D_0003; #[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub struct XmxTileConfig {
127 pub m: u32,
129 pub n: u32,
131 pub k: u32,
133}
134
135impl XmxTileConfig {
136 pub const XE_HPC_FP16: Self = Self { m: 8, n: 16, k: 16 };
138
139 pub const XE_DEFAULT: Self = Self { m: 8, n: 16, k: 16 };
141
142 pub fn accum_elements(&self) -> u32 {
144 self.m * self.n
145 }
146}
147
148impl Default for XmxTileConfig {
149 fn default() -> Self {
150 Self::XE_HPC_FP16
151 }
152}
153
154struct XmxSpvModule {
157 words: Vec<u32>,
158 id_bound: u32,
159}
160
161impl XmxSpvModule {
162 fn new() -> Self {
163 let words = vec![SPIRV_MAGIC, SPIRV_VERSION_1_6, SPIRV_GENERATOR, 0, 0];
164 Self { words, id_bound: 1 }
165 }
166
167 fn alloc_id(&mut self) -> u32 {
168 let id = self.id_bound;
169 self.id_bound += 1;
170 id
171 }
172
173 fn emit(&mut self, opcode: u32, operands: &[u32]) {
174 let word_count = (1 + operands.len()) as u32;
175 self.words.push((word_count << 16) | opcode);
176 self.words.extend_from_slice(operands);
177 }
178
179 fn string_words(s: &str) -> Vec<u32> {
180 let bytes = s.as_bytes();
181 let padded_len = (bytes.len() + 4) & !3;
182 let mut out = vec![0u32; padded_len / 4];
183 for (i, &b) in bytes.iter().enumerate() {
184 out[i / 4] |= (b as u32) << ((i % 4) * 8);
185 }
186 out
187 }
188
189 fn finalize(mut self) -> Vec<u32> {
190 self.words[3] = self.id_bound;
191 self.words
192 }
193
194 fn emit_capability(&mut self, cap: u32) {
197 self.emit(OP_CAPABILITY, &[cap]);
198 }
199
200 fn emit_extension(&mut self, name: &str) {
201 let mut ops = Self::string_words(name);
202 let word_count = (1 + ops.len()) as u32;
204 self.words.push((word_count << 16) | OP_EXTENSION);
205 self.words.append(&mut ops);
206 }
207
208 fn emit_memory_model(&mut self, addr: u32, model: u32) {
209 self.emit(OP_MEMORY_MODEL, &[addr, model]);
210 }
211
212 fn emit_entry_point(&mut self, model: u32, func_id: u32, name: &str, interfaces: &[u32]) {
213 let mut ops = vec![model, func_id];
214 ops.extend(Self::string_words(name));
215 ops.extend_from_slice(interfaces);
216 self.emit(OP_ENTRY_POINT, &ops);
217 }
218
219 fn emit_execution_mode_local_size(&mut self, func_id: u32, x: u32, y: u32, z: u32) {
220 self.emit(
221 OP_EXECUTION_MODE,
222 &[func_id, EXECUTION_MODE_LOCAL_SIZE, x, y, z],
223 );
224 }
225
226 fn emit_decorate(&mut self, target: u32, decoration: u32, extra: &[u32]) {
227 let mut ops = vec![target, decoration];
228 ops.extend_from_slice(extra);
229 self.emit(OP_DECORATE, &ops);
230 }
231
232 fn emit_member_decorate(
233 &mut self,
234 struct_id: u32,
235 member: u32,
236 decoration: u32,
237 extra: &[u32],
238 ) {
239 let mut ops = vec![struct_id, member, decoration];
240 ops.extend_from_slice(extra);
241 self.emit(OP_MEMBER_DECORATE, &ops);
242 }
243
244 fn emit_type_void(&mut self, id: u32) {
245 self.emit(OP_TYPE_VOID, &[id]);
246 }
247 fn emit_type_int(&mut self, id: u32, width: u32, sign: u32) {
248 self.emit(OP_TYPE_INT, &[id, width, sign]);
249 }
250 fn emit_type_float(&mut self, id: u32, width: u32) {
251 self.emit(OP_TYPE_FLOAT, &[id, width]);
252 }
253 fn emit_type_ptr(&mut self, id: u32, sc: u32, pointee: u32) {
254 self.emit(OP_TYPE_POINTER, &[id, sc, pointee]);
255 }
256 fn emit_type_fn(&mut self, id: u32, ret: u32, params: &[u32]) {
257 let mut ops = vec![id, ret];
258 ops.extend_from_slice(params);
259 self.emit(OP_TYPE_FUNCTION, &ops);
260 }
261 fn emit_type_struct(&mut self, id: u32, members: &[u32]) {
262 let mut ops = vec![id];
263 ops.extend_from_slice(members);
264 self.emit(OP_TYPE_STRUCT, &ops);
265 }
266 fn emit_type_runtime_array(&mut self, id: u32, elem: u32) {
267 self.emit(OP_TYPE_RUNTIME_ARRAY, &[id, elem]);
268 }
269
270 fn emit_const_u32(&mut self, ty: u32, id: u32, val: u32) {
271 self.emit(OP_CONSTANT, &[ty, id, val]);
272 }
273 fn emit_variable(&mut self, ty: u32, id: u32, sc: u32) {
274 self.emit(OP_VARIABLE, &[ty, id, sc]);
275 }
276 fn emit_load(&mut self, ty: u32, id: u32, ptr: u32) {
277 self.emit(OP_LOAD, &[ty, id, ptr]);
278 }
279 fn emit_label(&mut self, id: u32) {
280 self.emit(OP_LABEL, &[id]);
281 }
282 fn emit_return(&mut self) {
283 self.emit(OP_RETURN, &[]);
284 }
285 fn emit_function_end(&mut self) {
286 self.emit(OP_FUNCTION_END, &[]);
287 }
288 fn emit_function(&mut self, ret_ty: u32, id: u32, ctrl: u32, fn_ty: u32) {
289 self.emit(OP_FUNCTION, &[ret_ty, id, ctrl, fn_ty]);
290 }
291 fn emit_i_add(&mut self, ty: u32, id: u32, a: u32, b: u32) {
292 self.emit(OP_I_ADD, &[ty, id, a, b]);
293 }
294 fn emit_i_mul(&mut self, ty: u32, id: u32, a: u32, b: u32) {
295 self.emit(OP_I_MUL, &[ty, id, a, b]);
296 }
297 fn emit_composite_extract(&mut self, ty: u32, id: u32, composite: u32, idx: u32) {
298 self.emit(OP_COMPOSITE_EXTRACT, &[ty, id, composite, idx]);
299 }
300
301 fn emit_access_chain(&mut self, ty: u32, id: u32, base: u32, indices: &[u32]) {
302 let mut ops = vec![ty, id, base];
303 ops.extend_from_slice(indices);
304 self.emit(OP_ACCESS_CHAIN, &ops);
305 }
306
307 fn emit_in_bounds_access_chain(&mut self, ty: u32, id: u32, base: u32, indices: &[u32]) {
308 let mut ops = vec![ty, id, base];
309 ops.extend_from_slice(indices);
310 self.emit(OP_IN_BOUNDS_ACCESS_CHAIN, &ops);
311 }
312
313 fn emit_type_cooperative_matrix(
317 &mut self,
318 id: u32,
319 component_type: u32,
320 scope: u32,
321 rows: u32,
322 cols: u32,
323 matrix_use: u32,
324 ) {
325 self.emit(
326 OP_TYPE_COOPERATIVE_MATRIX_KHR,
327 &[id, component_type, scope, rows, cols, matrix_use],
328 );
329 }
330
331 fn emit_coop_matrix_load(
333 &mut self,
334 result_ty: u32,
335 result: u32,
336 pointer: u32,
337 layout: u32,
338 stride: u32,
339 ) {
340 self.emit(
341 OP_COOPERATIVE_MATRIX_LOAD_KHR,
342 &[
343 result_ty,
344 result,
345 pointer,
346 layout,
347 stride,
348 COOPERATIVE_MATRIX_OPERANDS_NONE,
349 ],
350 );
351 }
352
353 fn emit_coop_matrix_store(&mut self, pointer: u32, object: u32, layout: u32, stride: u32) {
355 self.emit(
356 OP_COOPERATIVE_MATRIX_STORE_KHR,
357 &[
358 pointer,
359 object,
360 layout,
361 stride,
362 COOPERATIVE_MATRIX_OPERANDS_NONE,
363 ],
364 );
365 }
366
367 fn emit_coop_matrix_muladd(
369 &mut self,
370 result_ty: u32,
371 result: u32,
372 a: u32,
373 b: u32,
374 c: u32,
375 operands: u32,
376 ) {
377 self.emit(
378 OP_COOPERATIVE_MATRIX_MUL_ADD_KHR,
379 &[result_ty, result, a, b, c, operands],
380 );
381 }
382}
383
384pub fn gemm_xmx_spirv(tile: XmxTileConfig, wg_x: u32, wg_y: u32) -> Vec<u32> {
417 let mut m = XmxSpvModule::new();
418
419 m.emit_capability(CAPABILITY_SHADER);
421 m.emit_capability(CAPABILITY_COOPERATIVE_MATRIX_KHR);
422
423 m.emit_extension("SPV_KHR_cooperative_matrix");
425
426 m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
428
429 let ty_void = m.alloc_id();
431 let ty_u32 = m.alloc_id();
432 let ty_f32 = m.alloc_id();
433
434 let ty_rt_f32 = m.alloc_id(); let ty_rt_u32 = m.alloc_id(); let ty_sb_f32 = m.alloc_id(); let ty_sb_u32 = m.alloc_id(); let ty_ptr_sb_f32 = m.alloc_id();
440 let ty_ptr_sb_u32 = m.alloc_id();
441 let ty_ptr_f32_sb = m.alloc_id();
442 let ty_ptr_u32_sb = m.alloc_id();
443
444 let ty_cmat_a = m.alloc_id(); let ty_cmat_b = m.alloc_id(); let ty_cmat_c = m.alloc_id(); let ty_fn_void = m.alloc_id();
451
452 let ty_v3u32 = m.alloc_id();
454 let ty_ptr_in_v3u32 = m.alloc_id();
455
456 let c0 = m.alloc_id();
458 let c1 = m.alloc_id();
459 let c_tile_m = m.alloc_id();
460 let c_tile_n = m.alloc_id();
461 let c_tile_k = m.alloc_id();
462
463 let var_a = m.alloc_id();
465 let var_b = m.alloc_id();
466 let var_c = m.alloc_id();
467 let var_dim = m.alloc_id();
468
469 let var_wg_id = m.alloc_id();
471
472 let fn_main = m.alloc_id();
474 let lbl_entry = m.alloc_id();
475
476 m.emit_entry_point(
478 EXECUTION_MODEL_GLCOMPUTE,
479 fn_main,
480 "gemm_xmx_f32",
481 &[var_a, var_b, var_c, var_dim, var_wg_id],
482 );
483 m.emit_execution_mode_local_size(fn_main, wg_x, wg_y, 1);
484
485 m.emit_decorate(ty_rt_f32, 6 , &[4]);
487 m.emit_decorate(ty_rt_u32, 6 , &[4]);
488
489 m.emit_decorate(ty_sb_f32, DECORATION_BLOCK, &[]);
490 m.emit_decorate(ty_sb_u32, DECORATION_BLOCK, &[]);
491
492 m.emit_member_decorate(ty_sb_f32, 0, 35 , &[0]);
493 m.emit_member_decorate(ty_sb_u32, 0, 35 , &[0]);
494
495 m.emit_decorate(var_a, DECORATION_DESCRIPTOR_SET, &[0]);
496 m.emit_decorate(var_a, DECORATION_BINDING, &[0]);
497 m.emit_decorate(var_a, DECORATION_NON_WRITABLE, &[]);
498 m.emit_decorate(var_b, DECORATION_DESCRIPTOR_SET, &[0]);
499 m.emit_decorate(var_b, DECORATION_BINDING, &[1]);
500 m.emit_decorate(var_b, DECORATION_NON_WRITABLE, &[]);
501 m.emit_decorate(var_c, DECORATION_DESCRIPTOR_SET, &[0]);
502 m.emit_decorate(var_c, DECORATION_BINDING, &[2]);
503 m.emit_decorate(var_dim, DECORATION_DESCRIPTOR_SET, &[0]);
504 m.emit_decorate(var_dim, DECORATION_BINDING, &[3]);
505 m.emit_decorate(var_dim, DECORATION_NON_WRITABLE, &[]);
506 m.emit_decorate(var_wg_id, DECORATION_BUILTIN, &[BUILTIN_WORKGROUP_ID]);
507
508 m.emit_type_void(ty_void);
510 m.emit_type_int(ty_u32, 32, 0);
511 m.emit_type_float(ty_f32, 32);
512
513 m.emit_type_runtime_array(ty_rt_f32, ty_f32);
514 m.emit_type_runtime_array(ty_rt_u32, ty_u32);
515 m.emit_type_struct(ty_sb_f32, &[ty_rt_f32]);
516 m.emit_type_struct(ty_sb_u32, &[ty_rt_u32]);
517 m.emit_type_ptr(ty_ptr_sb_f32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_f32);
518 m.emit_type_ptr(ty_ptr_sb_u32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_u32);
519 m.emit_type_ptr(ty_ptr_f32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_f32);
520 m.emit_type_ptr(ty_ptr_u32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_u32);
521
522 m.emit_type_cooperative_matrix(
524 ty_cmat_a,
525 ty_f32,
526 SCOPE_SUBGROUP,
527 tile.m,
528 tile.k,
529 MATRIX_USE_A,
530 );
531 m.emit_type_cooperative_matrix(
532 ty_cmat_b,
533 ty_f32,
534 SCOPE_SUBGROUP,
535 tile.k,
536 tile.n,
537 MATRIX_USE_B,
538 );
539 m.emit_type_cooperative_matrix(
540 ty_cmat_c,
541 ty_f32,
542 SCOPE_SUBGROUP,
543 tile.m,
544 tile.n,
545 MATRIX_USE_ACCUMULATOR,
546 );
547
548 let ty_v3u32_actual = ty_v3u32;
550 m.emit(30 , &[ty_v3u32_actual, ty_u32, 3]);
551 m.emit_type_ptr(ty_ptr_in_v3u32, STORAGE_CLASS_INPUT, ty_v3u32_actual);
552
553 m.emit_type_fn(ty_fn_void, ty_void, &[]);
554
555 m.emit_const_u32(ty_u32, c0, 0);
557 m.emit_const_u32(ty_u32, c1, 1);
558 m.emit_const_u32(ty_u32, c_tile_m, tile.m);
559 m.emit_const_u32(ty_u32, c_tile_n, tile.n);
560 m.emit_const_u32(ty_u32, c_tile_k, tile.k);
561
562 m.emit_variable(ty_ptr_sb_f32, var_a, STORAGE_CLASS_STORAGE_BUFFER);
564 m.emit_variable(ty_ptr_sb_f32, var_b, STORAGE_CLASS_STORAGE_BUFFER);
565 m.emit_variable(ty_ptr_sb_f32, var_c, STORAGE_CLASS_STORAGE_BUFFER);
566 m.emit_variable(ty_ptr_sb_u32, var_dim, STORAGE_CLASS_STORAGE_BUFFER);
567 m.emit_variable(ty_ptr_in_v3u32, var_wg_id, STORAGE_CLASS_INPUT);
568
569 m.emit_function(ty_void, fn_main, 0, ty_fn_void);
571 m.emit_label(lbl_entry);
572
573 let wg_id = m.alloc_id();
575 m.emit_load(ty_v3u32_actual, wg_id, var_wg_id);
576
577 let wg_col = m.alloc_id();
579 let wg_row = m.alloc_id();
580 m.emit_composite_extract(ty_u32, wg_col, wg_id, 0);
581 m.emit_composite_extract(ty_u32, wg_row, wg_id, 1);
582
583 let ptr_m = m.alloc_id();
585 let ptr_n = m.alloc_id();
586 let ptr_k = m.alloc_id();
587 let dim_m = m.alloc_id();
588 let dim_n = m.alloc_id();
589 let dim_k = m.alloc_id();
590 m.emit_access_chain(ty_ptr_u32_sb, ptr_m, var_dim, &[c0, c0]);
591 m.emit_access_chain(ty_ptr_u32_sb, ptr_n, var_dim, &[c0, c1]);
592 let c2 = m.alloc_id();
593 m.emit_const_u32(ty_u32, c2, 2);
594 m.emit_access_chain(ty_ptr_u32_sb, ptr_k, var_dim, &[c0, c2]);
595 m.emit_load(ty_u32, dim_m, ptr_m);
596 m.emit_load(ty_u32, dim_n, ptr_n);
597 m.emit_load(ty_u32, dim_k, ptr_k);
598
599 let row_base = m.alloc_id();
601 let col_base = m.alloc_id();
602 m.emit_i_mul(ty_u32, row_base, wg_row, c_tile_m);
603 m.emit_i_mul(ty_u32, col_base, wg_col, c_tile_n);
604
605 let c_row_stride = dim_n; let c_base_flat = m.alloc_id();
609 let c_base_tmp = m.alloc_id();
610 m.emit_i_mul(ty_u32, c_base_tmp, row_base, c_row_stride);
611 m.emit_i_add(ty_u32, c_base_flat, c_base_tmp, col_base);
612 let ptr_c_tile = m.alloc_id();
613 m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_c_tile, var_c, &[c0, c_base_flat]);
614
615 let mat_c_init = m.alloc_id();
616 m.emit_coop_matrix_load(
617 ty_cmat_c,
618 mat_c_init,
619 ptr_c_tile,
620 MATRIX_LAYOUT_ROW_MAJOR,
621 c_row_stride,
622 );
623
624 let mat_acc_after = {
629 let a_base_flat = m.alloc_id();
639 m.emit_i_mul(ty_u32, a_base_flat, row_base, dim_k);
640 let ptr_a_tile = m.alloc_id();
641 m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_a_tile, var_a, &[c0, a_base_flat]);
642 let mat_a = m.alloc_id();
643 m.emit_coop_matrix_load(ty_cmat_a, mat_a, ptr_a_tile, MATRIX_LAYOUT_ROW_MAJOR, dim_k);
644
645 let ptr_b_tile = m.alloc_id();
647 m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_b_tile, var_b, &[c0, col_base]);
648 let mat_b = m.alloc_id();
649 m.emit_coop_matrix_load(ty_cmat_b, mat_b, ptr_b_tile, MATRIX_LAYOUT_ROW_MAJOR, dim_n);
650
651 let mat_tmp = m.alloc_id();
653 m.emit_coop_matrix_muladd(
654 ty_cmat_c,
655 mat_tmp,
656 mat_a,
657 mat_b,
658 mat_c_init,
659 COOPERATIVE_MATRIX_OPERANDS_NONE,
660 );
661 mat_tmp
662 };
663
664 m.emit_coop_matrix_store(
666 ptr_c_tile,
667 mat_acc_after,
668 MATRIX_LAYOUT_ROW_MAJOR,
669 c_row_stride,
670 );
671
672 m.emit_return();
673 m.emit_function_end();
674
675 m.finalize()
676}
677
678pub fn gemm_xmx_f16_spirv(tile: XmxTileConfig, wg_x: u32, wg_y: u32) -> Vec<u32> {
689 let mut m = XmxSpvModule::new();
690
691 m.emit_capability(CAPABILITY_SHADER);
693 m.emit_capability(CAPABILITY_FLOAT16);
694 m.emit_capability(CAPABILITY_COOPERATIVE_MATRIX_KHR);
695 m.emit_extension("SPV_KHR_cooperative_matrix");
696 m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
697
698 let ty_void = m.alloc_id();
700 let ty_u32 = m.alloc_id();
701 let ty_f16 = m.alloc_id();
702 let ty_f32 = m.alloc_id();
703
704 let ty_rt_f16 = m.alloc_id();
705 let ty_rt_f32 = m.alloc_id();
706 let ty_rt_u32 = m.alloc_id();
707 let ty_sb_f16 = m.alloc_id();
708 let ty_sb_f32 = m.alloc_id();
709 let ty_sb_u32 = m.alloc_id();
710 let ty_ptr_sb_f16 = m.alloc_id();
711 let ty_ptr_sb_f32 = m.alloc_id();
712 let ty_ptr_sb_u32 = m.alloc_id();
713 let ty_ptr_f16_sb = m.alloc_id();
714 let ty_ptr_f32_sb = m.alloc_id();
715 let ty_ptr_u32_sb = m.alloc_id();
716
717 let ty_cmat_a = m.alloc_id();
719 let ty_cmat_b = m.alloc_id();
720 let ty_cmat_c = m.alloc_id();
721
722 let ty_v3u32 = m.alloc_id();
723 let ty_ptr_in_v3u32 = m.alloc_id();
724 let ty_fn_void = m.alloc_id();
725
726 let var_a = m.alloc_id();
728 let var_b = m.alloc_id();
729 let var_c = m.alloc_id();
730 let var_dim = m.alloc_id();
731 let var_wg = m.alloc_id();
732 let fn_main = m.alloc_id();
733 let lbl = m.alloc_id();
734
735 m.emit_entry_point(
737 EXECUTION_MODEL_GLCOMPUTE,
738 fn_main,
739 "gemm_xmx_f16",
740 &[var_a, var_b, var_c, var_dim, var_wg],
741 );
742 m.emit_execution_mode_local_size(fn_main, wg_x, wg_y, 1);
743
744 m.emit_decorate(ty_rt_f16, 6, &[2]); m.emit_decorate(ty_rt_f32, 6, &[4]);
747 m.emit_decorate(ty_rt_u32, 6, &[4]);
748 m.emit_decorate(ty_sb_f16, DECORATION_BLOCK, &[]);
749 m.emit_decorate(ty_sb_f32, DECORATION_BLOCK, &[]);
750 m.emit_decorate(ty_sb_u32, DECORATION_BLOCK, &[]);
751 m.emit_member_decorate(ty_sb_f16, 0, 35, &[0]);
752 m.emit_member_decorate(ty_sb_f32, 0, 35, &[0]);
753 m.emit_member_decorate(ty_sb_u32, 0, 35, &[0]);
754 for (var, set, binding, writable) in [
755 (var_a, 0u32, 0u32, false),
756 (var_b, 0, 1, false),
757 (var_c, 0, 2, true),
758 (var_dim, 0, 3, false),
759 ] {
760 m.emit_decorate(var, DECORATION_DESCRIPTOR_SET, &[set]);
761 m.emit_decorate(var, DECORATION_BINDING, &[binding]);
762 if !writable {
763 m.emit_decorate(var, DECORATION_NON_WRITABLE, &[]);
764 }
765 }
766 m.emit_decorate(var_wg, DECORATION_BUILTIN, &[BUILTIN_WORKGROUP_ID]);
767
768 m.emit_type_void(ty_void);
770 m.emit_type_int(ty_u32, 32, 0);
771 m.emit_type_float(ty_f16, 16);
772 m.emit_type_float(ty_f32, 32);
773 m.emit_type_runtime_array(ty_rt_f16, ty_f16);
774 m.emit_type_runtime_array(ty_rt_f32, ty_f32);
775 m.emit_type_runtime_array(ty_rt_u32, ty_u32);
776 m.emit_type_struct(ty_sb_f16, &[ty_rt_f16]);
777 m.emit_type_struct(ty_sb_f32, &[ty_rt_f32]);
778 m.emit_type_struct(ty_sb_u32, &[ty_rt_u32]);
779 m.emit_type_ptr(ty_ptr_sb_f16, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_f16);
780 m.emit_type_ptr(ty_ptr_sb_f32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_f32);
781 m.emit_type_ptr(ty_ptr_sb_u32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_u32);
782 m.emit_type_ptr(ty_ptr_f16_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_f16);
783 m.emit_type_ptr(ty_ptr_f32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_f32);
784 m.emit_type_ptr(ty_ptr_u32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_u32);
785 m.emit_type_cooperative_matrix(
787 ty_cmat_a,
788 ty_f16,
789 SCOPE_SUBGROUP,
790 tile.m,
791 tile.k,
792 MATRIX_USE_A,
793 );
794 m.emit_type_cooperative_matrix(
795 ty_cmat_b,
796 ty_f16,
797 SCOPE_SUBGROUP,
798 tile.k,
799 tile.n,
800 MATRIX_USE_B,
801 );
802 m.emit_type_cooperative_matrix(
803 ty_cmat_c,
804 ty_f32,
805 SCOPE_SUBGROUP,
806 tile.m,
807 tile.n,
808 MATRIX_USE_ACCUMULATOR,
809 );
810 m.emit(30, &[ty_v3u32, ty_u32, 3]); m.emit_type_ptr(ty_ptr_in_v3u32, STORAGE_CLASS_INPUT, ty_v3u32);
812 m.emit_type_fn(ty_fn_void, ty_void, &[]);
813
814 let c0 = m.alloc_id();
816 m.emit_const_u32(ty_u32, c0, 0);
817 let c1 = m.alloc_id();
818 m.emit_const_u32(ty_u32, c1, 1);
819 let c2 = m.alloc_id();
820 m.emit_const_u32(ty_u32, c2, 2);
821 let c_tm = m.alloc_id();
822 m.emit_const_u32(ty_u32, c_tm, tile.m);
823 let c_tn = m.alloc_id();
824 m.emit_const_u32(ty_u32, c_tn, tile.n);
825 let c_tk = m.alloc_id();
826 m.emit_const_u32(ty_u32, c_tk, tile.k);
827
828 m.emit_variable(ty_ptr_sb_f16, var_a, STORAGE_CLASS_STORAGE_BUFFER);
830 m.emit_variable(ty_ptr_sb_f16, var_b, STORAGE_CLASS_STORAGE_BUFFER);
831 m.emit_variable(ty_ptr_sb_f32, var_c, STORAGE_CLASS_STORAGE_BUFFER);
832 m.emit_variable(ty_ptr_sb_u32, var_dim, STORAGE_CLASS_STORAGE_BUFFER);
833 m.emit_variable(ty_ptr_in_v3u32, var_wg, STORAGE_CLASS_INPUT);
834
835 m.emit_function(ty_void, fn_main, 0, ty_fn_void);
837 m.emit_label(lbl);
838
839 let wg_id = m.alloc_id();
840 m.emit_load(ty_v3u32, wg_id, var_wg);
841 let wg_col = m.alloc_id();
842 m.emit_composite_extract(ty_u32, wg_col, wg_id, 0);
843 let wg_row = m.alloc_id();
844 m.emit_composite_extract(ty_u32, wg_row, wg_id, 1);
845
846 let ptr_m = m.alloc_id();
847 m.emit_access_chain(ty_ptr_u32_sb, ptr_m, var_dim, &[c0, c0]);
848 let ptr_n = m.alloc_id();
849 m.emit_access_chain(ty_ptr_u32_sb, ptr_n, var_dim, &[c0, c1]);
850 let ptr_k = m.alloc_id();
851 m.emit_access_chain(ty_ptr_u32_sb, ptr_k, var_dim, &[c0, c2]);
852 let dim_m = m.alloc_id();
853 m.emit_load(ty_u32, dim_m, ptr_m);
854 let dim_n = m.alloc_id();
855 m.emit_load(ty_u32, dim_n, ptr_n);
856 let dim_k = m.alloc_id();
857 m.emit_load(ty_u32, dim_k, ptr_k);
858
859 let row_base = m.alloc_id();
860 m.emit_i_mul(ty_u32, row_base, wg_row, c_tm);
861 let col_base = m.alloc_id();
862 m.emit_i_mul(ty_u32, col_base, wg_col, c_tn);
863
864 let c_base_tmp = m.alloc_id();
866 m.emit_i_mul(ty_u32, c_base_tmp, row_base, dim_n);
867 let c_base_flat = m.alloc_id();
868 m.emit_i_add(ty_u32, c_base_flat, c_base_tmp, col_base);
869 let ptr_c_tile = m.alloc_id();
870 m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_c_tile, var_c, &[c0, c_base_flat]);
871 let mat_c_init = m.alloc_id();
872 m.emit_coop_matrix_load(
873 ty_cmat_c,
874 mat_c_init,
875 ptr_c_tile,
876 MATRIX_LAYOUT_ROW_MAJOR,
877 dim_n,
878 );
879
880 let a_base = m.alloc_id();
882 m.emit_i_mul(ty_u32, a_base, row_base, dim_k);
883 let ptr_a = m.alloc_id();
884 m.emit_in_bounds_access_chain(ty_ptr_f16_sb, ptr_a, var_a, &[c0, a_base]);
885 let mat_a = m.alloc_id();
886 m.emit_coop_matrix_load(ty_cmat_a, mat_a, ptr_a, MATRIX_LAYOUT_ROW_MAJOR, dim_k);
887
888 let ptr_b = m.alloc_id();
890 m.emit_in_bounds_access_chain(ty_ptr_f16_sb, ptr_b, var_b, &[c0, col_base]);
891 let mat_b = m.alloc_id();
892 m.emit_coop_matrix_load(ty_cmat_b, mat_b, ptr_b, MATRIX_LAYOUT_ROW_MAJOR, dim_n);
893
894 let mat_out = m.alloc_id();
896 m.emit_coop_matrix_muladd(
897 ty_cmat_c,
898 mat_out,
899 mat_a,
900 mat_b,
901 mat_c_init,
902 COOPERATIVE_MATRIX_OPERANDS_NONE,
903 );
904
905 m.emit_coop_matrix_store(ptr_c_tile, mat_out, MATRIX_LAYOUT_ROW_MAJOR, dim_n);
907
908 m.emit_return();
909 m.emit_function_end();
910 m.finalize()
911}
912
913pub fn matmul_xmx_bf16_spirv(tile: XmxTileConfig, wg_x: u32, wg_y: u32) -> Vec<u32> {
925 let mut words = gemm_xmx_f16_spirv(tile, wg_x, wg_y);
940
941 let old = b"gemm_xmx_f16\0\0\0\0";
948 let new = b"matmul_xmx_bf\0\0\0"; patch_entry_point_name(&mut words, old, new);
950
951 words
952}
953
954fn patch_entry_point_name(words: &mut [u32], old: &[u8; 16], new: &[u8; 16]) {
959 let old_words = [
960 u32::from_le_bytes([old[0], old[1], old[2], old[3]]),
961 u32::from_le_bytes([old[4], old[5], old[6], old[7]]),
962 u32::from_le_bytes([old[8], old[9], old[10], old[11]]),
963 u32::from_le_bytes([old[12], old[13], old[14], old[15]]),
964 ];
965 let new_words = [
966 u32::from_le_bytes([new[0], new[1], new[2], new[3]]),
967 u32::from_le_bytes([new[4], new[5], new[6], new[7]]),
968 u32::from_le_bytes([new[8], new[9], new[10], new[11]]),
969 u32::from_le_bytes([new[12], new[13], new[14], new[15]]),
970 ];
971 'outer: for i in 0..words.len().saturating_sub(3) {
972 for (j, &ow) in old_words.iter().enumerate() {
973 if words[i + j] != ow {
974 continue 'outer;
975 }
976 }
977 for (j, &nw) in new_words.iter().enumerate() {
978 words[i + j] = nw;
979 }
980 break;
981 }
982}
983
984pub fn device_supports_xmx(device_name: &str) -> bool {
996 let name = device_name.to_ascii_lowercase();
997 name.contains("arc")
999 || name.contains("data center gpu max")
1001 || name.contains("ponte vecchio")
1002 || name.contains("max 1")
1003 || name.contains("max 12")
1004 || name.contains("iris xe")
1006 || name.contains("uhd graphics")
1007}
1008
1009pub fn best_xmx_tile(device_name: &str) -> XmxTileConfig {
1013 let name = device_name.to_ascii_lowercase();
1014 if name.contains("max") || name.contains("ponte vecchio") {
1015 XmxTileConfig { m: 8, n: 32, k: 16 }
1017 } else if name.contains("arc") || name.contains("iris xe") {
1018 XmxTileConfig::XE_HPC_FP16
1019 } else {
1020 XmxTileConfig::XE_DEFAULT
1021 }
1022}
1023
1024#[cfg(test)]
1027mod tests {
1028 use super::*;
1029
1030 #[test]
1031 fn gemm_xmx_spirv_starts_with_magic() {
1032 let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
1033 assert!(!words.is_empty(), "output must not be empty");
1034 assert_eq!(words[0], 0x07230203, "first word must be SPIR-V magic");
1035 }
1036
1037 #[test]
1038 fn gemm_xmx_spirv_version_1_6() {
1039 let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
1040 assert_eq!(words[1], 0x0001_0600, "version must be SPIR-V 1.6");
1041 }
1042
1043 #[test]
1044 fn gemm_xmx_spirv_id_bound_nonzero() {
1045 let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
1046 assert!(words[3] > 0, "ID bound must be > 0");
1047 }
1048
1049 #[test]
1050 fn gemm_xmx_f16_produces_valid_header() {
1051 let words = gemm_xmx_f16_spirv(XmxTileConfig::XE_HPC_FP16, 16, 16);
1052 assert_eq!(words[0], SPIRV_MAGIC);
1053 assert_eq!(words[1], SPIRV_VERSION_1_6);
1054 assert!(words.len() > 20, "module must have non-trivial content");
1055 }
1056
1057 #[test]
1058 fn matmul_xmx_bf16_produces_valid_header() {
1059 let words = matmul_xmx_bf16_spirv(XmxTileConfig::default(), 16, 16);
1060 assert_eq!(words[0], SPIRV_MAGIC);
1061 }
1062
1063 #[test]
1064 fn xmx_tile_accum_elements() {
1065 let tile = XmxTileConfig { m: 8, n: 16, k: 16 };
1066 assert_eq!(tile.accum_elements(), 128);
1067 }
1068
1069 #[test]
1070 fn device_supports_xmx_arc() {
1071 assert!(device_supports_xmx("Intel Arc A770 Graphics"));
1072 assert!(device_supports_xmx("Intel Data Center GPU Max 1550"));
1073 assert!(!device_supports_xmx("AMD Radeon RX 7900 XTX"));
1074 }
1075
1076 #[test]
1077 fn best_xmx_tile_xe_hpc() {
1078 let tile = best_xmx_tile("Intel Data Center GPU Max 1550");
1079 assert_eq!(tile.m, 8);
1080 assert_eq!(tile.n, 32);
1081 }
1082
1083 #[test]
1084 fn different_tile_sizes_produce_different_binaries() {
1085 let a = gemm_xmx_spirv(XmxTileConfig { m: 8, n: 16, k: 16 }, 16, 16);
1086 let b = gemm_xmx_spirv(XmxTileConfig { m: 8, n: 32, k: 16 }, 16, 16);
1087 assert_ne!(
1088 a, b,
1089 "different tile configurations must yield distinct SPIR-V"
1090 );
1091 }
1092
1093 #[test]
1094 fn gemm_xmx_spirv_contains_cooperative_matrix_opcode() {
1095 let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
1096 let has_cmat = words
1098 .iter()
1099 .any(|&w| (w & 0xFFFF) == OP_TYPE_COOPERATIVE_MATRIX_KHR);
1100 assert!(has_cmat, "module must declare OpTypeCooperativeMatrixKHR");
1101 }
1102
1103 #[test]
1104 fn gemm_xmx_f16_contains_float16_type() {
1105 let words = gemm_xmx_f16_spirv(XmxTileConfig::XE_HPC_FP16, 16, 16);
1106 let has_f16 = words.windows(3).any(|w| {
1108 (w[0] & 0xFFFF) == 22 && w[2] == 16 });
1110 assert!(has_f16, "FP16 module must declare 16-bit float type");
1111 }
1112}