Skip to main content

oxicuda_levelzero/
spirv.rs

1//! SPIR-V compute kernel generators for the Level Zero backend.
2//!
3//! This module provides:
4//! - A lightweight [`SpvModule`] builder for emitting valid SPIR-V binaries.
5//! - Generator functions for **unary**, **binary**, **reduce**, and **GEMM**
6//!   compute kernels consumed by Level Zero's `zeModuleCreate`.
7//! - The original [`trivial_compute_shader`] placeholder used for validation.
8//!
9//! All generated kernels use the OpenCL SPIR-V execution model (`Kernel`)
10//! with `Physical64`/`OpenCL` memory model.  Buffer parameters are
11//! `CrossWorkgroup` pointers; scalar parameters are passed by value via
12//! `zeKernelSetArgumentValue`.
13//!
14//! The generated SPIR-V targets version 1.2 (widely supported by Level Zero).
15
16use oxicuda_backend::{BinaryOp, ReduceOp, UnaryOp};
17
18// ─── Constants ──────────────────────────────────────────────
19
20/// SPIR-V magic number.
21pub const SPIRV_MAGIC: u32 = 0x07230203;
22/// SPIR-V version 1.2.
23pub const SPIRV_VERSION_1_2: u32 = 0x0001_0200;
24/// Generator magic — OxiCUDA Level Zero backend.
25pub const SPIRV_GENERATOR: u32 = 0x000D_0002;
26
27// ─── SPIR-V opcodes ─────────────────────────────────────────
28
29pub(crate) const OP_EXT_INST_IMPORT: u32 = 11;
30pub(crate) const OP_EXT_INST: u32 = 12;
31pub(crate) const OP_MEMORY_MODEL: u32 = 14;
32pub(crate) const OP_ENTRY_POINT: u32 = 15;
33pub(crate) const OP_EXECUTION_MODE: u32 = 16;
34pub(crate) const OP_CAPABILITY: u32 = 17;
35pub(crate) const OP_TYPE_VOID: u32 = 19;
36pub(crate) const OP_TYPE_BOOL: u32 = 20;
37pub(crate) const OP_TYPE_INT: u32 = 21;
38pub(crate) const OP_TYPE_FLOAT: u32 = 22;
39pub(crate) const OP_TYPE_VECTOR: u32 = 23;
40pub(crate) const OP_TYPE_ARRAY: u32 = 28;
41pub(crate) const OP_TYPE_POINTER: u32 = 32;
42pub(crate) const OP_TYPE_FUNCTION: u32 = 33;
43pub(crate) const OP_CONSTANT: u32 = 43;
44pub(crate) const OP_FUNCTION: u32 = 54;
45pub(crate) const OP_FUNCTION_PARAMETER: u32 = 55;
46pub(crate) const OP_FUNCTION_END: u32 = 56;
47pub(crate) const OP_VARIABLE: u32 = 59;
48pub(crate) const OP_LOAD: u32 = 61;
49pub(crate) const OP_STORE: u32 = 62;
50pub(crate) const OP_IN_BOUNDS_PTR_ACCESS_CHAIN: u32 = 70;
51pub(crate) const OP_DECORATE: u32 = 71;
52pub(crate) const OP_COMPOSITE_EXTRACT: u32 = 81;
53pub(crate) const OP_CONVERT_U_TO_F: u32 = 112;
54pub(crate) const OP_F_NEGATE: u32 = 127;
55pub(crate) const OP_I_ADD: u32 = 128;
56pub(crate) const OP_F_ADD: u32 = 129;
57pub(crate) const OP_F_SUB: u32 = 131;
58pub(crate) const OP_I_MUL: u32 = 132;
59pub(crate) const OP_F_MUL: u32 = 133;
60pub(crate) const OP_U_DIV: u32 = 134;
61pub(crate) const OP_F_DIV: u32 = 136;
62pub(crate) const OP_U_MOD: u32 = 137;
63pub(crate) const OP_U_LESS_THAN: u32 = 176;
64pub(crate) const OP_LOOP_MERGE: u32 = 246;
65pub(crate) const OP_SELECTION_MERGE: u32 = 247;
66pub(crate) const OP_LABEL: u32 = 248;
67pub(crate) const OP_BRANCH: u32 = 249;
68pub(crate) const OP_BRANCH_CONDITIONAL: u32 = 250;
69pub(crate) const OP_CONTROL_BARRIER: u32 = 224;
70pub(crate) const OP_PHI: u32 = 245;
71pub(crate) const OP_RETURN: u32 = 253;
72
73// GroupNonUniform opcodes (SPIR-V 1.3+)
74pub(crate) const OP_GROUP_NON_UNIFORM_FADD: u32 = 350;
75pub(crate) const OP_GROUP_NON_UNIFORM_SHUFFLE: u32 = 345;
76
77// Capabilities
78const CAPABILITY_SHADER: u32 = 1;
79const CAPABILITY_ADDRESSES: u32 = 4;
80const CAPABILITY_KERNEL: u32 = 6;
81
82// Addressing / memory model
83const ADDRESSING_MODEL_LOGICAL: u32 = 0;
84const ADDRESSING_MODEL_PHYSICAL64: u32 = 2;
85const MEMORY_MODEL_GLSL450: u32 = 1;
86const MEMORY_MODEL_OPENCL: u32 = 2;
87
88// Execution model / mode
89const EXECUTION_MODEL_GLCOMPUTE: u32 = 5;
90pub(crate) const EXECUTION_MODEL_KERNEL: u32 = 6;
91const EXECUTION_MODE_LOCAL_SIZE: u32 = 17;
92
93// Function control
94pub(crate) const FUNCTION_CONTROL_NONE: u32 = 0;
95
96// Decorations
97const DECORATION_BUILTIN: u32 = 11;
98
99// BuiltIn values
100const BUILTIN_GLOBAL_INVOCATION_ID: u32 = 28;
101
102// Storage classes
103const STORAGE_CLASS_INPUT: u32 = 1;
104const STORAGE_CLASS_CROSS_WORKGROUP: u32 = 5;
105pub(crate) const STORAGE_CLASS_FUNCTION: u32 = 7;
106
107// Selection/loop control
108const SELECTION_CONTROL_NONE: u32 = 0;
109const LOOP_CONTROL_NONE: u32 = 0;
110
111// OpenCL.std extended instruction numbers
112pub(crate) const OPENCL_EXP: u32 = 19;
113const OPENCL_FABS: u32 = 23;
114pub(crate) const OPENCL_FMAX: u32 = 27;
115const OPENCL_FMIN: u32 = 28;
116const OPENCL_LOG: u32 = 37;
117const OPENCL_SQRT: u32 = 61;
118const OPENCL_TANH: u32 = 63;
119
120/// Workgroup size for 1-D compute kernels.
121pub(crate) const WORKGROUP_SIZE: u32 = 256;
122
123// ─── Minimal SPIR-V builder ──────────────────────────────────
124
125/// Lightweight SPIR-V word-stream builder.
126///
127/// Emits valid SPIR-V instructions for simple compute shaders without
128/// pulling in a full compiler.
129pub struct SpvModule {
130    words: Vec<u32>,
131    /// Next available result ID.
132    id_bound: u32,
133}
134
135impl SpvModule {
136    /// Create a new module with a placeholder header (bound filled at finalise).
137    pub fn new() -> Self {
138        let words = vec![SPIRV_MAGIC, SPIRV_VERSION_1_2, SPIRV_GENERATOR, 0, 0];
139        Self { words, id_bound: 1 }
140    }
141
142    /// Allocate a fresh result ID.
143    pub fn alloc_id(&mut self) -> u32 {
144        let id = self.id_bound;
145        self.id_bound += 1;
146        id
147    }
148
149    /// Emit a SPIR-V instruction.
150    pub fn emit(&mut self, opcode: u32, operands: &[u32]) {
151        let word_count = (1 + operands.len()) as u32;
152        self.words.push((word_count << 16) | opcode);
153        self.words.extend_from_slice(operands);
154    }
155
156    /// Emit a string as null-terminated UTF-8 packed into 32-bit words.
157    pub fn string_words(s: &str) -> Vec<u32> {
158        let bytes = s.as_bytes();
159        let padded_len = (bytes.len() + 4) & !3;
160        let mut out = vec![0u32; padded_len / 4];
161        for (i, &b) in bytes.iter().enumerate() {
162            out[i / 4] |= (b as u32) << ((i % 4) * 8);
163        }
164        out
165    }
166
167    /// Finalise the module: patch the ID bound and return the word vector.
168    pub fn finalize(mut self) -> Vec<u32> {
169        self.words[3] = self.id_bound;
170        self.words
171    }
172
173    // ── Convenience emitters ─────────────────────────────────
174
175    pub(crate) fn emit_capability(&mut self, cap: u32) {
176        self.emit(OP_CAPABILITY, &[cap]);
177    }
178
179    pub(crate) fn emit_ext_inst_import(&mut self, id: u32, name: &str) {
180        let mut ops = vec![id];
181        ops.extend(Self::string_words(name));
182        self.emit(OP_EXT_INST_IMPORT, &ops);
183    }
184
185    pub(crate) fn emit_memory_model(&mut self, addressing: u32, memory: u32) {
186        self.emit(OP_MEMORY_MODEL, &[addressing, memory]);
187    }
188
189    pub(crate) fn emit_entry_point(
190        &mut self,
191        model: u32,
192        func_id: u32,
193        name: &str,
194        interfaces: &[u32],
195    ) {
196        let mut ops = vec![model, func_id];
197        ops.extend(Self::string_words(name));
198        ops.extend_from_slice(interfaces);
199        self.emit(OP_ENTRY_POINT, &ops);
200    }
201
202    pub(crate) fn emit_execution_mode_local_size(&mut self, func_id: u32, x: u32, y: u32, z: u32) {
203        self.emit(
204            OP_EXECUTION_MODE,
205            &[func_id, EXECUTION_MODE_LOCAL_SIZE, x, y, z],
206        );
207    }
208
209    pub(crate) fn emit_decorate(&mut self, target: u32, decoration: u32, operands: &[u32]) {
210        let mut ops = vec![target, decoration];
211        ops.extend_from_slice(operands);
212        self.emit(OP_DECORATE, &ops);
213    }
214
215    pub(crate) fn emit_type_void(&mut self, id: u32) {
216        self.emit(OP_TYPE_VOID, &[id]);
217    }
218
219    pub(crate) fn emit_type_bool(&mut self, id: u32) {
220        self.emit(OP_TYPE_BOOL, &[id]);
221    }
222
223    pub(crate) fn emit_type_int(&mut self, id: u32, width: u32, signedness: u32) {
224        self.emit(OP_TYPE_INT, &[id, width, signedness]);
225    }
226
227    pub(crate) fn emit_type_float(&mut self, id: u32, width: u32) {
228        self.emit(OP_TYPE_FLOAT, &[id, width]);
229    }
230
231    pub(crate) fn emit_type_vector(&mut self, id: u32, component: u32, count: u32) {
232        self.emit(OP_TYPE_VECTOR, &[id, component, count]);
233    }
234
235    pub(crate) fn emit_type_pointer(&mut self, id: u32, storage_class: u32, pointee: u32) {
236        self.emit(OP_TYPE_POINTER, &[id, storage_class, pointee]);
237    }
238
239    pub(crate) fn emit_type_function(&mut self, id: u32, return_type: u32, params: &[u32]) {
240        let mut ops = vec![id, return_type];
241        ops.extend_from_slice(params);
242        self.emit(OP_TYPE_FUNCTION, &ops);
243    }
244
245    pub(crate) fn emit_constant_u32(&mut self, ty: u32, id: u32, value: u32) {
246        self.emit(OP_CONSTANT, &[ty, id, value]);
247    }
248
249    pub(crate) fn emit_constant_f32(&mut self, ty: u32, id: u32, value: f32) {
250        self.emit(OP_CONSTANT, &[ty, id, value.to_bits()]);
251    }
252
253    pub(crate) fn emit_variable(&mut self, ty: u32, id: u32, storage_class: u32) {
254        self.emit(OP_VARIABLE, &[ty, id, storage_class]);
255    }
256
257    pub(crate) fn emit_load(&mut self, result_ty: u32, result: u32, pointer: u32) {
258        self.emit(OP_LOAD, &[result_ty, result, pointer]);
259    }
260
261    pub(crate) fn emit_store(&mut self, pointer: u32, value: u32) {
262        self.emit(OP_STORE, &[pointer, value]);
263    }
264
265    pub(crate) fn emit_in_bounds_ptr_access_chain(
266        &mut self,
267        result_ty: u32,
268        result: u32,
269        base: u32,
270        element: u32,
271    ) {
272        self.emit(
273            OP_IN_BOUNDS_PTR_ACCESS_CHAIN,
274            &[result_ty, result, base, element],
275        );
276    }
277
278    pub(crate) fn emit_function(&mut self, result_ty: u32, result: u32, control: u32, fn_ty: u32) {
279        self.emit(OP_FUNCTION, &[result_ty, result, control, fn_ty]);
280    }
281
282    pub(crate) fn emit_function_parameter(&mut self, result_ty: u32, result: u32) {
283        self.emit(OP_FUNCTION_PARAMETER, &[result_ty, result]);
284    }
285
286    pub(crate) fn emit_label(&mut self, id: u32) {
287        self.emit(OP_LABEL, &[id]);
288    }
289
290    pub(crate) fn emit_return(&mut self) {
291        self.emit(OP_RETURN, &[]);
292    }
293
294    pub(crate) fn emit_function_end(&mut self) {
295        self.emit(OP_FUNCTION_END, &[]);
296    }
297
298    pub(crate) fn emit_branch(&mut self, target: u32) {
299        self.emit(OP_BRANCH, &[target]);
300    }
301
302    pub(crate) fn emit_branch_conditional(&mut self, cond: u32, true_label: u32, false_label: u32) {
303        self.emit(OP_BRANCH_CONDITIONAL, &[cond, true_label, false_label]);
304    }
305
306    pub(crate) fn emit_selection_merge(&mut self, merge_label: u32) {
307        self.emit(OP_SELECTION_MERGE, &[merge_label, SELECTION_CONTROL_NONE]);
308    }
309
310    pub(crate) fn emit_loop_merge(&mut self, merge_label: u32, continue_label: u32) {
311        self.emit(
312            OP_LOOP_MERGE,
313            &[merge_label, continue_label, LOOP_CONTROL_NONE],
314        );
315    }
316
317    pub(crate) fn emit_opencl_ext(
318        &mut self,
319        ext_id: u32,
320        result_ty: u32,
321        result: u32,
322        inst: u32,
323        args: &[u32],
324    ) {
325        let mut ops = vec![result_ty, result, ext_id, inst];
326        ops.extend_from_slice(args);
327        self.emit(OP_EXT_INST, &ops);
328    }
329}
330
331impl Default for SpvModule {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337// ─── Common preamble for OpenCL SPIR-V kernels ──────────────
338
339/// IDs shared by all compute kernels.
340pub(crate) struct BaseIds {
341    pub(crate) ty_void: u32,
342    pub(crate) ty_bool: u32,
343    pub(crate) ty_uint: u32,
344    pub(crate) ty_float: u32,
345    #[allow(dead_code)]
346    pub(crate) ty_v3uint: u32,
347    #[allow(dead_code)]
348    pub(crate) ty_fn_void: u32,
349    #[allow(dead_code)]
350    pub(crate) ty_ptr_input_v3uint: u32,
351    pub(crate) ty_ptr_cross_float: u32,
352    pub(crate) ty_ptr_func_float: u32,
353    pub(crate) ty_ptr_func_uint: u32,
354    pub(crate) c_uint_0: u32,
355    pub(crate) c_uint_1: u32,
356    pub(crate) c_float_0: u32,
357    pub(crate) c_float_1: u32,
358    pub(crate) var_gid: u32,
359    pub(crate) opencl_ext: u32,
360}
361
362/// Emit the preamble shared by all OpenCL-style compute kernels.
363///
364/// This emits capabilities, memory model, types, constants, and the
365/// `GlobalInvocationId` Input variable.  The caller must separately emit
366/// `OpEntryPoint`, `OpExecutionMode`, and the function body.
367pub(crate) fn emit_preamble(m: &mut SpvModule) -> BaseIds {
368    let ty_void = m.alloc_id();
369    let ty_bool = m.alloc_id();
370    let ty_uint = m.alloc_id();
371    let ty_float = m.alloc_id();
372    let ty_v3uint = m.alloc_id();
373    let ty_fn_void = m.alloc_id();
374    let ty_ptr_input_v3uint = m.alloc_id();
375    let ty_ptr_cross_float = m.alloc_id();
376    let ty_ptr_func_float = m.alloc_id();
377    let ty_ptr_func_uint = m.alloc_id();
378    let c_uint_0 = m.alloc_id();
379    let c_uint_1 = m.alloc_id();
380    let c_float_0 = m.alloc_id();
381    let c_float_1 = m.alloc_id();
382    let var_gid = m.alloc_id();
383    let opencl_ext = m.alloc_id();
384
385    // Capabilities
386    m.emit_capability(CAPABILITY_KERNEL);
387    m.emit_capability(CAPABILITY_ADDRESSES);
388
389    // Extension import
390    m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
391
392    // Memory model: Physical64 + OpenCL
393    m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
394
395    // NOTE: OpEntryPoint and OpExecutionMode are emitted by the caller after
396    // allocating the main function ID, so we skip them here.
397
398    // Decoration: GlobalInvocationId on var_gid
399    m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
400
401    // Types
402    m.emit_type_void(ty_void);
403    m.emit_type_bool(ty_bool);
404    m.emit_type_int(ty_uint, 32, 0);
405    m.emit_type_float(ty_float, 32);
406    m.emit_type_vector(ty_v3uint, ty_uint, 3);
407    m.emit_type_function(ty_fn_void, ty_void, &[]);
408    m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
409    m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
410    m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
411    m.emit_type_pointer(ty_ptr_func_uint, STORAGE_CLASS_FUNCTION, ty_uint);
412
413    // Constants
414    m.emit_constant_u32(ty_uint, c_uint_0, 0);
415    m.emit_constant_u32(ty_uint, c_uint_1, 1);
416    m.emit_constant_f32(ty_float, c_float_0, 0.0);
417    m.emit_constant_f32(ty_float, c_float_1, 1.0);
418
419    // GlobalInvocationId input variable
420    m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
421
422    BaseIds {
423        ty_void,
424        ty_bool,
425        ty_uint,
426        ty_float,
427        ty_v3uint,
428        ty_fn_void,
429        ty_ptr_input_v3uint,
430        ty_ptr_cross_float,
431        ty_ptr_func_float,
432        ty_ptr_func_uint,
433        c_uint_0,
434        c_uint_1,
435        c_float_0,
436        c_float_1,
437        var_gid,
438        opencl_ext,
439    }
440}
441
442/// Load `GlobalInvocationId.x` into a uint result.
443pub(crate) fn load_gid_x(m: &mut SpvModule, b: &BaseIds) -> u32 {
444    let gid_val = m.alloc_id();
445    m.emit_load(b.ty_v3uint, gid_val, b.var_gid);
446    let gid_x = m.alloc_id();
447    m.emit(OP_COMPOSITE_EXTRACT, &[b.ty_uint, gid_x, gid_val, 0]);
448    gid_x
449}
450
451// ─── Unary compute kernel ───────────────────────────────────
452
453/// Generate an OpenCL SPIR-V compute kernel for an element-wise unary operation.
454///
455/// Kernel parameters: `(CrossWorkgroup float* input, CrossWorkgroup float* output, uint count)`.
456pub fn unary_compute_shader(op: UnaryOp) -> Vec<u32> {
457    let mut m = SpvModule::new();
458    let b = emit_preamble(&mut m);
459
460    let main_fn = m.alloc_id();
461    let fn_ty = m.alloc_id();
462    let p_input = m.alloc_id();
463    let p_output = m.alloc_id();
464    let p_count = m.alloc_id();
465
466    // Function type: void(CrossWorkgroup float*, CrossWorkgroup float*, uint)
467    m.emit_type_function(
468        fn_ty,
469        b.ty_void,
470        &[b.ty_ptr_cross_float, b.ty_ptr_cross_float, b.ty_uint],
471    );
472
473    // Entry point and execution mode
474    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
475    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
476
477    // Labels
478    let label_entry = m.alloc_id();
479    let label_body = m.alloc_id();
480    let label_merge = m.alloc_id();
481
482    // Function
483    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
484    m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
485    m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
486    m.emit_function_parameter(b.ty_uint, p_count);
487    m.emit_label(label_entry);
488
489    let gid = load_gid_x(&mut m, &b);
490
491    // Bounds check: gid < count
492    let cond = m.alloc_id();
493    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, p_count]);
494    m.emit_selection_merge(label_merge);
495    m.emit_branch_conditional(cond, label_body, label_merge);
496
497    m.emit_label(label_body);
498
499    // input_ptr = &input[gid]
500    let inp_ptr = m.alloc_id();
501    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, gid);
502    let inp_val = m.alloc_id();
503    m.emit_load(b.ty_float, inp_val, inp_ptr);
504
505    let result = emit_unary_op(&mut m, &b, op, inp_val);
506
507    // output_ptr = &output[gid]
508    let out_ptr = m.alloc_id();
509    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
510    m.emit_store(out_ptr, result);
511
512    m.emit_branch(label_merge);
513
514    m.emit_label(label_merge);
515    m.emit_return();
516    m.emit_function_end();
517
518    m.finalize()
519}
520
521/// Emit the SPIR-V instructions for a unary operation, returning the result ID.
522fn emit_unary_op(m: &mut SpvModule, b: &BaseIds, op: UnaryOp, x: u32) -> u32 {
523    let result = m.alloc_id();
524    match op {
525        UnaryOp::Relu => {
526            m.emit_opencl_ext(
527                b.opencl_ext,
528                b.ty_float,
529                result,
530                OPENCL_FMAX,
531                &[b.c_float_0, x],
532            );
533        }
534        UnaryOp::Sigmoid => {
535            let neg_x = m.alloc_id();
536            m.emit(OP_F_NEGATE, &[b.ty_float, neg_x, x]);
537            let exp_neg_x = m.alloc_id();
538            m.emit_opencl_ext(b.opencl_ext, b.ty_float, exp_neg_x, OPENCL_EXP, &[neg_x]);
539            let one_plus = m.alloc_id();
540            m.emit(OP_F_ADD, &[b.ty_float, one_plus, b.c_float_1, exp_neg_x]);
541            m.emit(OP_F_DIV, &[b.ty_float, result, b.c_float_1, one_plus]);
542        }
543        UnaryOp::Tanh => {
544            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_TANH, &[x]);
545        }
546        UnaryOp::Exp => {
547            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_EXP, &[x]);
548        }
549        UnaryOp::Log => {
550            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_LOG, &[x]);
551        }
552        UnaryOp::Sqrt => {
553            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_SQRT, &[x]);
554        }
555        UnaryOp::Abs => {
556            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FABS, &[x]);
557        }
558        UnaryOp::Neg => {
559            m.emit(OP_F_NEGATE, &[b.ty_float, result, x]);
560        }
561    }
562    result
563}
564
565// ─── Binary compute kernel ──────────────────────────────────
566
567/// Generate an OpenCL SPIR-V compute kernel for an element-wise binary operation.
568///
569/// Kernel parameters: `(CrossWorkgroup float* a, CrossWorkgroup float* b,
570///                      CrossWorkgroup float* output, uint count)`.
571pub fn binary_compute_shader(op: BinaryOp) -> Vec<u32> {
572    let mut m = SpvModule::new();
573    let b = emit_preamble(&mut m);
574
575    let main_fn = m.alloc_id();
576    let fn_ty = m.alloc_id();
577    let p_a = m.alloc_id();
578    let p_b = m.alloc_id();
579    let p_out = m.alloc_id();
580    let p_count = m.alloc_id();
581
582    // Function type: void(float*, float*, float*, uint)
583    m.emit_type_function(
584        fn_ty,
585        b.ty_void,
586        &[
587            b.ty_ptr_cross_float,
588            b.ty_ptr_cross_float,
589            b.ty_ptr_cross_float,
590            b.ty_uint,
591        ],
592    );
593
594    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
595    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
596
597    let label_entry = m.alloc_id();
598    let label_body = m.alloc_id();
599    let label_merge = m.alloc_id();
600
601    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
602    m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
603    m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
604    m.emit_function_parameter(b.ty_ptr_cross_float, p_out);
605    m.emit_function_parameter(b.ty_uint, p_count);
606    m.emit_label(label_entry);
607
608    let gid = load_gid_x(&mut m, &b);
609
610    let cond = m.alloc_id();
611    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, p_count]);
612    m.emit_selection_merge(label_merge);
613    m.emit_branch_conditional(cond, label_body, label_merge);
614
615    m.emit_label(label_body);
616
617    let a_ptr = m.alloc_id();
618    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, p_a, gid);
619    let a_val = m.alloc_id();
620    m.emit_load(b.ty_float, a_val, a_ptr);
621
622    let b_ptr = m.alloc_id();
623    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, p_b, gid);
624    let b_val = m.alloc_id();
625    m.emit_load(b.ty_float, b_val, b_ptr);
626
627    let result = emit_binary_op(&mut m, &b, op, a_val, b_val);
628
629    let out_ptr = m.alloc_id();
630    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_out, gid);
631    m.emit_store(out_ptr, result);
632
633    m.emit_branch(label_merge);
634
635    m.emit_label(label_merge);
636    m.emit_return();
637    m.emit_function_end();
638
639    m.finalize()
640}
641
642fn emit_binary_op(m: &mut SpvModule, b: &BaseIds, op: BinaryOp, lhs: u32, rhs: u32) -> u32 {
643    let result = m.alloc_id();
644    match op {
645        BinaryOp::Add => m.emit(OP_F_ADD, &[b.ty_float, result, lhs, rhs]),
646        BinaryOp::Sub => m.emit(OP_F_SUB, &[b.ty_float, result, lhs, rhs]),
647        BinaryOp::Mul => m.emit(OP_F_MUL, &[b.ty_float, result, lhs, rhs]),
648        BinaryOp::Div => m.emit(OP_F_DIV, &[b.ty_float, result, lhs, rhs]),
649        BinaryOp::Max => {
650            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FMAX, &[lhs, rhs]);
651        }
652        BinaryOp::Min => {
653            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FMIN, &[lhs, rhs]);
654        }
655    }
656    result
657}
658
659// ─── Reduce compute kernel ──────────────────────────────────
660
661/// Generate an OpenCL SPIR-V compute kernel for reduction along an axis.
662///
663/// Kernel parameters: `(CrossWorkgroup float* input, CrossWorkgroup float* output,
664///                      uint outer_size, uint reduce_size, uint inner_size)`.
665///
666/// Each thread computes one output element by iterating over the reduce dimension.
667pub fn reduce_compute_shader(op: ReduceOp) -> Vec<u32> {
668    let mut m = SpvModule::new();
669    let b = emit_preamble(&mut m);
670
671    let main_fn = m.alloc_id();
672    let fn_ty = m.alloc_id();
673    let p_input = m.alloc_id();
674    let p_output = m.alloc_id();
675    let p_outer = m.alloc_id();
676    let p_reduce = m.alloc_id();
677    let p_inner = m.alloc_id();
678
679    // Function type: void(float*, float*, uint, uint, uint)
680    m.emit_type_function(
681        fn_ty,
682        b.ty_void,
683        &[
684            b.ty_ptr_cross_float,
685            b.ty_ptr_cross_float,
686            b.ty_uint,
687            b.ty_uint,
688            b.ty_uint,
689        ],
690    );
691
692    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
693    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
694
695    let label_entry = m.alloc_id();
696    let label_bounds_body = m.alloc_id();
697    let label_bounds_merge = m.alloc_id();
698    let label_loop_header = m.alloc_id();
699    let label_loop_body = m.alloc_id();
700    let label_loop_continue = m.alloc_id();
701    let label_loop_merge = m.alloc_id();
702
703    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
704    m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
705    m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
706    m.emit_function_parameter(b.ty_uint, p_outer);
707    m.emit_function_parameter(b.ty_uint, p_reduce);
708    m.emit_function_parameter(b.ty_uint, p_inner);
709    m.emit_label(label_entry);
710
711    let gid = load_gid_x(&mut m, &b);
712
713    // total_output = outer_size * inner_size
714    let total_output = m.alloc_id();
715    m.emit(OP_I_MUL, &[b.ty_uint, total_output, p_outer, p_inner]);
716
717    // Bounds check: gid < total_output
718    let cond_bounds = m.alloc_id();
719    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond_bounds, gid, total_output]);
720    m.emit_selection_merge(label_bounds_merge);
721    m.emit_branch_conditional(cond_bounds, label_bounds_body, label_bounds_merge);
722
723    m.emit_label(label_bounds_body);
724
725    // outer_idx = gid / inner_size, inner_idx = gid % inner_size
726    let outer_idx = m.alloc_id();
727    m.emit(OP_U_DIV, &[b.ty_uint, outer_idx, gid, p_inner]);
728    let inner_idx = m.alloc_id();
729    m.emit(OP_U_MOD, &[b.ty_uint, inner_idx, gid, p_inner]);
730
731    // base = outer_idx * reduce_size * inner_size + inner_idx
732    let t1 = m.alloc_id();
733    m.emit(OP_I_MUL, &[b.ty_uint, t1, outer_idx, p_reduce]);
734    let t2 = m.alloc_id();
735    m.emit(OP_I_MUL, &[b.ty_uint, t2, t1, p_inner]);
736    let base_idx = m.alloc_id();
737    m.emit(OP_I_ADD, &[b.ty_uint, base_idx, t2, inner_idx]);
738
739    // Loop counter
740    let var_i = m.alloc_id();
741    m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
742    m.emit_store(var_i, b.c_uint_0);
743
744    // Accumulator
745    let var_acc = m.alloc_id();
746    m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
747    let init_val = match op {
748        ReduceOp::Sum | ReduceOp::Mean => b.c_float_0,
749        ReduceOp::Max => {
750            let neg_inf = m.alloc_id();
751            m.emit_constant_f32(b.ty_float, neg_inf, f32::NEG_INFINITY);
752            neg_inf
753        }
754        ReduceOp::Min => {
755            let pos_inf = m.alloc_id();
756            m.emit_constant_f32(b.ty_float, pos_inf, f32::INFINITY);
757            pos_inf
758        }
759    };
760    m.emit_store(var_acc, init_val);
761
762    m.emit_branch(label_loop_header);
763
764    // ── Loop header ──
765    m.emit_label(label_loop_header);
766    let i_val = m.alloc_id();
767    m.emit_load(b.ty_uint, i_val, var_i);
768    let loop_cond = m.alloc_id();
769    m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_reduce]);
770    m.emit_loop_merge(label_loop_merge, label_loop_continue);
771    m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
772
773    // ── Loop body ──
774    m.emit_label(label_loop_body);
775
776    // input_idx = base_idx + i * inner_size
777    let i_times_inner = m.alloc_id();
778    m.emit(OP_I_MUL, &[b.ty_uint, i_times_inner, i_val, p_inner]);
779    let input_idx = m.alloc_id();
780    m.emit(OP_I_ADD, &[b.ty_uint, input_idx, base_idx, i_times_inner]);
781
782    let inp_ptr = m.alloc_id();
783    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, input_idx);
784    let inp_val = m.alloc_id();
785    m.emit_load(b.ty_float, inp_val, inp_ptr);
786
787    let acc_val = m.alloc_id();
788    m.emit_load(b.ty_float, acc_val, var_acc);
789
790    let new_acc = m.alloc_id();
791    match op {
792        ReduceOp::Sum | ReduceOp::Mean => {
793            m.emit(OP_F_ADD, &[b.ty_float, new_acc, acc_val, inp_val]);
794        }
795        ReduceOp::Max => {
796            m.emit_opencl_ext(
797                b.opencl_ext,
798                b.ty_float,
799                new_acc,
800                OPENCL_FMAX,
801                &[acc_val, inp_val],
802            );
803        }
804        ReduceOp::Min => {
805            m.emit_opencl_ext(
806                b.opencl_ext,
807                b.ty_float,
808                new_acc,
809                OPENCL_FMIN,
810                &[acc_val, inp_val],
811            );
812        }
813    }
814    m.emit_store(var_acc, new_acc);
815
816    m.emit_branch(label_loop_continue);
817
818    // ── Loop continue ──
819    m.emit_label(label_loop_continue);
820    let i_inc = m.alloc_id();
821    m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
822    m.emit_store(var_i, i_inc);
823    m.emit_branch(label_loop_header);
824
825    // ── Loop merge ──
826    m.emit_label(label_loop_merge);
827
828    let final_acc = m.alloc_id();
829    m.emit_load(b.ty_float, final_acc, var_acc);
830
831    let store_val = if op == ReduceOp::Mean {
832        let reduce_f = m.alloc_id();
833        m.emit(OP_CONVERT_U_TO_F, &[b.ty_float, reduce_f, p_reduce]);
834        let mean_val = m.alloc_id();
835        m.emit(OP_F_DIV, &[b.ty_float, mean_val, final_acc, reduce_f]);
836        mean_val
837    } else {
838        final_acc
839    };
840
841    let out_ptr = m.alloc_id();
842    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
843    m.emit_store(out_ptr, store_val);
844
845    m.emit_branch(label_bounds_merge);
846
847    m.emit_label(label_bounds_merge);
848    m.emit_return();
849    m.emit_function_end();
850
851    m.finalize()
852}
853
854// ─── GEMM compute kernel ────────────────────────────────────
855
856/// Generate an OpenCL SPIR-V compute kernel for GEMM: `C = alpha * A * B + beta * C`.
857///
858/// Naive reference (one thread per output element, row-major f32 layout).
859///
860/// Kernel parameters: `(CrossWorkgroup float* A, CrossWorkgroup float* B,
861///                      CrossWorkgroup float* C, uint m, uint n, uint k,
862///                      float alpha, float beta)`.
863pub fn gemm_compute_shader() -> Vec<u32> {
864    let mut m = SpvModule::new();
865    let b = emit_preamble(&mut m);
866
867    let main_fn = m.alloc_id();
868    let fn_ty = m.alloc_id();
869    let p_a = m.alloc_id();
870    let p_b = m.alloc_id();
871    let p_c = m.alloc_id();
872    let p_m = m.alloc_id();
873    let p_n = m.alloc_id();
874    let p_k = m.alloc_id();
875    let p_alpha = m.alloc_id();
876    let p_beta = m.alloc_id();
877
878    // Function type: void(float*, float*, float*, uint, uint, uint, float, float)
879    m.emit_type_function(
880        fn_ty,
881        b.ty_void,
882        &[
883            b.ty_ptr_cross_float,
884            b.ty_ptr_cross_float,
885            b.ty_ptr_cross_float,
886            b.ty_uint,
887            b.ty_uint,
888            b.ty_uint,
889            b.ty_float,
890            b.ty_float,
891        ],
892    );
893
894    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
895    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
896
897    let label_entry = m.alloc_id();
898    let label_bounds_body = m.alloc_id();
899    let label_bounds_merge = m.alloc_id();
900    let label_loop_header = m.alloc_id();
901    let label_loop_body = m.alloc_id();
902    let label_loop_continue = m.alloc_id();
903    let label_loop_merge = m.alloc_id();
904
905    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
906    m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
907    m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
908    m.emit_function_parameter(b.ty_ptr_cross_float, p_c);
909    m.emit_function_parameter(b.ty_uint, p_m);
910    m.emit_function_parameter(b.ty_uint, p_n);
911    m.emit_function_parameter(b.ty_uint, p_k);
912    m.emit_function_parameter(b.ty_float, p_alpha);
913    m.emit_function_parameter(b.ty_float, p_beta);
914    m.emit_label(label_entry);
915
916    let gid = load_gid_x(&mut m, &b);
917
918    // total = m * n
919    let total = m.alloc_id();
920    m.emit(OP_I_MUL, &[b.ty_uint, total, p_m, p_n]);
921
922    // Bounds check
923    let cond = m.alloc_id();
924    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, total]);
925    m.emit_selection_merge(label_bounds_merge);
926    m.emit_branch_conditional(cond, label_bounds_body, label_bounds_merge);
927
928    m.emit_label(label_bounds_body);
929
930    // row = gid / n, col = gid % n
931    let row = m.alloc_id();
932    m.emit(OP_U_DIV, &[b.ty_uint, row, gid, p_n]);
933    let col = m.alloc_id();
934    m.emit(OP_U_MOD, &[b.ty_uint, col, gid, p_n]);
935
936    // Loop counter + accumulator
937    let var_i = m.alloc_id();
938    m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
939    m.emit_store(var_i, b.c_uint_0);
940    let var_acc = m.alloc_id();
941    m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
942    m.emit_store(var_acc, b.c_float_0);
943
944    m.emit_branch(label_loop_header);
945
946    // ── Loop header ──
947    m.emit_label(label_loop_header);
948    let i_val = m.alloc_id();
949    m.emit_load(b.ty_uint, i_val, var_i);
950    let loop_cond = m.alloc_id();
951    m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_k]);
952    m.emit_loop_merge(label_loop_merge, label_loop_continue);
953    m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
954
955    // ── Loop body ──
956    m.emit_label(label_loop_body);
957
958    // a_idx = row * k + i
959    let row_k = m.alloc_id();
960    m.emit(OP_I_MUL, &[b.ty_uint, row_k, row, p_k]);
961    let a_idx = m.alloc_id();
962    m.emit(OP_I_ADD, &[b.ty_uint, a_idx, row_k, i_val]);
963
964    // b_idx = i * n + col
965    let i_n = m.alloc_id();
966    m.emit(OP_I_MUL, &[b.ty_uint, i_n, i_val, p_n]);
967    let b_idx = m.alloc_id();
968    m.emit(OP_I_ADD, &[b.ty_uint, b_idx, i_n, col]);
969
970    let a_ptr = m.alloc_id();
971    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, p_a, a_idx);
972    let a_val = m.alloc_id();
973    m.emit_load(b.ty_float, a_val, a_ptr);
974
975    let b_ptr = m.alloc_id();
976    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, p_b, b_idx);
977    let b_val = m.alloc_id();
978    m.emit_load(b.ty_float, b_val, b_ptr);
979
980    let prod = m.alloc_id();
981    m.emit(OP_F_MUL, &[b.ty_float, prod, a_val, b_val]);
982    let old_acc = m.alloc_id();
983    m.emit_load(b.ty_float, old_acc, var_acc);
984    let new_acc = m.alloc_id();
985    m.emit(OP_F_ADD, &[b.ty_float, new_acc, old_acc, prod]);
986    m.emit_store(var_acc, new_acc);
987
988    m.emit_branch(label_loop_continue);
989
990    // ── Loop continue ──
991    m.emit_label(label_loop_continue);
992    let i_inc = m.alloc_id();
993    m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
994    m.emit_store(var_i, i_inc);
995    m.emit_branch(label_loop_header);
996
997    // ── Loop merge ──
998    m.emit_label(label_loop_merge);
999
1000    // result = alpha * acc + beta * C[gid]
1001    let final_acc = m.alloc_id();
1002    m.emit_load(b.ty_float, final_acc, var_acc);
1003    let alpha_acc = m.alloc_id();
1004    m.emit(OP_F_MUL, &[b.ty_float, alpha_acc, p_alpha, final_acc]);
1005
1006    let c_ptr = m.alloc_id();
1007    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, c_ptr, p_c, gid);
1008    let c_old = m.alloc_id();
1009    m.emit_load(b.ty_float, c_old, c_ptr);
1010    let beta_c = m.alloc_id();
1011    m.emit(OP_F_MUL, &[b.ty_float, beta_c, p_beta, c_old]);
1012    let c_new = m.alloc_id();
1013    m.emit(OP_F_ADD, &[b.ty_float, c_new, alpha_acc, beta_c]);
1014    m.emit_store(c_ptr, c_new);
1015
1016    m.emit_branch(label_bounds_merge);
1017
1018    m.emit_label(label_bounds_merge);
1019    m.emit_return();
1020    m.emit_function_end();
1021
1022    m.finalize()
1023}
1024
1025// ─── Batched GEMM compute kernel ─────────────────────────────
1026
1027/// Load `GlobalInvocationId.z` into a uint result.
1028fn load_gid_z(m: &mut SpvModule, b: &BaseIds) -> u32 {
1029    let gid_val = m.alloc_id();
1030    m.emit_load(b.ty_v3uint, gid_val, b.var_gid);
1031    let gid_z = m.alloc_id();
1032    m.emit(OP_COMPOSITE_EXTRACT, &[b.ty_uint, gid_z, gid_val, 2]);
1033    gid_z
1034}
1035
1036/// Generate an OpenCL SPIR-V compute kernel for batched GEMM.
1037///
1038/// For each batch `b` in `0..batch_count`:
1039///   `C_b = alpha * A_b * B_b + beta * C_b`
1040/// where `A_b` starts at offset `b * stride_a`, etc.
1041///
1042/// Uses 3D global work size `(ceil(m*n / WG), 1, batch_count)`:
1043/// - `get_global_id(0)` = element index within a single m×n output
1044/// - `get_global_id(2)` = batch index
1045///
1046/// Kernel parameters:
1047/// `(CrossWorkgroup float* A, CrossWorkgroup float* B, CrossWorkgroup float* C,
1048///   uint m, uint n, uint k, float alpha, float beta,
1049///   uint batch_count, uint stride_a, uint stride_b, uint stride_c)`.
1050pub fn batched_gemm_compute_shader() -> Vec<u32> {
1051    let mut m = SpvModule::new();
1052    let b = emit_preamble(&mut m);
1053
1054    let main_fn = m.alloc_id();
1055    let fn_ty = m.alloc_id();
1056    let p_a = m.alloc_id();
1057    let p_b = m.alloc_id();
1058    let p_c = m.alloc_id();
1059    let p_m = m.alloc_id();
1060    let p_n = m.alloc_id();
1061    let p_k = m.alloc_id();
1062    let p_alpha = m.alloc_id();
1063    let p_beta = m.alloc_id();
1064    let p_batch_count = m.alloc_id();
1065    let p_stride_a = m.alloc_id();
1066    let p_stride_b = m.alloc_id();
1067    let p_stride_c = m.alloc_id();
1068
1069    // Function type: void(float*, float*, float*, uint, uint, uint, float, float,
1070    //                      uint, uint, uint, uint)
1071    m.emit_type_function(
1072        fn_ty,
1073        b.ty_void,
1074        &[
1075            b.ty_ptr_cross_float,
1076            b.ty_ptr_cross_float,
1077            b.ty_ptr_cross_float,
1078            b.ty_uint,
1079            b.ty_uint,
1080            b.ty_uint,
1081            b.ty_float,
1082            b.ty_float,
1083            b.ty_uint,
1084            b.ty_uint,
1085            b.ty_uint,
1086            b.ty_uint,
1087        ],
1088    );
1089
1090    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
1091    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
1092
1093    let label_entry = m.alloc_id();
1094    let label_bounds_body = m.alloc_id();
1095    let label_bounds_merge = m.alloc_id();
1096    let label_loop_header = m.alloc_id();
1097    let label_loop_body = m.alloc_id();
1098    let label_loop_continue = m.alloc_id();
1099    let label_loop_merge = m.alloc_id();
1100
1101    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
1102    m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
1103    m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
1104    m.emit_function_parameter(b.ty_ptr_cross_float, p_c);
1105    m.emit_function_parameter(b.ty_uint, p_m);
1106    m.emit_function_parameter(b.ty_uint, p_n);
1107    m.emit_function_parameter(b.ty_uint, p_k);
1108    m.emit_function_parameter(b.ty_float, p_alpha);
1109    m.emit_function_parameter(b.ty_float, p_beta);
1110    m.emit_function_parameter(b.ty_uint, p_batch_count);
1111    m.emit_function_parameter(b.ty_uint, p_stride_a);
1112    m.emit_function_parameter(b.ty_uint, p_stride_b);
1113    m.emit_function_parameter(b.ty_uint, p_stride_c);
1114    m.emit_label(label_entry);
1115
1116    // gid_x = element index within single GEMM output
1117    let gid = load_gid_x(&mut m, &b);
1118    // gid_z = batch index
1119    let batch_idx = load_gid_z(&mut m, &b);
1120
1121    // total = m * n
1122    let total = m.alloc_id();
1123    m.emit(OP_I_MUL, &[b.ty_uint, total, p_m, p_n]);
1124
1125    // Bounds check: gid < total && batch_idx < batch_count
1126    let cond1 = m.alloc_id();
1127    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond1, gid, total]);
1128    let cond2 = m.alloc_id();
1129    m.emit(
1130        OP_U_LESS_THAN,
1131        &[b.ty_bool, cond2, batch_idx, p_batch_count],
1132    );
1133    // Combined condition via OpLogicalAnd
1134    let cond = m.alloc_id();
1135    // OpLogicalAnd = 166
1136    m.emit(166, &[b.ty_bool, cond, cond1, cond2]);
1137    m.emit_selection_merge(label_bounds_merge);
1138    m.emit_branch_conditional(cond, label_bounds_body, label_bounds_merge);
1139
1140    m.emit_label(label_bounds_body);
1141
1142    // Compute batch offsets: a_base = batch_idx * stride_a, etc.
1143    let a_offset = m.alloc_id();
1144    m.emit(OP_I_MUL, &[b.ty_uint, a_offset, batch_idx, p_stride_a]);
1145    let b_offset = m.alloc_id();
1146    m.emit(OP_I_MUL, &[b.ty_uint, b_offset, batch_idx, p_stride_b]);
1147    let c_offset = m.alloc_id();
1148    m.emit(OP_I_MUL, &[b.ty_uint, c_offset, batch_idx, p_stride_c]);
1149
1150    // Offset the base pointers
1151    let a_batch = m.alloc_id();
1152    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_batch, p_a, a_offset);
1153    let b_batch = m.alloc_id();
1154    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_batch, p_b, b_offset);
1155    let c_batch = m.alloc_id();
1156    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, c_batch, p_c, c_offset);
1157
1158    // row = gid / n, col = gid % n
1159    let row = m.alloc_id();
1160    m.emit(OP_U_DIV, &[b.ty_uint, row, gid, p_n]);
1161    let col = m.alloc_id();
1162    m.emit(OP_U_MOD, &[b.ty_uint, col, gid, p_n]);
1163
1164    // Loop counter + accumulator
1165    let var_i = m.alloc_id();
1166    m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
1167    m.emit_store(var_i, b.c_uint_0);
1168    let var_acc = m.alloc_id();
1169    m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
1170    m.emit_store(var_acc, b.c_float_0);
1171
1172    m.emit_branch(label_loop_header);
1173
1174    // ── Loop header ──
1175    m.emit_label(label_loop_header);
1176    let i_val = m.alloc_id();
1177    m.emit_load(b.ty_uint, i_val, var_i);
1178    let loop_cond = m.alloc_id();
1179    m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_k]);
1180    m.emit_loop_merge(label_loop_merge, label_loop_continue);
1181    m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
1182
1183    // ── Loop body ──
1184    m.emit_label(label_loop_body);
1185
1186    // a_idx = row * k + i
1187    let row_k = m.alloc_id();
1188    m.emit(OP_I_MUL, &[b.ty_uint, row_k, row, p_k]);
1189    let a_idx = m.alloc_id();
1190    m.emit(OP_I_ADD, &[b.ty_uint, a_idx, row_k, i_val]);
1191
1192    // b_idx = i * n + col
1193    let i_n = m.alloc_id();
1194    m.emit(OP_I_MUL, &[b.ty_uint, i_n, i_val, p_n]);
1195    let b_idx = m.alloc_id();
1196    m.emit(OP_I_ADD, &[b.ty_uint, b_idx, i_n, col]);
1197
1198    let a_ptr = m.alloc_id();
1199    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, a_batch, a_idx);
1200    let a_val = m.alloc_id();
1201    m.emit_load(b.ty_float, a_val, a_ptr);
1202
1203    let b_ptr = m.alloc_id();
1204    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, b_batch, b_idx);
1205    let b_val = m.alloc_id();
1206    m.emit_load(b.ty_float, b_val, b_ptr);
1207
1208    let prod = m.alloc_id();
1209    m.emit(OP_F_MUL, &[b.ty_float, prod, a_val, b_val]);
1210    let old_acc = m.alloc_id();
1211    m.emit_load(b.ty_float, old_acc, var_acc);
1212    let new_acc = m.alloc_id();
1213    m.emit(OP_F_ADD, &[b.ty_float, new_acc, old_acc, prod]);
1214    m.emit_store(var_acc, new_acc);
1215
1216    m.emit_branch(label_loop_continue);
1217
1218    // ── Loop continue ──
1219    m.emit_label(label_loop_continue);
1220    let i_inc = m.alloc_id();
1221    m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
1222    m.emit_store(var_i, i_inc);
1223    m.emit_branch(label_loop_header);
1224
1225    // ── Loop merge ──
1226    m.emit_label(label_loop_merge);
1227
1228    // result = alpha * acc + beta * C_batch[gid]
1229    let final_acc = m.alloc_id();
1230    m.emit_load(b.ty_float, final_acc, var_acc);
1231    let alpha_acc = m.alloc_id();
1232    m.emit(OP_F_MUL, &[b.ty_float, alpha_acc, p_alpha, final_acc]);
1233
1234    let c_ptr = m.alloc_id();
1235    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, c_ptr, c_batch, gid);
1236    let c_old = m.alloc_id();
1237    m.emit_load(b.ty_float, c_old, c_ptr);
1238    let beta_c = m.alloc_id();
1239    m.emit(OP_F_MUL, &[b.ty_float, beta_c, p_beta, c_old]);
1240    let c_new = m.alloc_id();
1241    m.emit(OP_F_ADD, &[b.ty_float, c_new, alpha_acc, beta_c]);
1242    m.emit_store(c_ptr, c_new);
1243
1244    m.emit_branch(label_bounds_merge);
1245
1246    m.emit_label(label_bounds_merge);
1247    m.emit_return();
1248    m.emit_function_end();
1249
1250    m.finalize()
1251}
1252
1253// ─── Trivial placeholder ────────────────────────────────────
1254
1255/// Build a minimal valid Shader-style compute shader: `void main() {}`.
1256///
1257/// Uses `GLCompute` / Shader capability for basic Level Zero module validation.
1258pub fn trivial_compute_shader() -> Vec<u32> {
1259    let mut m = SpvModule::new();
1260
1261    let id_main_fn = m.alloc_id();
1262    let id_void = m.alloc_id();
1263    let id_void_fn = m.alloc_id();
1264    let id_label = m.alloc_id();
1265
1266    m.emit_capability(CAPABILITY_SHADER);
1267    m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
1268
1269    let mut entry_words = vec![EXECUTION_MODEL_GLCOMPUTE, id_main_fn];
1270    entry_words.extend(SpvModule::string_words("main"));
1271    m.emit(OP_ENTRY_POINT, &entry_words);
1272
1273    m.emit_execution_mode_local_size(id_main_fn, 1, 1, 1);
1274
1275    m.emit_type_void(id_void);
1276    m.emit_type_function(id_void_fn, id_void, &[]);
1277
1278    m.emit_function(id_void, id_main_fn, FUNCTION_CONTROL_NONE, id_void_fn);
1279    m.emit_label(id_label);
1280    m.emit_return();
1281    m.emit_function_end();
1282
1283    m.finalize()
1284}
1285
1286/// Return the trivial compute shader as a byte slice suitable for
1287/// passing to Level Zero module creation.
1288pub fn trivial_compute_shader_bytes() -> Vec<u8> {
1289    trivial_compute_shader()
1290        .iter()
1291        .flat_map(|w| w.to_ne_bytes())
1292        .collect()
1293}
1294
1295// ─── Tests ──────────────────────────────────────────────────
1296
1297#[cfg(test)]
1298mod tests {
1299    use super::*;
1300
1301    fn check_valid_spirv(words: &[u32]) {
1302        assert!(words.len() >= 5, "too short for SPIR-V header");
1303        assert_eq!(words[0], SPIRV_MAGIC, "bad magic");
1304        assert!(words[3] > 0, "ID bound must be > 0");
1305        assert_eq!(words[4], 0, "schema must be 0");
1306    }
1307
1308    #[test]
1309    fn placeholder_spv_valid_magic() {
1310        let words = trivial_compute_shader();
1311        check_valid_spirv(&words);
1312    }
1313
1314    #[test]
1315    fn placeholder_spv_word_aligned() {
1316        let bytes = trivial_compute_shader_bytes();
1317        assert_eq!(bytes.len() % 4, 0);
1318    }
1319
1320    #[test]
1321    fn placeholder_spv_version_and_schema() {
1322        let words = trivial_compute_shader();
1323        assert!(words[1] >= 0x0001_0000);
1324        assert_eq!(words[4], 0);
1325    }
1326
1327    #[test]
1328    fn placeholder_spv_nonzero_bound() {
1329        let words = trivial_compute_shader();
1330        assert!(words[3] > 0);
1331    }
1332
1333    #[test]
1334    fn spv_module_id_allocation_is_monotonic() {
1335        let mut m = SpvModule::new();
1336        let id1 = m.alloc_id();
1337        let id2 = m.alloc_id();
1338        assert!(id2 > id1);
1339    }
1340
1341    #[test]
1342    fn string_words_null_terminated() {
1343        let words = SpvModule::string_words("abc");
1344        assert!(!words.is_empty());
1345        let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
1346        assert_eq!(bytes[0], b'a');
1347        assert_eq!(bytes[1], b'b');
1348        assert_eq!(bytes[2], b'c');
1349        assert_eq!(bytes[3], 0);
1350    }
1351
1352    #[test]
1353    fn string_words_empty_string() {
1354        let words = SpvModule::string_words("");
1355        assert!(!words.is_empty());
1356        let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
1357        assert_eq!(bytes[0], 0);
1358    }
1359
1360    #[test]
1361    fn generator_magic_is_level_zero() {
1362        assert_eq!(SPIRV_GENERATOR, 0x000D_0002);
1363        assert_ne!(SPIRV_GENERATOR, 0x000D_0001);
1364    }
1365
1366    // ── Compute kernel generation ────────────────────────────
1367
1368    #[test]
1369    fn unary_shader_all_ops() {
1370        let ops = [
1371            UnaryOp::Relu,
1372            UnaryOp::Sigmoid,
1373            UnaryOp::Tanh,
1374            UnaryOp::Exp,
1375            UnaryOp::Log,
1376            UnaryOp::Sqrt,
1377            UnaryOp::Abs,
1378            UnaryOp::Neg,
1379        ];
1380        for op in ops {
1381            let words = unary_compute_shader(op);
1382            check_valid_spirv(&words);
1383        }
1384    }
1385
1386    #[test]
1387    fn binary_shader_all_ops() {
1388        let ops = [
1389            BinaryOp::Add,
1390            BinaryOp::Sub,
1391            BinaryOp::Mul,
1392            BinaryOp::Div,
1393            BinaryOp::Max,
1394            BinaryOp::Min,
1395        ];
1396        for op in ops {
1397            let words = binary_compute_shader(op);
1398            check_valid_spirv(&words);
1399        }
1400    }
1401
1402    #[test]
1403    fn reduce_shader_all_ops() {
1404        let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
1405        for op in ops {
1406            let words = reduce_compute_shader(op);
1407            check_valid_spirv(&words);
1408        }
1409    }
1410
1411    #[test]
1412    fn gemm_shader_valid() {
1413        let words = gemm_compute_shader();
1414        check_valid_spirv(&words);
1415    }
1416
1417    #[test]
1418    fn batched_gemm_shader_valid() {
1419        let words = batched_gemm_compute_shader();
1420        check_valid_spirv(&words);
1421    }
1422
1423    #[test]
1424    fn batched_gemm_shader_word_aligned() {
1425        let words = batched_gemm_compute_shader();
1426        let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_ne_bytes()).collect();
1427        assert_eq!(bytes.len() % 4, 0);
1428    }
1429
1430    #[test]
1431    fn batched_gemm_shader_uses_kernel_capability() {
1432        let words = batched_gemm_compute_shader();
1433        let cap_header = (2u32 << 16) | OP_CAPABILITY;
1434        assert_eq!(words[5], cap_header);
1435        assert_eq!(words[6], 6); // CAPABILITY_KERNEL
1436    }
1437
1438    #[test]
1439    fn all_kernel_shaders_word_aligned() {
1440        fn to_bytes(words: &[u32]) -> Vec<u8> {
1441            words.iter().flat_map(|w| w.to_ne_bytes()).collect()
1442        }
1443        assert_eq!(to_bytes(&unary_compute_shader(UnaryOp::Relu)).len() % 4, 0);
1444        assert_eq!(to_bytes(&binary_compute_shader(BinaryOp::Add)).len() % 4, 0);
1445        assert_eq!(to_bytes(&reduce_compute_shader(ReduceOp::Sum)).len() % 4, 0);
1446        assert_eq!(to_bytes(&gemm_compute_shader()).len() % 4, 0);
1447        assert_eq!(to_bytes(&batched_gemm_compute_shader()).len() % 4, 0);
1448    }
1449
1450    #[test]
1451    fn kernel_shaders_use_opencl_memory_model() {
1452        // Check that kernel shaders use Physical64 + OpenCL memory model,
1453        // while trivial shader uses Logical + GLSL450.
1454        let trivial = trivial_compute_shader();
1455        let unary = unary_compute_shader(UnaryOp::Relu);
1456
1457        // The trivial shader should contain the Shader capability (1)
1458        // The unary shader should contain the Kernel capability (6)
1459        // These appear in OpCapability instructions after the header.
1460
1461        // After header (5 words), first instruction is OpCapability
1462        // Format: (2 << 16) | 17 = 0x00020011
1463        let cap_header = (2u32 << 16) | OP_CAPABILITY;
1464        assert_eq!(trivial[5], cap_header);
1465        assert_eq!(trivial[6], CAPABILITY_SHADER);
1466        assert_eq!(unary[5], cap_header);
1467        assert_eq!(unary[6], CAPABILITY_KERNEL);
1468    }
1469}