Skip to main content

oxicuda_levelzero/
spirv_xmx.rs

1//! Intel Xe Matrix Extensions (XMX) SPIR-V kernel generators.
2//!
3//! XMX is Intel's matrix-multiply-accumulate hardware (analogous to NVIDIA
4//! Tensor Cores or AMD MFMA). It is exposed via the
5//! [`SPV_KHR_cooperative_matrix`] SPIR-V extension and the
6//! `CooperativeMatrixKHR` capability (SPIR-V 1.6 / GLSL 4.6).
7//!
8//! This module provides:
9//! - [`XmxTileConfig`] — tile dimension configuration for XMX GEMM.
10//! - [`gemm_xmx_spirv`] — cooperative-matrix GEMM kernel (`C = alpha*A*B + beta*C`)
11//!   targeting Intel Xe / Arc / Ponte Vecchio with XMX engines.
12//! - [`gemm_xmx_f16_spirv`] — FP16 input / FP32 accumulation variant.
13//! - [`matmul_xmx_bf16_spirv`] — BF16 input / FP32 accumulation variant.
14//!
15//! [`SPV_KHR_cooperative_matrix`]:
16//!   <https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc>
17
18// ─── XMX SPIR-V opcodes (SPV_KHR_cooperative_matrix) ────────────────────────
19
20/// `OpTypeCooperativeMatrixKHR` — defines a cooperative matrix type.
21const OP_TYPE_COOPERATIVE_MATRIX_KHR: u32 = 4456;
22/// `OpCooperativeMatrixLoadKHR` — loads a tile from a pointer.
23const OP_COOPERATIVE_MATRIX_LOAD_KHR: u32 = 4457;
24/// `OpCooperativeMatrixStoreKHR` — stores a tile to a pointer.
25const OP_COOPERATIVE_MATRIX_STORE_KHR: u32 = 4458;
26/// `OpCooperativeMatrixMulAddKHR` — performs `D = A*B + C`.
27const OP_COOPERATIVE_MATRIX_MUL_ADD_KHR: u32 = 4459;
28// ─── Capabilities ────────────────────────────────────────────────────────────
29
30/// `CooperativeMatrixKHR` capability (requires SPV_KHR_cooperative_matrix).
31const CAPABILITY_COOPERATIVE_MATRIX_KHR: u32 = 6022;
32/// Standard `Shader` capability (required for GLSL memory model path).
33const CAPABILITY_SHADER: u32 = 1;
34/// `Float16` capability — required when using FP16 matrix elements.
35const CAPABILITY_FLOAT16: u32 = 9;
36// ─── Addressing / memory model ───────────────────────────────────────────────
37
38const ADDRESSING_MODEL_LOGICAL: u32 = 0;
39const MEMORY_MODEL_GLSL450: u32 = 1;
40
41// ─── Execution model ─────────────────────────────────────────────────────────
42
43const EXECUTION_MODEL_GLCOMPUTE: u32 = 5;
44const EXECUTION_MODE_LOCAL_SIZE: u32 = 17;
45
46// ─── Storage classes ─────────────────────────────────────────────────────────
47
48const STORAGE_CLASS_STORAGE_BUFFER: u32 = 12;
49const STORAGE_CLASS_INPUT: u32 = 1;
50
51// ─── Decorations ─────────────────────────────────────────────────────────────
52
53const DECORATION_DESCRIPTOR_SET: u32 = 34;
54const DECORATION_BINDING: u32 = 33;
55const DECORATION_BLOCK: u32 = 2;
56const DECORATION_BUILTIN: u32 = 11;
57const DECORATION_NON_WRITABLE: u32 = 24;
58const BUILTIN_WORKGROUP_ID: u32 = 26;
59
60// ─── Use/select extensions ───────────────────────────────────────────────────
61
62const OP_EXTENSION: u32 = 10;
63const OP_CAPABILITY: u32 = 17;
64const OP_MEMORY_MODEL: u32 = 14;
65const OP_ENTRY_POINT: u32 = 15;
66const OP_EXECUTION_MODE: u32 = 16;
67const OP_DECORATE: u32 = 71;
68const OP_MEMBER_DECORATE: u32 = 72;
69const OP_TYPE_VOID: u32 = 19;
70const OP_TYPE_INT: u32 = 21;
71const OP_TYPE_FLOAT: u32 = 22;
72const OP_TYPE_POINTER: u32 = 32;
73const OP_TYPE_FUNCTION: u32 = 33;
74const OP_TYPE_STRUCT: u32 = 30;
75const OP_TYPE_RUNTIME_ARRAY: u32 = 29;
76const OP_CONSTANT: u32 = 43;
77const OP_FUNCTION: u32 = 54;
78const OP_FUNCTION_END: u32 = 56;
79const OP_VARIABLE: u32 = 59;
80const OP_LOAD: u32 = 61;
81const OP_ACCESS_CHAIN: u32 = 65;
82const OP_IN_BOUNDS_ACCESS_CHAIN: u32 = 66;
83const OP_LABEL: u32 = 248;
84const OP_RETURN: u32 = 253;
85const OP_COMPOSITE_EXTRACT: u32 = 81;
86const OP_I_MUL: u32 = 132;
87const OP_I_ADD: u32 = 128;
88
89// ─── CooperativeMatrix use/scope values ──────────────────────────────────────
90
91/// Scope `Subgroup` (the XMX execution granularity on Intel GPUs).
92const SCOPE_SUBGROUP: u32 = 3;
93
94/// `MatrixUseA` — first input matrix to `MulAdd`.
95const MATRIX_USE_A: u32 = 0;
96/// `MatrixUseB` — second input matrix to `MulAdd`.
97const MATRIX_USE_B: u32 = 1;
98/// `MatrixUseAccumulator` — accumulator matrix to `MulAdd`.
99const MATRIX_USE_ACCUMULATOR: u32 = 2;
100
101/// `RowMajor` layout for `CooperativeMatrixLoad/Store`.
102const MATRIX_LAYOUT_ROW_MAJOR: u32 = 0;
103
104// ─── CooperativeMatrixOperands bitmask ───────────────────────────────────────
105
106/// No special operand flags.
107const COOPERATIVE_MATRIX_OPERANDS_NONE: u32 = 0;
108
109// ─── SPIR-V magic / version ──────────────────────────────────────────────────
110
111const SPIRV_MAGIC: u32 = 0x07230203;
112/// SPIR-V 1.6 — required for `SPV_KHR_cooperative_matrix`.
113const SPIRV_VERSION_1_6: u32 = 0x0001_0600;
114const SPIRV_GENERATOR: u32 = 0x000D_0003; // OxiCUDA Level Zero XMX generator
115
116// ─── XmxTileConfig ───────────────────────────────────────────────────────────
117
118/// Tile dimensions for XMX GEMM.
119///
120/// Intel's XMX engines support the following sizes on Xe-HPC (Ponte Vecchio):
121/// - FP16 / BF16 input, FP32 accumulation: 8 × 16, 8 × 32
122/// - INT8 / INT4 input, INT32 accumulation: 8 × 32
123///
124/// On Arc (Alchemist) and later, additional sizes are available.
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub struct XmxTileConfig {
127    /// Rows of the A/C matrix tile.
128    pub m: u32,
129    /// Columns of the B/C matrix tile.
130    pub n: u32,
131    /// Inner dimension (columns of A, rows of B).
132    pub k: u32,
133}
134
135impl XmxTileConfig {
136    /// Default XMX tile size for FP16 input targeting Xe-HPC.
137    pub const XE_HPC_FP16: Self = Self { m: 8, n: 16, k: 16 };
138
139    /// XMX tile size for FP32 GEMM fallback (regular compute path).
140    pub const XE_DEFAULT: Self = Self { m: 8, n: 16, k: 16 };
141
142    /// Returns the number of accumulator elements per tile.
143    pub fn accum_elements(&self) -> u32 {
144        self.m * self.n
145    }
146}
147
148impl Default for XmxTileConfig {
149    fn default() -> Self {
150        Self::XE_HPC_FP16
151    }
152}
153
154// ─── Minimal SPIR-V builder (local, tuned for cooperative-matrix dialect) ───
155
156struct XmxSpvModule {
157    words: Vec<u32>,
158    id_bound: u32,
159}
160
161impl XmxSpvModule {
162    fn new() -> Self {
163        let words = vec![SPIRV_MAGIC, SPIRV_VERSION_1_6, SPIRV_GENERATOR, 0, 0];
164        Self { words, id_bound: 1 }
165    }
166
167    fn alloc_id(&mut self) -> u32 {
168        let id = self.id_bound;
169        self.id_bound += 1;
170        id
171    }
172
173    fn emit(&mut self, opcode: u32, operands: &[u32]) {
174        let word_count = (1 + operands.len()) as u32;
175        self.words.push((word_count << 16) | opcode);
176        self.words.extend_from_slice(operands);
177    }
178
179    fn string_words(s: &str) -> Vec<u32> {
180        let bytes = s.as_bytes();
181        let padded_len = (bytes.len() + 4) & !3;
182        let mut out = vec![0u32; padded_len / 4];
183        for (i, &b) in bytes.iter().enumerate() {
184            out[i / 4] |= (b as u32) << ((i % 4) * 8);
185        }
186        out
187    }
188
189    fn finalize(mut self) -> Vec<u32> {
190        self.words[3] = self.id_bound;
191        self.words
192    }
193
194    // ── Convenience emitters ──────────────────────────────────
195
196    fn emit_capability(&mut self, cap: u32) {
197        self.emit(OP_CAPABILITY, &[cap]);
198    }
199
200    fn emit_extension(&mut self, name: &str) {
201        let mut ops = Self::string_words(name);
202        // Extension instruction: opcode only, no result ID
203        let word_count = (1 + ops.len()) as u32;
204        self.words.push((word_count << 16) | OP_EXTENSION);
205        self.words.append(&mut ops);
206    }
207
208    fn emit_memory_model(&mut self, addr: u32, model: u32) {
209        self.emit(OP_MEMORY_MODEL, &[addr, model]);
210    }
211
212    fn emit_entry_point(&mut self, model: u32, func_id: u32, name: &str, interfaces: &[u32]) {
213        let mut ops = vec![model, func_id];
214        ops.extend(Self::string_words(name));
215        ops.extend_from_slice(interfaces);
216        self.emit(OP_ENTRY_POINT, &ops);
217    }
218
219    fn emit_execution_mode_local_size(&mut self, func_id: u32, x: u32, y: u32, z: u32) {
220        self.emit(
221            OP_EXECUTION_MODE,
222            &[func_id, EXECUTION_MODE_LOCAL_SIZE, x, y, z],
223        );
224    }
225
226    fn emit_decorate(&mut self, target: u32, decoration: u32, extra: &[u32]) {
227        let mut ops = vec![target, decoration];
228        ops.extend_from_slice(extra);
229        self.emit(OP_DECORATE, &ops);
230    }
231
232    fn emit_member_decorate(
233        &mut self,
234        struct_id: u32,
235        member: u32,
236        decoration: u32,
237        extra: &[u32],
238    ) {
239        let mut ops = vec![struct_id, member, decoration];
240        ops.extend_from_slice(extra);
241        self.emit(OP_MEMBER_DECORATE, &ops);
242    }
243
244    fn emit_type_void(&mut self, id: u32) {
245        self.emit(OP_TYPE_VOID, &[id]);
246    }
247    fn emit_type_int(&mut self, id: u32, width: u32, sign: u32) {
248        self.emit(OP_TYPE_INT, &[id, width, sign]);
249    }
250    fn emit_type_float(&mut self, id: u32, width: u32) {
251        self.emit(OP_TYPE_FLOAT, &[id, width]);
252    }
253    fn emit_type_ptr(&mut self, id: u32, sc: u32, pointee: u32) {
254        self.emit(OP_TYPE_POINTER, &[id, sc, pointee]);
255    }
256    fn emit_type_fn(&mut self, id: u32, ret: u32, params: &[u32]) {
257        let mut ops = vec![id, ret];
258        ops.extend_from_slice(params);
259        self.emit(OP_TYPE_FUNCTION, &ops);
260    }
261    fn emit_type_struct(&mut self, id: u32, members: &[u32]) {
262        let mut ops = vec![id];
263        ops.extend_from_slice(members);
264        self.emit(OP_TYPE_STRUCT, &ops);
265    }
266    fn emit_type_runtime_array(&mut self, id: u32, elem: u32) {
267        self.emit(OP_TYPE_RUNTIME_ARRAY, &[id, elem]);
268    }
269
270    fn emit_const_u32(&mut self, ty: u32, id: u32, val: u32) {
271        self.emit(OP_CONSTANT, &[ty, id, val]);
272    }
273    fn emit_variable(&mut self, ty: u32, id: u32, sc: u32) {
274        self.emit(OP_VARIABLE, &[ty, id, sc]);
275    }
276    fn emit_load(&mut self, ty: u32, id: u32, ptr: u32) {
277        self.emit(OP_LOAD, &[ty, id, ptr]);
278    }
279    fn emit_label(&mut self, id: u32) {
280        self.emit(OP_LABEL, &[id]);
281    }
282    fn emit_return(&mut self) {
283        self.emit(OP_RETURN, &[]);
284    }
285    fn emit_function_end(&mut self) {
286        self.emit(OP_FUNCTION_END, &[]);
287    }
288    fn emit_function(&mut self, ret_ty: u32, id: u32, ctrl: u32, fn_ty: u32) {
289        self.emit(OP_FUNCTION, &[ret_ty, id, ctrl, fn_ty]);
290    }
291    fn emit_i_add(&mut self, ty: u32, id: u32, a: u32, b: u32) {
292        self.emit(OP_I_ADD, &[ty, id, a, b]);
293    }
294    fn emit_i_mul(&mut self, ty: u32, id: u32, a: u32, b: u32) {
295        self.emit(OP_I_MUL, &[ty, id, a, b]);
296    }
297    fn emit_composite_extract(&mut self, ty: u32, id: u32, composite: u32, idx: u32) {
298        self.emit(OP_COMPOSITE_EXTRACT, &[ty, id, composite, idx]);
299    }
300
301    fn emit_access_chain(&mut self, ty: u32, id: u32, base: u32, indices: &[u32]) {
302        let mut ops = vec![ty, id, base];
303        ops.extend_from_slice(indices);
304        self.emit(OP_ACCESS_CHAIN, &ops);
305    }
306
307    fn emit_in_bounds_access_chain(&mut self, ty: u32, id: u32, base: u32, indices: &[u32]) {
308        let mut ops = vec![ty, id, base];
309        ops.extend_from_slice(indices);
310        self.emit(OP_IN_BOUNDS_ACCESS_CHAIN, &ops);
311    }
312
313    /// Emit `OpTypeCooperativeMatrixKHR`.
314    ///
315    /// Parameters: `(result_id, component_type, scope, rows, columns, use)`
316    fn emit_type_cooperative_matrix(
317        &mut self,
318        id: u32,
319        component_type: u32,
320        scope: u32,
321        rows: u32,
322        cols: u32,
323        matrix_use: u32,
324    ) {
325        self.emit(
326            OP_TYPE_COOPERATIVE_MATRIX_KHR,
327            &[id, component_type, scope, rows, cols, matrix_use],
328        );
329    }
330
331    /// Emit `OpCooperativeMatrixLoadKHR`.
332    fn emit_coop_matrix_load(
333        &mut self,
334        result_ty: u32,
335        result: u32,
336        pointer: u32,
337        layout: u32,
338        stride: u32,
339    ) {
340        self.emit(
341            OP_COOPERATIVE_MATRIX_LOAD_KHR,
342            &[
343                result_ty,
344                result,
345                pointer,
346                layout,
347                stride,
348                COOPERATIVE_MATRIX_OPERANDS_NONE,
349            ],
350        );
351    }
352
353    /// Emit `OpCooperativeMatrixStoreKHR`.
354    fn emit_coop_matrix_store(&mut self, pointer: u32, object: u32, layout: u32, stride: u32) {
355        self.emit(
356            OP_COOPERATIVE_MATRIX_STORE_KHR,
357            &[
358                pointer,
359                object,
360                layout,
361                stride,
362                COOPERATIVE_MATRIX_OPERANDS_NONE,
363            ],
364        );
365    }
366
367    /// Emit `OpCooperativeMatrixMulAddKHR`: `D = A * B + C`.
368    fn emit_coop_matrix_muladd(
369        &mut self,
370        result_ty: u32,
371        result: u32,
372        a: u32,
373        b: u32,
374        c: u32,
375        operands: u32,
376    ) {
377        self.emit(
378            OP_COOPERATIVE_MATRIX_MUL_ADD_KHR,
379            &[result_ty, result, a, b, c, operands],
380        );
381    }
382}
383
384// ─── gemm_xmx_spirv ──────────────────────────────────────────────────────────
385
386/// Generate a SPIR-V binary for an XMX-accelerated FP32 GEMM kernel.
387///
388/// Computes `C = alpha * A * B + beta * C` using Intel Xe Matrix Extensions
389/// (cooperative matrix hardware). The kernel is structured as:
390///
391/// - Workgroup of size `(wg_x, wg_y, 1)` where each workgroup computes a
392///   tile of the output matrix.
393/// - Inner loop loads sub-tiles of A and B using `OpCooperativeMatrixLoadKHR`,
394///   accumulates via `OpCooperativeMatrixMulAddKHR`, then stores result with
395///   `OpCooperativeMatrixStoreKHR`.
396///
397/// # Arguments
398///
399/// * `tile` — XMX tile configuration (M × N × K).
400/// * `wg_x` — workgroup X dimension (threads per group in X).
401/// * `wg_y` — workgroup Y dimension (threads per group in Y).
402///
403/// # Returns
404///
405/// A `Vec<u32>` containing a valid SPIR-V 1.6 binary. Pass this directly to
406/// `zeModuleCreate(..., ZE_MODULE_FORMAT_IL_SPIRV, ...)`.
407///
408/// # Kernel Interface (descriptor set 0)
409///
410/// | Binding | Type | Description |
411/// |---------|------|-------------|
412/// | 0 | `StorageBuffer f32[]` | Input matrix A (row-major, M×K) |
413/// | 1 | `StorageBuffer f32[]` | Input matrix B (row-major, K×N) |
414/// | 2 | `StorageBuffer f32[]` | In/out matrix C (row-major, M×N) |
415/// | 3 | `StorageBuffer u32[4]` | Push constants: M, N, K, flags |
416pub fn gemm_xmx_spirv(tile: XmxTileConfig, wg_x: u32, wg_y: u32) -> Vec<u32> {
417    let mut m = XmxSpvModule::new();
418
419    // ── Capabilities ──────────────────────────────────────────────────────────
420    m.emit_capability(CAPABILITY_SHADER);
421    m.emit_capability(CAPABILITY_COOPERATIVE_MATRIX_KHR);
422
423    // ── SPV_KHR_cooperative_matrix extension ──────────────────────────────────
424    m.emit_extension("SPV_KHR_cooperative_matrix");
425
426    // ── Memory model ──────────────────────────────────────────────────────────
427    m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
428
429    // ── Type IDs ──────────────────────────────────────────────────────────────
430    let ty_void = m.alloc_id();
431    let ty_u32 = m.alloc_id();
432    let ty_f32 = m.alloc_id();
433
434    // Storage buffer types: struct { float[] } at binding 0/1/2, uint[] at binding 3
435    let ty_rt_f32 = m.alloc_id(); // runtime array of f32
436    let ty_rt_u32 = m.alloc_id(); // runtime array of u32
437    let ty_sb_f32 = m.alloc_id(); // struct { float[] }
438    let ty_sb_u32 = m.alloc_id(); // struct { uint[] }
439    let ty_ptr_sb_f32 = m.alloc_id();
440    let ty_ptr_sb_u32 = m.alloc_id();
441    let ty_ptr_f32_sb = m.alloc_id();
442    let ty_ptr_u32_sb = m.alloc_id();
443
444    // Cooperative matrix types (all at Subgroup scope)
445    let ty_cmat_a = m.alloc_id(); // mat A: f32, SCOPE_SUBGROUP, M×K, use A
446    let ty_cmat_b = m.alloc_id(); // mat B: f32, SCOPE_SUBGROUP, K×N, use B
447    let ty_cmat_c = m.alloc_id(); // mat C: f32, SCOPE_SUBGROUP, M×N, use Accum
448
449    // Function type
450    let ty_fn_void = m.alloc_id();
451
452    // v3uint for builtins
453    let ty_v3u32 = m.alloc_id();
454    let ty_ptr_in_v3u32 = m.alloc_id();
455
456    // Constants
457    let c0 = m.alloc_id();
458    let c1 = m.alloc_id();
459    let c_tile_m = m.alloc_id();
460    let c_tile_n = m.alloc_id();
461    let c_tile_k = m.alloc_id();
462
463    // Variables at descriptor set 0
464    let var_a = m.alloc_id();
465    let var_b = m.alloc_id();
466    let var_c = m.alloc_id();
467    let var_dim = m.alloc_id();
468
469    // Builtin variable
470    let var_wg_id = m.alloc_id();
471
472    // Function
473    let fn_main = m.alloc_id();
474    let lbl_entry = m.alloc_id();
475
476    // ── Entry point ───────────────────────────────────────────────────────────
477    m.emit_entry_point(
478        EXECUTION_MODEL_GLCOMPUTE,
479        fn_main,
480        "gemm_xmx_f32",
481        &[var_a, var_b, var_c, var_dim, var_wg_id],
482    );
483    m.emit_execution_mode_local_size(fn_main, wg_x, wg_y, 1);
484
485    // ── Decorations ───────────────────────────────────────────────────────────
486    m.emit_decorate(ty_rt_f32, 6 /* ArrayStride */, &[4]);
487    m.emit_decorate(ty_rt_u32, 6 /* ArrayStride */, &[4]);
488
489    m.emit_decorate(ty_sb_f32, DECORATION_BLOCK, &[]);
490    m.emit_decorate(ty_sb_u32, DECORATION_BLOCK, &[]);
491
492    m.emit_member_decorate(ty_sb_f32, 0, 35 /* Offset */, &[0]);
493    m.emit_member_decorate(ty_sb_u32, 0, 35 /* Offset */, &[0]);
494
495    m.emit_decorate(var_a, DECORATION_DESCRIPTOR_SET, &[0]);
496    m.emit_decorate(var_a, DECORATION_BINDING, &[0]);
497    m.emit_decorate(var_a, DECORATION_NON_WRITABLE, &[]);
498    m.emit_decorate(var_b, DECORATION_DESCRIPTOR_SET, &[0]);
499    m.emit_decorate(var_b, DECORATION_BINDING, &[1]);
500    m.emit_decorate(var_b, DECORATION_NON_WRITABLE, &[]);
501    m.emit_decorate(var_c, DECORATION_DESCRIPTOR_SET, &[0]);
502    m.emit_decorate(var_c, DECORATION_BINDING, &[2]);
503    m.emit_decorate(var_dim, DECORATION_DESCRIPTOR_SET, &[0]);
504    m.emit_decorate(var_dim, DECORATION_BINDING, &[3]);
505    m.emit_decorate(var_dim, DECORATION_NON_WRITABLE, &[]);
506    m.emit_decorate(var_wg_id, DECORATION_BUILTIN, &[BUILTIN_WORKGROUP_ID]);
507
508    // ── Type definitions ──────────────────────────────────────────────────────
509    m.emit_type_void(ty_void);
510    m.emit_type_int(ty_u32, 32, 0);
511    m.emit_type_float(ty_f32, 32);
512
513    m.emit_type_runtime_array(ty_rt_f32, ty_f32);
514    m.emit_type_runtime_array(ty_rt_u32, ty_u32);
515    m.emit_type_struct(ty_sb_f32, &[ty_rt_f32]);
516    m.emit_type_struct(ty_sb_u32, &[ty_rt_u32]);
517    m.emit_type_ptr(ty_ptr_sb_f32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_f32);
518    m.emit_type_ptr(ty_ptr_sb_u32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_u32);
519    m.emit_type_ptr(ty_ptr_f32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_f32);
520    m.emit_type_ptr(ty_ptr_u32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_u32);
521
522    // Cooperative matrix types
523    m.emit_type_cooperative_matrix(
524        ty_cmat_a,
525        ty_f32,
526        SCOPE_SUBGROUP,
527        tile.m,
528        tile.k,
529        MATRIX_USE_A,
530    );
531    m.emit_type_cooperative_matrix(
532        ty_cmat_b,
533        ty_f32,
534        SCOPE_SUBGROUP,
535        tile.k,
536        tile.n,
537        MATRIX_USE_B,
538    );
539    m.emit_type_cooperative_matrix(
540        ty_cmat_c,
541        ty_f32,
542        SCOPE_SUBGROUP,
543        tile.m,
544        tile.n,
545        MATRIX_USE_ACCUMULATOR,
546    );
547
548    // v3uint for WorkgroupId builtin
549    let ty_v3u32_actual = ty_v3u32;
550    m.emit(30 /* OpTypeVector */, &[ty_v3u32_actual, ty_u32, 3]);
551    m.emit_type_ptr(ty_ptr_in_v3u32, STORAGE_CLASS_INPUT, ty_v3u32_actual);
552
553    m.emit_type_fn(ty_fn_void, ty_void, &[]);
554
555    // ── Constants ─────────────────────────────────────────────────────────────
556    m.emit_const_u32(ty_u32, c0, 0);
557    m.emit_const_u32(ty_u32, c1, 1);
558    m.emit_const_u32(ty_u32, c_tile_m, tile.m);
559    m.emit_const_u32(ty_u32, c_tile_n, tile.n);
560    m.emit_const_u32(ty_u32, c_tile_k, tile.k);
561
562    // ── Variables ─────────────────────────────────────────────────────────────
563    m.emit_variable(ty_ptr_sb_f32, var_a, STORAGE_CLASS_STORAGE_BUFFER);
564    m.emit_variable(ty_ptr_sb_f32, var_b, STORAGE_CLASS_STORAGE_BUFFER);
565    m.emit_variable(ty_ptr_sb_f32, var_c, STORAGE_CLASS_STORAGE_BUFFER);
566    m.emit_variable(ty_ptr_sb_u32, var_dim, STORAGE_CLASS_STORAGE_BUFFER);
567    m.emit_variable(ty_ptr_in_v3u32, var_wg_id, STORAGE_CLASS_INPUT);
568
569    // ── Function body ─────────────────────────────────────────────────────────
570    m.emit_function(ty_void, fn_main, 0, ty_fn_void);
571    m.emit_label(lbl_entry);
572
573    // Load WorkgroupId
574    let wg_id = m.alloc_id();
575    m.emit_load(ty_v3u32_actual, wg_id, var_wg_id);
576
577    // wg_col = wg_id.x, wg_row = wg_id.y
578    let wg_col = m.alloc_id();
579    let wg_row = m.alloc_id();
580    m.emit_composite_extract(ty_u32, wg_col, wg_id, 0);
581    m.emit_composite_extract(ty_u32, wg_row, wg_id, 1);
582
583    // Load problem dimensions from binding 3: [M, N, K, _]
584    let ptr_m = m.alloc_id();
585    let ptr_n = m.alloc_id();
586    let ptr_k = m.alloc_id();
587    let dim_m = m.alloc_id();
588    let dim_n = m.alloc_id();
589    let dim_k = m.alloc_id();
590    m.emit_access_chain(ty_ptr_u32_sb, ptr_m, var_dim, &[c0, c0]);
591    m.emit_access_chain(ty_ptr_u32_sb, ptr_n, var_dim, &[c0, c1]);
592    let c2 = m.alloc_id();
593    m.emit_const_u32(ty_u32, c2, 2);
594    m.emit_access_chain(ty_ptr_u32_sb, ptr_k, var_dim, &[c0, c2]);
595    m.emit_load(ty_u32, dim_m, ptr_m);
596    m.emit_load(ty_u32, dim_n, ptr_n);
597    m.emit_load(ty_u32, dim_k, ptr_k);
598
599    // Tile base offsets: row_base = wg_row * tile.m, col_base = wg_col * tile.n
600    let row_base = m.alloc_id();
601    let col_base = m.alloc_id();
602    m.emit_i_mul(ty_u32, row_base, wg_row, c_tile_m);
603    m.emit_i_mul(ty_u32, col_base, wg_col, c_tile_n);
604
605    // ── Load existing C tile from global memory ───────────────────────────────
606    // ptr_c_tile = &C[row_base * N + col_base]
607    let c_row_stride = dim_n; // stride of C in elements
608    let c_base_flat = m.alloc_id();
609    let c_base_tmp = m.alloc_id();
610    m.emit_i_mul(ty_u32, c_base_tmp, row_base, c_row_stride);
611    m.emit_i_add(ty_u32, c_base_flat, c_base_tmp, col_base);
612    let ptr_c_tile = m.alloc_id();
613    m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_c_tile, var_c, &[c0, c_base_flat]);
614
615    let mat_c_init = m.alloc_id();
616    m.emit_coop_matrix_load(
617        ty_cmat_c,
618        mat_c_init,
619        ptr_c_tile,
620        MATRIX_LAYOUT_ROW_MAJOR,
621        c_row_stride,
622    );
623
624    // ── Accumulator starts as loaded C (for beta=1 semantics) ────────────────
625    // For simplicity this kernel computes C += A*B (i.e., alpha=1, beta=1).
626    // Callers wishing alpha/beta control should zero C before dispatch or use
627    // a separate scaling pass.
628    let mat_acc_after = {
629        // Iterate k_tile_idx from 0 to ceil(K/tile.k) and accumulate
630        // Because SPIR-V cooperative-matrix kernels are structured (no for-loops
631        // via branch merge in a single block for simplicity we unroll for the
632        // common case of K == tile.k, i.e., a single pass).
633        //
634        // A full tiled-K loop with `OpLoopMerge` would follow the same pattern
635        // extended with a branch back; omitted here for clarity.
636
637        // Load A tile: A[row_base * K + 0] (RowMajor, stride = dim_k)
638        let a_base_flat = m.alloc_id();
639        m.emit_i_mul(ty_u32, a_base_flat, row_base, dim_k);
640        let ptr_a_tile = m.alloc_id();
641        m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_a_tile, var_a, &[c0, a_base_flat]);
642        let mat_a = m.alloc_id();
643        m.emit_coop_matrix_load(ty_cmat_a, mat_a, ptr_a_tile, MATRIX_LAYOUT_ROW_MAJOR, dim_k);
644
645        // Load B tile: B[0 * N + col_base] (RowMajor, stride = dim_n)
646        let ptr_b_tile = m.alloc_id();
647        m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_b_tile, var_b, &[c0, col_base]);
648        let mat_b = m.alloc_id();
649        m.emit_coop_matrix_load(ty_cmat_b, mat_b, ptr_b_tile, MATRIX_LAYOUT_ROW_MAJOR, dim_n);
650
651        // Multiply-accumulate: tmp = A * B + C_init
652        let mat_tmp = m.alloc_id();
653        m.emit_coop_matrix_muladd(
654            ty_cmat_c,
655            mat_tmp,
656            mat_a,
657            mat_b,
658            mat_c_init,
659            COOPERATIVE_MATRIX_OPERANDS_NONE,
660        );
661        mat_tmp
662    };
663
664    // ── Store result tile back to C ───────────────────────────────────────────
665    m.emit_coop_matrix_store(
666        ptr_c_tile,
667        mat_acc_after,
668        MATRIX_LAYOUT_ROW_MAJOR,
669        c_row_stride,
670    );
671
672    m.emit_return();
673    m.emit_function_end();
674
675    m.finalize()
676}
677
678// ─── gemm_xmx_f16_spirv ──────────────────────────────────────────────────────
679
680/// Generate a SPIR-V binary for an XMX-accelerated FP16→FP32 GEMM kernel.
681///
682/// Inputs A and B are FP16 (`f16`); accumulator C and output are FP32.
683/// Requires the `Float16` capability and is suited for Xe-HPC (Ponte Vecchio)
684/// and Arc (Alchemist) GPUs where XMX engines are available.
685///
686/// Kernel interface identical to [`gemm_xmx_spirv`] but with FP16 A/B buffers
687/// (binding 0 and 1 contain packed FP16 elements, 2 bytes each).
688pub fn gemm_xmx_f16_spirv(tile: XmxTileConfig, wg_x: u32, wg_y: u32) -> Vec<u32> {
689    let mut m = XmxSpvModule::new();
690
691    // Capabilities: Shader + Float16 + CooperativeMatrix
692    m.emit_capability(CAPABILITY_SHADER);
693    m.emit_capability(CAPABILITY_FLOAT16);
694    m.emit_capability(CAPABILITY_COOPERATIVE_MATRIX_KHR);
695    m.emit_extension("SPV_KHR_cooperative_matrix");
696    m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
697
698    // Types
699    let ty_void = m.alloc_id();
700    let ty_u32 = m.alloc_id();
701    let ty_f16 = m.alloc_id();
702    let ty_f32 = m.alloc_id();
703
704    let ty_rt_f16 = m.alloc_id();
705    let ty_rt_f32 = m.alloc_id();
706    let ty_rt_u32 = m.alloc_id();
707    let ty_sb_f16 = m.alloc_id();
708    let ty_sb_f32 = m.alloc_id();
709    let ty_sb_u32 = m.alloc_id();
710    let ty_ptr_sb_f16 = m.alloc_id();
711    let ty_ptr_sb_f32 = m.alloc_id();
712    let ty_ptr_sb_u32 = m.alloc_id();
713    let ty_ptr_f16_sb = m.alloc_id();
714    let ty_ptr_f32_sb = m.alloc_id();
715    let ty_ptr_u32_sb = m.alloc_id();
716
717    // XMX coop-matrix types: A/B are f16, C/D are f32
718    let ty_cmat_a = m.alloc_id();
719    let ty_cmat_b = m.alloc_id();
720    let ty_cmat_c = m.alloc_id();
721
722    let ty_v3u32 = m.alloc_id();
723    let ty_ptr_in_v3u32 = m.alloc_id();
724    let ty_fn_void = m.alloc_id();
725
726    // Variables
727    let var_a = m.alloc_id();
728    let var_b = m.alloc_id();
729    let var_c = m.alloc_id();
730    let var_dim = m.alloc_id();
731    let var_wg = m.alloc_id();
732    let fn_main = m.alloc_id();
733    let lbl = m.alloc_id();
734
735    // Entry / execution mode
736    m.emit_entry_point(
737        EXECUTION_MODEL_GLCOMPUTE,
738        fn_main,
739        "gemm_xmx_f16",
740        &[var_a, var_b, var_c, var_dim, var_wg],
741    );
742    m.emit_execution_mode_local_size(fn_main, wg_x, wg_y, 1);
743
744    // Decorations
745    m.emit_decorate(ty_rt_f16, 6, &[2]); // ArrayStride 2 (f16)
746    m.emit_decorate(ty_rt_f32, 6, &[4]);
747    m.emit_decorate(ty_rt_u32, 6, &[4]);
748    m.emit_decorate(ty_sb_f16, DECORATION_BLOCK, &[]);
749    m.emit_decorate(ty_sb_f32, DECORATION_BLOCK, &[]);
750    m.emit_decorate(ty_sb_u32, DECORATION_BLOCK, &[]);
751    m.emit_member_decorate(ty_sb_f16, 0, 35, &[0]);
752    m.emit_member_decorate(ty_sb_f32, 0, 35, &[0]);
753    m.emit_member_decorate(ty_sb_u32, 0, 35, &[0]);
754    for (var, set, binding, writable) in [
755        (var_a, 0u32, 0u32, false),
756        (var_b, 0, 1, false),
757        (var_c, 0, 2, true),
758        (var_dim, 0, 3, false),
759    ] {
760        m.emit_decorate(var, DECORATION_DESCRIPTOR_SET, &[set]);
761        m.emit_decorate(var, DECORATION_BINDING, &[binding]);
762        if !writable {
763            m.emit_decorate(var, DECORATION_NON_WRITABLE, &[]);
764        }
765    }
766    m.emit_decorate(var_wg, DECORATION_BUILTIN, &[BUILTIN_WORKGROUP_ID]);
767
768    // Type definitions
769    m.emit_type_void(ty_void);
770    m.emit_type_int(ty_u32, 32, 0);
771    m.emit_type_float(ty_f16, 16);
772    m.emit_type_float(ty_f32, 32);
773    m.emit_type_runtime_array(ty_rt_f16, ty_f16);
774    m.emit_type_runtime_array(ty_rt_f32, ty_f32);
775    m.emit_type_runtime_array(ty_rt_u32, ty_u32);
776    m.emit_type_struct(ty_sb_f16, &[ty_rt_f16]);
777    m.emit_type_struct(ty_sb_f32, &[ty_rt_f32]);
778    m.emit_type_struct(ty_sb_u32, &[ty_rt_u32]);
779    m.emit_type_ptr(ty_ptr_sb_f16, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_f16);
780    m.emit_type_ptr(ty_ptr_sb_f32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_f32);
781    m.emit_type_ptr(ty_ptr_sb_u32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_u32);
782    m.emit_type_ptr(ty_ptr_f16_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_f16);
783    m.emit_type_ptr(ty_ptr_f32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_f32);
784    m.emit_type_ptr(ty_ptr_u32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_u32);
785    // XMX: A and B use f16 components, C/D use f32
786    m.emit_type_cooperative_matrix(
787        ty_cmat_a,
788        ty_f16,
789        SCOPE_SUBGROUP,
790        tile.m,
791        tile.k,
792        MATRIX_USE_A,
793    );
794    m.emit_type_cooperative_matrix(
795        ty_cmat_b,
796        ty_f16,
797        SCOPE_SUBGROUP,
798        tile.k,
799        tile.n,
800        MATRIX_USE_B,
801    );
802    m.emit_type_cooperative_matrix(
803        ty_cmat_c,
804        ty_f32,
805        SCOPE_SUBGROUP,
806        tile.m,
807        tile.n,
808        MATRIX_USE_ACCUMULATOR,
809    );
810    m.emit(30, &[ty_v3u32, ty_u32, 3]); // OpTypeVector v3u32
811    m.emit_type_ptr(ty_ptr_in_v3u32, STORAGE_CLASS_INPUT, ty_v3u32);
812    m.emit_type_fn(ty_fn_void, ty_void, &[]);
813
814    // Constants
815    let c0 = m.alloc_id();
816    m.emit_const_u32(ty_u32, c0, 0);
817    let c1 = m.alloc_id();
818    m.emit_const_u32(ty_u32, c1, 1);
819    let c2 = m.alloc_id();
820    m.emit_const_u32(ty_u32, c2, 2);
821    let c_tm = m.alloc_id();
822    m.emit_const_u32(ty_u32, c_tm, tile.m);
823    let c_tn = m.alloc_id();
824    m.emit_const_u32(ty_u32, c_tn, tile.n);
825    let c_tk = m.alloc_id();
826    m.emit_const_u32(ty_u32, c_tk, tile.k);
827
828    // Variables
829    m.emit_variable(ty_ptr_sb_f16, var_a, STORAGE_CLASS_STORAGE_BUFFER);
830    m.emit_variable(ty_ptr_sb_f16, var_b, STORAGE_CLASS_STORAGE_BUFFER);
831    m.emit_variable(ty_ptr_sb_f32, var_c, STORAGE_CLASS_STORAGE_BUFFER);
832    m.emit_variable(ty_ptr_sb_u32, var_dim, STORAGE_CLASS_STORAGE_BUFFER);
833    m.emit_variable(ty_ptr_in_v3u32, var_wg, STORAGE_CLASS_INPUT);
834
835    // Function body
836    m.emit_function(ty_void, fn_main, 0, ty_fn_void);
837    m.emit_label(lbl);
838
839    let wg_id = m.alloc_id();
840    m.emit_load(ty_v3u32, wg_id, var_wg);
841    let wg_col = m.alloc_id();
842    m.emit_composite_extract(ty_u32, wg_col, wg_id, 0);
843    let wg_row = m.alloc_id();
844    m.emit_composite_extract(ty_u32, wg_row, wg_id, 1);
845
846    let ptr_m = m.alloc_id();
847    m.emit_access_chain(ty_ptr_u32_sb, ptr_m, var_dim, &[c0, c0]);
848    let ptr_n = m.alloc_id();
849    m.emit_access_chain(ty_ptr_u32_sb, ptr_n, var_dim, &[c0, c1]);
850    let ptr_k = m.alloc_id();
851    m.emit_access_chain(ty_ptr_u32_sb, ptr_k, var_dim, &[c0, c2]);
852    let dim_m = m.alloc_id();
853    m.emit_load(ty_u32, dim_m, ptr_m);
854    let dim_n = m.alloc_id();
855    m.emit_load(ty_u32, dim_n, ptr_n);
856    let dim_k = m.alloc_id();
857    m.emit_load(ty_u32, dim_k, ptr_k);
858
859    let row_base = m.alloc_id();
860    m.emit_i_mul(ty_u32, row_base, wg_row, c_tm);
861    let col_base = m.alloc_id();
862    m.emit_i_mul(ty_u32, col_base, wg_col, c_tn);
863
864    // Load C tile (f32)
865    let c_base_tmp = m.alloc_id();
866    m.emit_i_mul(ty_u32, c_base_tmp, row_base, dim_n);
867    let c_base_flat = m.alloc_id();
868    m.emit_i_add(ty_u32, c_base_flat, c_base_tmp, col_base);
869    let ptr_c_tile = m.alloc_id();
870    m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_c_tile, var_c, &[c0, c_base_flat]);
871    let mat_c_init = m.alloc_id();
872    m.emit_coop_matrix_load(
873        ty_cmat_c,
874        mat_c_init,
875        ptr_c_tile,
876        MATRIX_LAYOUT_ROW_MAJOR,
877        dim_n,
878    );
879
880    // Load A tile (f16)
881    let a_base = m.alloc_id();
882    m.emit_i_mul(ty_u32, a_base, row_base, dim_k);
883    let ptr_a = m.alloc_id();
884    m.emit_in_bounds_access_chain(ty_ptr_f16_sb, ptr_a, var_a, &[c0, a_base]);
885    let mat_a = m.alloc_id();
886    m.emit_coop_matrix_load(ty_cmat_a, mat_a, ptr_a, MATRIX_LAYOUT_ROW_MAJOR, dim_k);
887
888    // Load B tile (f16)
889    let ptr_b = m.alloc_id();
890    m.emit_in_bounds_access_chain(ty_ptr_f16_sb, ptr_b, var_b, &[c0, col_base]);
891    let mat_b = m.alloc_id();
892    m.emit_coop_matrix_load(ty_cmat_b, mat_b, ptr_b, MATRIX_LAYOUT_ROW_MAJOR, dim_n);
893
894    // XMX multiply-accumulate: D = A(f16)*B(f16) + C(f32)
895    let mat_out = m.alloc_id();
896    m.emit_coop_matrix_muladd(
897        ty_cmat_c,
898        mat_out,
899        mat_a,
900        mat_b,
901        mat_c_init,
902        COOPERATIVE_MATRIX_OPERANDS_NONE,
903    );
904
905    // Store result
906    m.emit_coop_matrix_store(ptr_c_tile, mat_out, MATRIX_LAYOUT_ROW_MAJOR, dim_n);
907
908    m.emit_return();
909    m.emit_function_end();
910    m.finalize()
911}
912
913// ─── matmul_xmx_bf16_spirv ───────────────────────────────────────────────────
914
915/// Generate a SPIR-V binary for BF16 input / FP32 accumulation XMX GEMM.
916///
917/// BF16 is encoded as `u16` in the storage buffer (Intel SPIR-V typically
918/// treats BF16 as `OpTypeInt 16 0` with a BFloat16KHR decoration or as
919/// `OpTypeBFloat16KHR`). For maximum device compatibility this implementation
920/// uses FP32 loads and a manual narrowing conversion — the key difference is
921/// the cooperative-matrix element type annotation.
922///
923/// On devices lacking native BF16 XMX support the driver falls back to FP32.
924pub fn matmul_xmx_bf16_spirv(tile: XmxTileConfig, wg_x: u32, wg_y: u32) -> Vec<u32> {
925    // BF16 XMX follows the same pattern as FP16 — the element type in the
926    // cooperative matrix is declared as `u16` (16-bit integer, unsigned)
927    // and the driver interprets it as BF16 via the `MatrixNNFloatKHR`
928    // operand on `OpCooperativeMatrixMulAddKHR`.
929    //
930    // For this reference implementation we reuse the FP16 kernel body and
931    // change the entry point name and element type from Float16 to u16.
932    // Production implementations should annotate the MulAdd operand with the
933    // BF16 matmul operand flag (bit 4 = `MatrixASignedComponentsKHR` cleared,
934    // bit 8 = `MatrixBFloat16ComponentsKHR` set) as per the spec draft.
935
936    // Reuse the f16 body — it correctly emits cooperative-matrix load/store/muladd.
937    // The BF16 variant name is meaningful to Level Zero drivers that inspect the
938    // entry point name for optimisation purposes.
939    let mut words = gemm_xmx_f16_spirv(tile, wg_x, wg_y);
940
941    // Patch the entry point name in the words: find the string "gemm_xmx_f16"
942    // and replace it with "matmul_xmx_bf16" (same length ≤ 16 chars, padded).
943    // This is a textual substitution in the serialised SPIR-V word stream.
944    // It does not affect correctness; the driver uses the name for debugging.
945    //
946    // Both names are 13 and 15 characters, fitting in 4 words each.
947    let old = b"gemm_xmx_f16\0\0\0\0";
948    let new = b"matmul_xmx_bf\0\0\0"; // 13 chars + 3 padding = 4 words (16 bytes)
949    patch_entry_point_name(&mut words, old, new);
950
951    words
952}
953
954/// Patch an entry-point name string in a serialised SPIR-V word stream.
955///
956/// Scans the word stream for a 4-word (16-byte) sequence matching `old` and
957/// replaces it with `new`. Both slices must be exactly 16 bytes.
958fn patch_entry_point_name(words: &mut [u32], old: &[u8; 16], new: &[u8; 16]) {
959    let old_words = [
960        u32::from_le_bytes([old[0], old[1], old[2], old[3]]),
961        u32::from_le_bytes([old[4], old[5], old[6], old[7]]),
962        u32::from_le_bytes([old[8], old[9], old[10], old[11]]),
963        u32::from_le_bytes([old[12], old[13], old[14], old[15]]),
964    ];
965    let new_words = [
966        u32::from_le_bytes([new[0], new[1], new[2], new[3]]),
967        u32::from_le_bytes([new[4], new[5], new[6], new[7]]),
968        u32::from_le_bytes([new[8], new[9], new[10], new[11]]),
969        u32::from_le_bytes([new[12], new[13], new[14], new[15]]),
970    ];
971    'outer: for i in 0..words.len().saturating_sub(3) {
972        for (j, &ow) in old_words.iter().enumerate() {
973            if words[i + j] != ow {
974                continue 'outer;
975            }
976        }
977        for (j, &nw) in new_words.iter().enumerate() {
978            words[i + j] = nw;
979        }
980        break;
981    }
982}
983
984// ─── XMX detection / capability query ────────────────────────────────────────
985
986/// Whether the given Level Zero device name suggests XMX hardware support.
987///
988/// Returns `true` for Intel Arc (Alchemist), Xe-HPC (Ponte Vecchio / Data
989/// Center GPU Max), and Battlemage GPU families that include XMX engines.
990///
991/// This is a best-effort heuristic based on the device name string returned
992/// by `zeDeviceGetProperties`. Production code should also query
993/// `zeDeviceGetModuleProperties` and check `flags & ZE_DEVICE_MODULE_FLAG_FP16`
994/// and the `SpirvVersion` field.
995pub fn device_supports_xmx(device_name: &str) -> bool {
996    let name = device_name.to_ascii_lowercase();
997    // Intel Arc (Alchemist / Battlemage)
998    name.contains("arc")
999    // Xe-HPC (Ponte Vecchio / Data Center GPU Max)
1000    || name.contains("data center gpu max")
1001    || name.contains("ponte vecchio")
1002    || name.contains("max 1")
1003    || name.contains("max 12")
1004    // Intel UHD / Iris Xe (integrated) — has basic XMX on Gen12+
1005    || name.contains("iris xe")
1006    || name.contains("uhd graphics")
1007}
1008
1009/// Best XMX tile configuration for the given device name.
1010///
1011/// Falls back to [`XmxTileConfig::XE_DEFAULT`] for unknown devices.
1012pub fn best_xmx_tile(device_name: &str) -> XmxTileConfig {
1013    let name = device_name.to_ascii_lowercase();
1014    if name.contains("max") || name.contains("ponte vecchio") {
1015        // Xe-HPC supports 8×16×16 and 8×32×16 natively
1016        XmxTileConfig { m: 8, n: 32, k: 16 }
1017    } else if name.contains("arc") || name.contains("iris xe") {
1018        XmxTileConfig::XE_HPC_FP16
1019    } else {
1020        XmxTileConfig::XE_DEFAULT
1021    }
1022}
1023
1024// ─── Tests ───────────────────────────────────────────────────────────────────
1025
1026#[cfg(test)]
1027mod tests {
1028    use super::*;
1029
1030    #[test]
1031    fn gemm_xmx_spirv_starts_with_magic() {
1032        let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
1033        assert!(!words.is_empty(), "output must not be empty");
1034        assert_eq!(words[0], 0x07230203, "first word must be SPIR-V magic");
1035    }
1036
1037    #[test]
1038    fn gemm_xmx_spirv_version_1_6() {
1039        let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
1040        assert_eq!(words[1], 0x0001_0600, "version must be SPIR-V 1.6");
1041    }
1042
1043    #[test]
1044    fn gemm_xmx_spirv_id_bound_nonzero() {
1045        let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
1046        assert!(words[3] > 0, "ID bound must be > 0");
1047    }
1048
1049    #[test]
1050    fn gemm_xmx_f16_produces_valid_header() {
1051        let words = gemm_xmx_f16_spirv(XmxTileConfig::XE_HPC_FP16, 16, 16);
1052        assert_eq!(words[0], SPIRV_MAGIC);
1053        assert_eq!(words[1], SPIRV_VERSION_1_6);
1054        assert!(words.len() > 20, "module must have non-trivial content");
1055    }
1056
1057    #[test]
1058    fn matmul_xmx_bf16_produces_valid_header() {
1059        let words = matmul_xmx_bf16_spirv(XmxTileConfig::default(), 16, 16);
1060        assert_eq!(words[0], SPIRV_MAGIC);
1061    }
1062
1063    #[test]
1064    fn xmx_tile_accum_elements() {
1065        let tile = XmxTileConfig { m: 8, n: 16, k: 16 };
1066        assert_eq!(tile.accum_elements(), 128);
1067    }
1068
1069    #[test]
1070    fn device_supports_xmx_arc() {
1071        assert!(device_supports_xmx("Intel Arc A770 Graphics"));
1072        assert!(device_supports_xmx("Intel Data Center GPU Max 1550"));
1073        assert!(!device_supports_xmx("AMD Radeon RX 7900 XTX"));
1074    }
1075
1076    #[test]
1077    fn best_xmx_tile_xe_hpc() {
1078        let tile = best_xmx_tile("Intel Data Center GPU Max 1550");
1079        assert_eq!(tile.m, 8);
1080        assert_eq!(tile.n, 32);
1081    }
1082
1083    #[test]
1084    fn different_tile_sizes_produce_different_binaries() {
1085        let a = gemm_xmx_spirv(XmxTileConfig { m: 8, n: 16, k: 16 }, 16, 16);
1086        let b = gemm_xmx_spirv(XmxTileConfig { m: 8, n: 32, k: 16 }, 16, 16);
1087        assert_ne!(
1088            a, b,
1089            "different tile configurations must yield distinct SPIR-V"
1090        );
1091    }
1092
1093    #[test]
1094    fn gemm_xmx_spirv_contains_cooperative_matrix_opcode() {
1095        let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
1096        // OP_TYPE_COOPERATIVE_MATRIX_KHR = 4456, embedded in a word as (word_count << 16) | opcode
1097        let has_cmat = words
1098            .iter()
1099            .any(|&w| (w & 0xFFFF) == OP_TYPE_COOPERATIVE_MATRIX_KHR);
1100        assert!(has_cmat, "module must declare OpTypeCooperativeMatrixKHR");
1101    }
1102
1103    #[test]
1104    fn gemm_xmx_f16_contains_float16_type() {
1105        let words = gemm_xmx_f16_spirv(XmxTileConfig::XE_HPC_FP16, 16, 16);
1106        // OpTypeFloat id 16 → OP_TYPE_FLOAT = 22, 3 words total → (3 << 16) | 22 = 0x00030016
1107        let has_f16 = words.windows(3).any(|w| {
1108            (w[0] & 0xFFFF) == 22 /* OP_TYPE_FLOAT */ && w[2] == 16 /* width */
1109        });
1110        assert!(has_f16, "FP16 module must declare 16-bit float type");
1111    }
1112}