Skip to main content

oxicuda_levelzero/
spirv.rs

1//! SPIR-V compute kernel generators for the Level Zero backend.
2//!
3//! This module provides:
4//! - A lightweight [`SpvModule`] builder for emitting valid SPIR-V binaries.
5//! - Generator functions for **unary**, **binary**, **reduce**, and **GEMM**
6//!   compute kernels consumed by Level Zero's `zeModuleCreate`.
7//! - The original [`trivial_compute_shader`] placeholder used for validation.
8//!
9//! All generated kernels use the OpenCL SPIR-V execution model (`Kernel`)
10//! with `Physical64`/`OpenCL` memory model.  Buffer parameters are
11//! `CrossWorkgroup` pointers; scalar parameters are passed by value via
12//! `zeKernelSetArgumentValue`.
13//!
14//! The generated SPIR-V targets version 1.2 (widely supported by Level Zero).
15
16use oxicuda_backend::{BinaryOp, ReduceOp, UnaryOp};
17
18// ─── Constants ──────────────────────────────────────────────
19
20/// SPIR-V magic number.
21pub const SPIRV_MAGIC: u32 = 0x07230203;
22/// SPIR-V version 1.2.
23pub const SPIRV_VERSION_1_2: u32 = 0x0001_0200;
24/// Generator magic — OxiCUDA Level Zero backend.
25pub const SPIRV_GENERATOR: u32 = 0x000D_0002;
26
27// ─── SPIR-V opcodes ─────────────────────────────────────────
28
29pub(crate) const OP_EXT_INST_IMPORT: u32 = 11;
30pub(crate) const OP_EXT_INST: u32 = 12;
31pub(crate) const OP_MEMORY_MODEL: u32 = 14;
32pub(crate) const OP_ENTRY_POINT: u32 = 15;
33pub(crate) const OP_EXECUTION_MODE: u32 = 16;
34pub(crate) const OP_CAPABILITY: u32 = 17;
35pub(crate) const OP_TYPE_VOID: u32 = 19;
36pub(crate) const OP_TYPE_BOOL: u32 = 20;
37pub(crate) const OP_TYPE_INT: u32 = 21;
38pub(crate) const OP_TYPE_FLOAT: u32 = 22;
39pub(crate) const OP_TYPE_VECTOR: u32 = 23;
40pub(crate) const OP_TYPE_POINTER: u32 = 32;
41pub(crate) const OP_TYPE_FUNCTION: u32 = 33;
42pub(crate) const OP_CONSTANT: u32 = 43;
43pub(crate) const OP_FUNCTION: u32 = 54;
44pub(crate) const OP_FUNCTION_PARAMETER: u32 = 55;
45pub(crate) const OP_FUNCTION_END: u32 = 56;
46pub(crate) const OP_VARIABLE: u32 = 59;
47pub(crate) const OP_LOAD: u32 = 61;
48pub(crate) const OP_STORE: u32 = 62;
49pub(crate) const OP_IN_BOUNDS_PTR_ACCESS_CHAIN: u32 = 70;
50pub(crate) const OP_DECORATE: u32 = 71;
51pub(crate) const OP_COMPOSITE_EXTRACT: u32 = 81;
52pub(crate) const OP_CONVERT_U_TO_F: u32 = 112;
53pub(crate) const OP_F_NEGATE: u32 = 127;
54pub(crate) const OP_I_ADD: u32 = 128;
55pub(crate) const OP_F_ADD: u32 = 129;
56pub(crate) const OP_F_SUB: u32 = 131;
57pub(crate) const OP_I_MUL: u32 = 132;
58pub(crate) const OP_F_MUL: u32 = 133;
59pub(crate) const OP_U_DIV: u32 = 134;
60pub(crate) const OP_F_DIV: u32 = 136;
61pub(crate) const OP_U_MOD: u32 = 137;
62pub(crate) const OP_U_LESS_THAN: u32 = 176;
63pub(crate) const OP_LOOP_MERGE: u32 = 246;
64pub(crate) const OP_SELECTION_MERGE: u32 = 247;
65pub(crate) const OP_LABEL: u32 = 248;
66pub(crate) const OP_BRANCH: u32 = 249;
67pub(crate) const OP_BRANCH_CONDITIONAL: u32 = 250;
68pub(crate) const OP_RETURN: u32 = 253;
69
70// Capabilities
71const CAPABILITY_SHADER: u32 = 1;
72const CAPABILITY_ADDRESSES: u32 = 4;
73const CAPABILITY_KERNEL: u32 = 6;
74
75// Addressing / memory model
76const ADDRESSING_MODEL_LOGICAL: u32 = 0;
77const ADDRESSING_MODEL_PHYSICAL64: u32 = 2;
78const MEMORY_MODEL_GLSL450: u32 = 1;
79const MEMORY_MODEL_OPENCL: u32 = 2;
80
81// Execution model / mode
82const EXECUTION_MODEL_GLCOMPUTE: u32 = 5;
83pub(crate) const EXECUTION_MODEL_KERNEL: u32 = 6;
84const EXECUTION_MODE_LOCAL_SIZE: u32 = 17;
85
86// Function control
87pub(crate) const FUNCTION_CONTROL_NONE: u32 = 0;
88
89// Decorations
90const DECORATION_BUILTIN: u32 = 11;
91
92// BuiltIn values
93const BUILTIN_GLOBAL_INVOCATION_ID: u32 = 28;
94
95// Storage classes
96const STORAGE_CLASS_INPUT: u32 = 1;
97const STORAGE_CLASS_CROSS_WORKGROUP: u32 = 5;
98pub(crate) const STORAGE_CLASS_FUNCTION: u32 = 7;
99
100// Selection/loop control
101const SELECTION_CONTROL_NONE: u32 = 0;
102const LOOP_CONTROL_NONE: u32 = 0;
103
104// OpenCL.std extended instruction numbers
105pub(crate) const OPENCL_EXP: u32 = 19;
106const OPENCL_FABS: u32 = 23;
107pub(crate) const OPENCL_FMAX: u32 = 27;
108const OPENCL_FMIN: u32 = 28;
109const OPENCL_LOG: u32 = 37;
110const OPENCL_SQRT: u32 = 61;
111const OPENCL_TANH: u32 = 63;
112
113/// Workgroup size for 1-D compute kernels.
114pub(crate) const WORKGROUP_SIZE: u32 = 256;
115
116// ─── Minimal SPIR-V builder ──────────────────────────────────
117
118/// Lightweight SPIR-V word-stream builder.
119///
120/// Emits valid SPIR-V instructions for simple compute shaders without
121/// pulling in a full compiler.
122pub struct SpvModule {
123    words: Vec<u32>,
124    /// Next available result ID.
125    id_bound: u32,
126}
127
128impl SpvModule {
129    /// Create a new module with a placeholder header (bound filled at finalise).
130    pub fn new() -> Self {
131        let words = vec![SPIRV_MAGIC, SPIRV_VERSION_1_2, SPIRV_GENERATOR, 0, 0];
132        Self { words, id_bound: 1 }
133    }
134
135    /// Allocate a fresh result ID.
136    pub fn alloc_id(&mut self) -> u32 {
137        let id = self.id_bound;
138        self.id_bound += 1;
139        id
140    }
141
142    /// Emit a SPIR-V instruction.
143    pub fn emit(&mut self, opcode: u32, operands: &[u32]) {
144        let word_count = (1 + operands.len()) as u32;
145        self.words.push((word_count << 16) | opcode);
146        self.words.extend_from_slice(operands);
147    }
148
149    /// Emit a string as null-terminated UTF-8 packed into 32-bit words.
150    pub fn string_words(s: &str) -> Vec<u32> {
151        let bytes = s.as_bytes();
152        let padded_len = (bytes.len() + 4) & !3;
153        let mut out = vec![0u32; padded_len / 4];
154        for (i, &b) in bytes.iter().enumerate() {
155            out[i / 4] |= (b as u32) << ((i % 4) * 8);
156        }
157        out
158    }
159
160    /// Finalise the module: patch the ID bound and return the word vector.
161    pub fn finalize(mut self) -> Vec<u32> {
162        self.words[3] = self.id_bound;
163        self.words
164    }
165
166    // ── Convenience emitters ─────────────────────────────────
167
168    pub(crate) fn emit_capability(&mut self, cap: u32) {
169        self.emit(OP_CAPABILITY, &[cap]);
170    }
171
172    pub(crate) fn emit_ext_inst_import(&mut self, id: u32, name: &str) {
173        let mut ops = vec![id];
174        ops.extend(Self::string_words(name));
175        self.emit(OP_EXT_INST_IMPORT, &ops);
176    }
177
178    pub(crate) fn emit_memory_model(&mut self, addressing: u32, memory: u32) {
179        self.emit(OP_MEMORY_MODEL, &[addressing, memory]);
180    }
181
182    pub(crate) fn emit_entry_point(
183        &mut self,
184        model: u32,
185        func_id: u32,
186        name: &str,
187        interfaces: &[u32],
188    ) {
189        let mut ops = vec![model, func_id];
190        ops.extend(Self::string_words(name));
191        ops.extend_from_slice(interfaces);
192        self.emit(OP_ENTRY_POINT, &ops);
193    }
194
195    pub(crate) fn emit_execution_mode_local_size(&mut self, func_id: u32, x: u32, y: u32, z: u32) {
196        self.emit(
197            OP_EXECUTION_MODE,
198            &[func_id, EXECUTION_MODE_LOCAL_SIZE, x, y, z],
199        );
200    }
201
202    pub(crate) fn emit_decorate(&mut self, target: u32, decoration: u32, operands: &[u32]) {
203        let mut ops = vec![target, decoration];
204        ops.extend_from_slice(operands);
205        self.emit(OP_DECORATE, &ops);
206    }
207
208    pub(crate) fn emit_type_void(&mut self, id: u32) {
209        self.emit(OP_TYPE_VOID, &[id]);
210    }
211
212    pub(crate) fn emit_type_bool(&mut self, id: u32) {
213        self.emit(OP_TYPE_BOOL, &[id]);
214    }
215
216    pub(crate) fn emit_type_int(&mut self, id: u32, width: u32, signedness: u32) {
217        self.emit(OP_TYPE_INT, &[id, width, signedness]);
218    }
219
220    pub(crate) fn emit_type_float(&mut self, id: u32, width: u32) {
221        self.emit(OP_TYPE_FLOAT, &[id, width]);
222    }
223
224    pub(crate) fn emit_type_vector(&mut self, id: u32, component: u32, count: u32) {
225        self.emit(OP_TYPE_VECTOR, &[id, component, count]);
226    }
227
228    pub(crate) fn emit_type_pointer(&mut self, id: u32, storage_class: u32, pointee: u32) {
229        self.emit(OP_TYPE_POINTER, &[id, storage_class, pointee]);
230    }
231
232    pub(crate) fn emit_type_function(&mut self, id: u32, return_type: u32, params: &[u32]) {
233        let mut ops = vec![id, return_type];
234        ops.extend_from_slice(params);
235        self.emit(OP_TYPE_FUNCTION, &ops);
236    }
237
238    pub(crate) fn emit_constant_u32(&mut self, ty: u32, id: u32, value: u32) {
239        self.emit(OP_CONSTANT, &[ty, id, value]);
240    }
241
242    pub(crate) fn emit_constant_f32(&mut self, ty: u32, id: u32, value: f32) {
243        self.emit(OP_CONSTANT, &[ty, id, value.to_bits()]);
244    }
245
246    pub(crate) fn emit_variable(&mut self, ty: u32, id: u32, storage_class: u32) {
247        self.emit(OP_VARIABLE, &[ty, id, storage_class]);
248    }
249
250    pub(crate) fn emit_load(&mut self, result_ty: u32, result: u32, pointer: u32) {
251        self.emit(OP_LOAD, &[result_ty, result, pointer]);
252    }
253
254    pub(crate) fn emit_store(&mut self, pointer: u32, value: u32) {
255        self.emit(OP_STORE, &[pointer, value]);
256    }
257
258    pub(crate) fn emit_in_bounds_ptr_access_chain(
259        &mut self,
260        result_ty: u32,
261        result: u32,
262        base: u32,
263        element: u32,
264    ) {
265        self.emit(
266            OP_IN_BOUNDS_PTR_ACCESS_CHAIN,
267            &[result_ty, result, base, element],
268        );
269    }
270
271    pub(crate) fn emit_function(&mut self, result_ty: u32, result: u32, control: u32, fn_ty: u32) {
272        self.emit(OP_FUNCTION, &[result_ty, result, control, fn_ty]);
273    }
274
275    pub(crate) fn emit_function_parameter(&mut self, result_ty: u32, result: u32) {
276        self.emit(OP_FUNCTION_PARAMETER, &[result_ty, result]);
277    }
278
279    pub(crate) fn emit_label(&mut self, id: u32) {
280        self.emit(OP_LABEL, &[id]);
281    }
282
283    pub(crate) fn emit_return(&mut self) {
284        self.emit(OP_RETURN, &[]);
285    }
286
287    pub(crate) fn emit_function_end(&mut self) {
288        self.emit(OP_FUNCTION_END, &[]);
289    }
290
291    pub(crate) fn emit_branch(&mut self, target: u32) {
292        self.emit(OP_BRANCH, &[target]);
293    }
294
295    pub(crate) fn emit_branch_conditional(&mut self, cond: u32, true_label: u32, false_label: u32) {
296        self.emit(OP_BRANCH_CONDITIONAL, &[cond, true_label, false_label]);
297    }
298
299    pub(crate) fn emit_selection_merge(&mut self, merge_label: u32) {
300        self.emit(OP_SELECTION_MERGE, &[merge_label, SELECTION_CONTROL_NONE]);
301    }
302
303    pub(crate) fn emit_loop_merge(&mut self, merge_label: u32, continue_label: u32) {
304        self.emit(
305            OP_LOOP_MERGE,
306            &[merge_label, continue_label, LOOP_CONTROL_NONE],
307        );
308    }
309
310    pub(crate) fn emit_opencl_ext(
311        &mut self,
312        ext_id: u32,
313        result_ty: u32,
314        result: u32,
315        inst: u32,
316        args: &[u32],
317    ) {
318        let mut ops = vec![result_ty, result, ext_id, inst];
319        ops.extend_from_slice(args);
320        self.emit(OP_EXT_INST, &ops);
321    }
322}
323
324impl Default for SpvModule {
325    fn default() -> Self {
326        Self::new()
327    }
328}
329
330// ─── Common preamble for OpenCL SPIR-V kernels ──────────────
331
332/// IDs shared by all compute kernels.
333pub(crate) struct BaseIds {
334    pub(crate) ty_void: u32,
335    pub(crate) ty_bool: u32,
336    pub(crate) ty_uint: u32,
337    pub(crate) ty_float: u32,
338    #[allow(dead_code)]
339    pub(crate) ty_v3uint: u32,
340    #[allow(dead_code)]
341    pub(crate) ty_fn_void: u32,
342    #[allow(dead_code)]
343    pub(crate) ty_ptr_input_v3uint: u32,
344    pub(crate) ty_ptr_cross_float: u32,
345    pub(crate) ty_ptr_func_float: u32,
346    pub(crate) ty_ptr_func_uint: u32,
347    pub(crate) c_uint_0: u32,
348    pub(crate) c_uint_1: u32,
349    pub(crate) c_float_0: u32,
350    pub(crate) c_float_1: u32,
351    pub(crate) var_gid: u32,
352    pub(crate) opencl_ext: u32,
353}
354
355/// Emit the preamble shared by all OpenCL-style compute kernels.
356///
357/// This emits capabilities, memory model, types, constants, and the
358/// `GlobalInvocationId` Input variable.  The caller must separately emit
359/// `OpEntryPoint`, `OpExecutionMode`, and the function body.
360pub(crate) fn emit_preamble(m: &mut SpvModule) -> BaseIds {
361    let ty_void = m.alloc_id();
362    let ty_bool = m.alloc_id();
363    let ty_uint = m.alloc_id();
364    let ty_float = m.alloc_id();
365    let ty_v3uint = m.alloc_id();
366    let ty_fn_void = m.alloc_id();
367    let ty_ptr_input_v3uint = m.alloc_id();
368    let ty_ptr_cross_float = m.alloc_id();
369    let ty_ptr_func_float = m.alloc_id();
370    let ty_ptr_func_uint = m.alloc_id();
371    let c_uint_0 = m.alloc_id();
372    let c_uint_1 = m.alloc_id();
373    let c_float_0 = m.alloc_id();
374    let c_float_1 = m.alloc_id();
375    let var_gid = m.alloc_id();
376    let opencl_ext = m.alloc_id();
377
378    // Capabilities
379    m.emit_capability(CAPABILITY_KERNEL);
380    m.emit_capability(CAPABILITY_ADDRESSES);
381
382    // Extension import
383    m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
384
385    // Memory model: Physical64 + OpenCL
386    m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
387
388    // NOTE: OpEntryPoint and OpExecutionMode are emitted by the caller after
389    // allocating the main function ID, so we skip them here.
390
391    // Decoration: GlobalInvocationId on var_gid
392    m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
393
394    // Types
395    m.emit_type_void(ty_void);
396    m.emit_type_bool(ty_bool);
397    m.emit_type_int(ty_uint, 32, 0);
398    m.emit_type_float(ty_float, 32);
399    m.emit_type_vector(ty_v3uint, ty_uint, 3);
400    m.emit_type_function(ty_fn_void, ty_void, &[]);
401    m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
402    m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
403    m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
404    m.emit_type_pointer(ty_ptr_func_uint, STORAGE_CLASS_FUNCTION, ty_uint);
405
406    // Constants
407    m.emit_constant_u32(ty_uint, c_uint_0, 0);
408    m.emit_constant_u32(ty_uint, c_uint_1, 1);
409    m.emit_constant_f32(ty_float, c_float_0, 0.0);
410    m.emit_constant_f32(ty_float, c_float_1, 1.0);
411
412    // GlobalInvocationId input variable
413    m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
414
415    BaseIds {
416        ty_void,
417        ty_bool,
418        ty_uint,
419        ty_float,
420        ty_v3uint,
421        ty_fn_void,
422        ty_ptr_input_v3uint,
423        ty_ptr_cross_float,
424        ty_ptr_func_float,
425        ty_ptr_func_uint,
426        c_uint_0,
427        c_uint_1,
428        c_float_0,
429        c_float_1,
430        var_gid,
431        opencl_ext,
432    }
433}
434
435/// Load `GlobalInvocationId.x` into a uint result.
436pub(crate) fn load_gid_x(m: &mut SpvModule, b: &BaseIds) -> u32 {
437    let gid_val = m.alloc_id();
438    m.emit_load(b.ty_v3uint, gid_val, b.var_gid);
439    let gid_x = m.alloc_id();
440    m.emit(OP_COMPOSITE_EXTRACT, &[b.ty_uint, gid_x, gid_val, 0]);
441    gid_x
442}
443
444// ─── Unary compute kernel ───────────────────────────────────
445
446/// Generate an OpenCL SPIR-V compute kernel for an element-wise unary operation.
447///
448/// Kernel parameters: `(CrossWorkgroup float* input, CrossWorkgroup float* output, uint count)`.
449pub fn unary_compute_shader(op: UnaryOp) -> Vec<u32> {
450    let mut m = SpvModule::new();
451    let b = emit_preamble(&mut m);
452
453    let main_fn = m.alloc_id();
454    let fn_ty = m.alloc_id();
455    let p_input = m.alloc_id();
456    let p_output = m.alloc_id();
457    let p_count = m.alloc_id();
458
459    // Function type: void(CrossWorkgroup float*, CrossWorkgroup float*, uint)
460    m.emit_type_function(
461        fn_ty,
462        b.ty_void,
463        &[b.ty_ptr_cross_float, b.ty_ptr_cross_float, b.ty_uint],
464    );
465
466    // Entry point and execution mode
467    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
468    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
469
470    // Labels
471    let label_entry = m.alloc_id();
472    let label_body = m.alloc_id();
473    let label_merge = m.alloc_id();
474
475    // Function
476    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
477    m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
478    m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
479    m.emit_function_parameter(b.ty_uint, p_count);
480    m.emit_label(label_entry);
481
482    let gid = load_gid_x(&mut m, &b);
483
484    // Bounds check: gid < count
485    let cond = m.alloc_id();
486    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, p_count]);
487    m.emit_selection_merge(label_merge);
488    m.emit_branch_conditional(cond, label_body, label_merge);
489
490    m.emit_label(label_body);
491
492    // input_ptr = &input[gid]
493    let inp_ptr = m.alloc_id();
494    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, gid);
495    let inp_val = m.alloc_id();
496    m.emit_load(b.ty_float, inp_val, inp_ptr);
497
498    let result = emit_unary_op(&mut m, &b, op, inp_val);
499
500    // output_ptr = &output[gid]
501    let out_ptr = m.alloc_id();
502    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
503    m.emit_store(out_ptr, result);
504
505    m.emit_branch(label_merge);
506
507    m.emit_label(label_merge);
508    m.emit_return();
509    m.emit_function_end();
510
511    m.finalize()
512}
513
514/// Emit the SPIR-V instructions for a unary operation, returning the result ID.
515fn emit_unary_op(m: &mut SpvModule, b: &BaseIds, op: UnaryOp, x: u32) -> u32 {
516    let result = m.alloc_id();
517    match op {
518        UnaryOp::Relu => {
519            m.emit_opencl_ext(
520                b.opencl_ext,
521                b.ty_float,
522                result,
523                OPENCL_FMAX,
524                &[b.c_float_0, x],
525            );
526        }
527        UnaryOp::Sigmoid => {
528            let neg_x = m.alloc_id();
529            m.emit(OP_F_NEGATE, &[b.ty_float, neg_x, x]);
530            let exp_neg_x = m.alloc_id();
531            m.emit_opencl_ext(b.opencl_ext, b.ty_float, exp_neg_x, OPENCL_EXP, &[neg_x]);
532            let one_plus = m.alloc_id();
533            m.emit(OP_F_ADD, &[b.ty_float, one_plus, b.c_float_1, exp_neg_x]);
534            m.emit(OP_F_DIV, &[b.ty_float, result, b.c_float_1, one_plus]);
535        }
536        UnaryOp::Tanh => {
537            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_TANH, &[x]);
538        }
539        UnaryOp::Exp => {
540            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_EXP, &[x]);
541        }
542        UnaryOp::Log => {
543            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_LOG, &[x]);
544        }
545        UnaryOp::Sqrt => {
546            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_SQRT, &[x]);
547        }
548        UnaryOp::Abs => {
549            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FABS, &[x]);
550        }
551        UnaryOp::Neg => {
552            m.emit(OP_F_NEGATE, &[b.ty_float, result, x]);
553        }
554    }
555    result
556}
557
558// ─── Binary compute kernel ──────────────────────────────────
559
560/// Generate an OpenCL SPIR-V compute kernel for an element-wise binary operation.
561///
562/// Kernel parameters: `(CrossWorkgroup float* a, CrossWorkgroup float* b,
563///                      CrossWorkgroup float* output, uint count)`.
564pub fn binary_compute_shader(op: BinaryOp) -> Vec<u32> {
565    let mut m = SpvModule::new();
566    let b = emit_preamble(&mut m);
567
568    let main_fn = m.alloc_id();
569    let fn_ty = m.alloc_id();
570    let p_a = m.alloc_id();
571    let p_b = m.alloc_id();
572    let p_out = m.alloc_id();
573    let p_count = m.alloc_id();
574
575    // Function type: void(float*, float*, float*, uint)
576    m.emit_type_function(
577        fn_ty,
578        b.ty_void,
579        &[
580            b.ty_ptr_cross_float,
581            b.ty_ptr_cross_float,
582            b.ty_ptr_cross_float,
583            b.ty_uint,
584        ],
585    );
586
587    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
588    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
589
590    let label_entry = m.alloc_id();
591    let label_body = m.alloc_id();
592    let label_merge = m.alloc_id();
593
594    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
595    m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
596    m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
597    m.emit_function_parameter(b.ty_ptr_cross_float, p_out);
598    m.emit_function_parameter(b.ty_uint, p_count);
599    m.emit_label(label_entry);
600
601    let gid = load_gid_x(&mut m, &b);
602
603    let cond = m.alloc_id();
604    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, p_count]);
605    m.emit_selection_merge(label_merge);
606    m.emit_branch_conditional(cond, label_body, label_merge);
607
608    m.emit_label(label_body);
609
610    let a_ptr = m.alloc_id();
611    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, p_a, gid);
612    let a_val = m.alloc_id();
613    m.emit_load(b.ty_float, a_val, a_ptr);
614
615    let b_ptr = m.alloc_id();
616    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, p_b, gid);
617    let b_val = m.alloc_id();
618    m.emit_load(b.ty_float, b_val, b_ptr);
619
620    let result = emit_binary_op(&mut m, &b, op, a_val, b_val);
621
622    let out_ptr = m.alloc_id();
623    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_out, gid);
624    m.emit_store(out_ptr, result);
625
626    m.emit_branch(label_merge);
627
628    m.emit_label(label_merge);
629    m.emit_return();
630    m.emit_function_end();
631
632    m.finalize()
633}
634
635fn emit_binary_op(m: &mut SpvModule, b: &BaseIds, op: BinaryOp, lhs: u32, rhs: u32) -> u32 {
636    let result = m.alloc_id();
637    match op {
638        BinaryOp::Add => m.emit(OP_F_ADD, &[b.ty_float, result, lhs, rhs]),
639        BinaryOp::Sub => m.emit(OP_F_SUB, &[b.ty_float, result, lhs, rhs]),
640        BinaryOp::Mul => m.emit(OP_F_MUL, &[b.ty_float, result, lhs, rhs]),
641        BinaryOp::Div => m.emit(OP_F_DIV, &[b.ty_float, result, lhs, rhs]),
642        BinaryOp::Max => {
643            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FMAX, &[lhs, rhs]);
644        }
645        BinaryOp::Min => {
646            m.emit_opencl_ext(b.opencl_ext, b.ty_float, result, OPENCL_FMIN, &[lhs, rhs]);
647        }
648    }
649    result
650}
651
652// ─── Reduce compute kernel ──────────────────────────────────
653
654/// Generate an OpenCL SPIR-V compute kernel for reduction along an axis.
655///
656/// Kernel parameters: `(CrossWorkgroup float* input, CrossWorkgroup float* output,
657///                      uint outer_size, uint reduce_size, uint inner_size)`.
658///
659/// Each thread computes one output element by iterating over the reduce dimension.
660pub fn reduce_compute_shader(op: ReduceOp) -> Vec<u32> {
661    let mut m = SpvModule::new();
662    let b = emit_preamble(&mut m);
663
664    let main_fn = m.alloc_id();
665    let fn_ty = m.alloc_id();
666    let p_input = m.alloc_id();
667    let p_output = m.alloc_id();
668    let p_outer = m.alloc_id();
669    let p_reduce = m.alloc_id();
670    let p_inner = m.alloc_id();
671
672    // Function type: void(float*, float*, uint, uint, uint)
673    m.emit_type_function(
674        fn_ty,
675        b.ty_void,
676        &[
677            b.ty_ptr_cross_float,
678            b.ty_ptr_cross_float,
679            b.ty_uint,
680            b.ty_uint,
681            b.ty_uint,
682        ],
683    );
684
685    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
686    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
687
688    let label_entry = m.alloc_id();
689    let label_bounds_body = m.alloc_id();
690    let label_bounds_merge = m.alloc_id();
691    let label_loop_header = m.alloc_id();
692    let label_loop_body = m.alloc_id();
693    let label_loop_continue = m.alloc_id();
694    let label_loop_merge = m.alloc_id();
695
696    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
697    m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
698    m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
699    m.emit_function_parameter(b.ty_uint, p_outer);
700    m.emit_function_parameter(b.ty_uint, p_reduce);
701    m.emit_function_parameter(b.ty_uint, p_inner);
702    m.emit_label(label_entry);
703
704    let gid = load_gid_x(&mut m, &b);
705
706    // total_output = outer_size * inner_size
707    let total_output = m.alloc_id();
708    m.emit(OP_I_MUL, &[b.ty_uint, total_output, p_outer, p_inner]);
709
710    // Bounds check: gid < total_output
711    let cond_bounds = m.alloc_id();
712    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond_bounds, gid, total_output]);
713    m.emit_selection_merge(label_bounds_merge);
714    m.emit_branch_conditional(cond_bounds, label_bounds_body, label_bounds_merge);
715
716    m.emit_label(label_bounds_body);
717
718    // outer_idx = gid / inner_size, inner_idx = gid % inner_size
719    let outer_idx = m.alloc_id();
720    m.emit(OP_U_DIV, &[b.ty_uint, outer_idx, gid, p_inner]);
721    let inner_idx = m.alloc_id();
722    m.emit(OP_U_MOD, &[b.ty_uint, inner_idx, gid, p_inner]);
723
724    // base = outer_idx * reduce_size * inner_size + inner_idx
725    let t1 = m.alloc_id();
726    m.emit(OP_I_MUL, &[b.ty_uint, t1, outer_idx, p_reduce]);
727    let t2 = m.alloc_id();
728    m.emit(OP_I_MUL, &[b.ty_uint, t2, t1, p_inner]);
729    let base_idx = m.alloc_id();
730    m.emit(OP_I_ADD, &[b.ty_uint, base_idx, t2, inner_idx]);
731
732    // Loop counter
733    let var_i = m.alloc_id();
734    m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
735    m.emit_store(var_i, b.c_uint_0);
736
737    // Accumulator
738    let var_acc = m.alloc_id();
739    m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
740    let init_val = match op {
741        ReduceOp::Sum | ReduceOp::Mean => b.c_float_0,
742        ReduceOp::Max => {
743            let neg_inf = m.alloc_id();
744            m.emit_constant_f32(b.ty_float, neg_inf, f32::NEG_INFINITY);
745            neg_inf
746        }
747        ReduceOp::Min => {
748            let pos_inf = m.alloc_id();
749            m.emit_constant_f32(b.ty_float, pos_inf, f32::INFINITY);
750            pos_inf
751        }
752    };
753    m.emit_store(var_acc, init_val);
754
755    m.emit_branch(label_loop_header);
756
757    // ── Loop header ──
758    m.emit_label(label_loop_header);
759    let i_val = m.alloc_id();
760    m.emit_load(b.ty_uint, i_val, var_i);
761    let loop_cond = m.alloc_id();
762    m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_reduce]);
763    m.emit_loop_merge(label_loop_merge, label_loop_continue);
764    m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
765
766    // ── Loop body ──
767    m.emit_label(label_loop_body);
768
769    // input_idx = base_idx + i * inner_size
770    let i_times_inner = m.alloc_id();
771    m.emit(OP_I_MUL, &[b.ty_uint, i_times_inner, i_val, p_inner]);
772    let input_idx = m.alloc_id();
773    m.emit(OP_I_ADD, &[b.ty_uint, input_idx, base_idx, i_times_inner]);
774
775    let inp_ptr = m.alloc_id();
776    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, input_idx);
777    let inp_val = m.alloc_id();
778    m.emit_load(b.ty_float, inp_val, inp_ptr);
779
780    let acc_val = m.alloc_id();
781    m.emit_load(b.ty_float, acc_val, var_acc);
782
783    let new_acc = m.alloc_id();
784    match op {
785        ReduceOp::Sum | ReduceOp::Mean => {
786            m.emit(OP_F_ADD, &[b.ty_float, new_acc, acc_val, inp_val]);
787        }
788        ReduceOp::Max => {
789            m.emit_opencl_ext(
790                b.opencl_ext,
791                b.ty_float,
792                new_acc,
793                OPENCL_FMAX,
794                &[acc_val, inp_val],
795            );
796        }
797        ReduceOp::Min => {
798            m.emit_opencl_ext(
799                b.opencl_ext,
800                b.ty_float,
801                new_acc,
802                OPENCL_FMIN,
803                &[acc_val, inp_val],
804            );
805        }
806    }
807    m.emit_store(var_acc, new_acc);
808
809    m.emit_branch(label_loop_continue);
810
811    // ── Loop continue ──
812    m.emit_label(label_loop_continue);
813    let i_inc = m.alloc_id();
814    m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
815    m.emit_store(var_i, i_inc);
816    m.emit_branch(label_loop_header);
817
818    // ── Loop merge ──
819    m.emit_label(label_loop_merge);
820
821    let final_acc = m.alloc_id();
822    m.emit_load(b.ty_float, final_acc, var_acc);
823
824    let store_val = if op == ReduceOp::Mean {
825        let reduce_f = m.alloc_id();
826        m.emit(OP_CONVERT_U_TO_F, &[b.ty_float, reduce_f, p_reduce]);
827        let mean_val = m.alloc_id();
828        m.emit(OP_F_DIV, &[b.ty_float, mean_val, final_acc, reduce_f]);
829        mean_val
830    } else {
831        final_acc
832    };
833
834    let out_ptr = m.alloc_id();
835    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
836    m.emit_store(out_ptr, store_val);
837
838    m.emit_branch(label_bounds_merge);
839
840    m.emit_label(label_bounds_merge);
841    m.emit_return();
842    m.emit_function_end();
843
844    m.finalize()
845}
846
847// ─── GEMM compute kernel ────────────────────────────────────
848
849/// Generate an OpenCL SPIR-V compute kernel for GEMM: `C = alpha * A * B + beta * C`.
850///
851/// Naive reference (one thread per output element, row-major f32 layout).
852///
853/// Kernel parameters: `(CrossWorkgroup float* A, CrossWorkgroup float* B,
854///                      CrossWorkgroup float* C, uint m, uint n, uint k,
855///                      float alpha, float beta)`.
856pub fn gemm_compute_shader() -> Vec<u32> {
857    let mut m = SpvModule::new();
858    let b = emit_preamble(&mut m);
859
860    let main_fn = m.alloc_id();
861    let fn_ty = m.alloc_id();
862    let p_a = m.alloc_id();
863    let p_b = m.alloc_id();
864    let p_c = m.alloc_id();
865    let p_m = m.alloc_id();
866    let p_n = m.alloc_id();
867    let p_k = m.alloc_id();
868    let p_alpha = m.alloc_id();
869    let p_beta = m.alloc_id();
870
871    // Function type: void(float*, float*, float*, uint, uint, uint, float, float)
872    m.emit_type_function(
873        fn_ty,
874        b.ty_void,
875        &[
876            b.ty_ptr_cross_float,
877            b.ty_ptr_cross_float,
878            b.ty_ptr_cross_float,
879            b.ty_uint,
880            b.ty_uint,
881            b.ty_uint,
882            b.ty_float,
883            b.ty_float,
884        ],
885    );
886
887    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
888    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
889
890    let label_entry = m.alloc_id();
891    let label_bounds_body = m.alloc_id();
892    let label_bounds_merge = m.alloc_id();
893    let label_loop_header = m.alloc_id();
894    let label_loop_body = m.alloc_id();
895    let label_loop_continue = m.alloc_id();
896    let label_loop_merge = m.alloc_id();
897
898    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
899    m.emit_function_parameter(b.ty_ptr_cross_float, p_a);
900    m.emit_function_parameter(b.ty_ptr_cross_float, p_b);
901    m.emit_function_parameter(b.ty_ptr_cross_float, p_c);
902    m.emit_function_parameter(b.ty_uint, p_m);
903    m.emit_function_parameter(b.ty_uint, p_n);
904    m.emit_function_parameter(b.ty_uint, p_k);
905    m.emit_function_parameter(b.ty_float, p_alpha);
906    m.emit_function_parameter(b.ty_float, p_beta);
907    m.emit_label(label_entry);
908
909    let gid = load_gid_x(&mut m, &b);
910
911    // total = m * n
912    let total = m.alloc_id();
913    m.emit(OP_I_MUL, &[b.ty_uint, total, p_m, p_n]);
914
915    // Bounds check
916    let cond = m.alloc_id();
917    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, total]);
918    m.emit_selection_merge(label_bounds_merge);
919    m.emit_branch_conditional(cond, label_bounds_body, label_bounds_merge);
920
921    m.emit_label(label_bounds_body);
922
923    // row = gid / n, col = gid % n
924    let row = m.alloc_id();
925    m.emit(OP_U_DIV, &[b.ty_uint, row, gid, p_n]);
926    let col = m.alloc_id();
927    m.emit(OP_U_MOD, &[b.ty_uint, col, gid, p_n]);
928
929    // Loop counter + accumulator
930    let var_i = m.alloc_id();
931    m.emit_variable(b.ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
932    m.emit_store(var_i, b.c_uint_0);
933    let var_acc = m.alloc_id();
934    m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
935    m.emit_store(var_acc, b.c_float_0);
936
937    m.emit_branch(label_loop_header);
938
939    // ── Loop header ──
940    m.emit_label(label_loop_header);
941    let i_val = m.alloc_id();
942    m.emit_load(b.ty_uint, i_val, var_i);
943    let loop_cond = m.alloc_id();
944    m.emit(OP_U_LESS_THAN, &[b.ty_bool, loop_cond, i_val, p_k]);
945    m.emit_loop_merge(label_loop_merge, label_loop_continue);
946    m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
947
948    // ── Loop body ──
949    m.emit_label(label_loop_body);
950
951    // a_idx = row * k + i
952    let row_k = m.alloc_id();
953    m.emit(OP_I_MUL, &[b.ty_uint, row_k, row, p_k]);
954    let a_idx = m.alloc_id();
955    m.emit(OP_I_ADD, &[b.ty_uint, a_idx, row_k, i_val]);
956
957    // b_idx = i * n + col
958    let i_n = m.alloc_id();
959    m.emit(OP_I_MUL, &[b.ty_uint, i_n, i_val, p_n]);
960    let b_idx = m.alloc_id();
961    m.emit(OP_I_ADD, &[b.ty_uint, b_idx, i_n, col]);
962
963    let a_ptr = m.alloc_id();
964    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, a_ptr, p_a, a_idx);
965    let a_val = m.alloc_id();
966    m.emit_load(b.ty_float, a_val, a_ptr);
967
968    let b_ptr = m.alloc_id();
969    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, b_ptr, p_b, b_idx);
970    let b_val = m.alloc_id();
971    m.emit_load(b.ty_float, b_val, b_ptr);
972
973    let prod = m.alloc_id();
974    m.emit(OP_F_MUL, &[b.ty_float, prod, a_val, b_val]);
975    let old_acc = m.alloc_id();
976    m.emit_load(b.ty_float, old_acc, var_acc);
977    let new_acc = m.alloc_id();
978    m.emit(OP_F_ADD, &[b.ty_float, new_acc, old_acc, prod]);
979    m.emit_store(var_acc, new_acc);
980
981    m.emit_branch(label_loop_continue);
982
983    // ── Loop continue ──
984    m.emit_label(label_loop_continue);
985    let i_inc = m.alloc_id();
986    m.emit(OP_I_ADD, &[b.ty_uint, i_inc, i_val, b.c_uint_1]);
987    m.emit_store(var_i, i_inc);
988    m.emit_branch(label_loop_header);
989
990    // ── Loop merge ──
991    m.emit_label(label_loop_merge);
992
993    // result = alpha * acc + beta * C[gid]
994    let final_acc = m.alloc_id();
995    m.emit_load(b.ty_float, final_acc, var_acc);
996    let alpha_acc = m.alloc_id();
997    m.emit(OP_F_MUL, &[b.ty_float, alpha_acc, p_alpha, final_acc]);
998
999    let c_ptr = m.alloc_id();
1000    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, c_ptr, p_c, gid);
1001    let c_old = m.alloc_id();
1002    m.emit_load(b.ty_float, c_old, c_ptr);
1003    let beta_c = m.alloc_id();
1004    m.emit(OP_F_MUL, &[b.ty_float, beta_c, p_beta, c_old]);
1005    let c_new = m.alloc_id();
1006    m.emit(OP_F_ADD, &[b.ty_float, c_new, alpha_acc, beta_c]);
1007    m.emit_store(c_ptr, c_new);
1008
1009    m.emit_branch(label_bounds_merge);
1010
1011    m.emit_label(label_bounds_merge);
1012    m.emit_return();
1013    m.emit_function_end();
1014
1015    m.finalize()
1016}
1017
1018// ─── Trivial placeholder ────────────────────────────────────
1019
1020/// Build a minimal valid Shader-style compute shader: `void main() {}`.
1021///
1022/// Uses `GLCompute` / Shader capability for basic Level Zero module validation.
1023pub fn trivial_compute_shader() -> Vec<u32> {
1024    let mut m = SpvModule::new();
1025
1026    let id_main_fn = m.alloc_id();
1027    let id_void = m.alloc_id();
1028    let id_void_fn = m.alloc_id();
1029    let id_label = m.alloc_id();
1030
1031    m.emit_capability(CAPABILITY_SHADER);
1032    m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
1033
1034    let mut entry_words = vec![EXECUTION_MODEL_GLCOMPUTE, id_main_fn];
1035    entry_words.extend(SpvModule::string_words("main"));
1036    m.emit(OP_ENTRY_POINT, &entry_words);
1037
1038    m.emit_execution_mode_local_size(id_main_fn, 1, 1, 1);
1039
1040    m.emit_type_void(id_void);
1041    m.emit_type_function(id_void_fn, id_void, &[]);
1042
1043    m.emit_function(id_void, id_main_fn, FUNCTION_CONTROL_NONE, id_void_fn);
1044    m.emit_label(id_label);
1045    m.emit_return();
1046    m.emit_function_end();
1047
1048    m.finalize()
1049}
1050
1051/// Return the trivial compute shader as a byte slice suitable for
1052/// passing to Level Zero module creation.
1053pub fn trivial_compute_shader_bytes() -> Vec<u8> {
1054    trivial_compute_shader()
1055        .iter()
1056        .flat_map(|w| w.to_ne_bytes())
1057        .collect()
1058}
1059
1060// ─── Tests ──────────────────────────────────────────────────
1061
1062#[cfg(test)]
1063mod tests {
1064    use super::*;
1065
1066    fn check_valid_spirv(words: &[u32]) {
1067        assert!(words.len() >= 5, "too short for SPIR-V header");
1068        assert_eq!(words[0], SPIRV_MAGIC, "bad magic");
1069        assert!(words[3] > 0, "ID bound must be > 0");
1070        assert_eq!(words[4], 0, "schema must be 0");
1071    }
1072
1073    #[test]
1074    fn placeholder_spv_valid_magic() {
1075        let words = trivial_compute_shader();
1076        check_valid_spirv(&words);
1077    }
1078
1079    #[test]
1080    fn placeholder_spv_word_aligned() {
1081        let bytes = trivial_compute_shader_bytes();
1082        assert_eq!(bytes.len() % 4, 0);
1083    }
1084
1085    #[test]
1086    fn placeholder_spv_version_and_schema() {
1087        let words = trivial_compute_shader();
1088        assert!(words[1] >= 0x0001_0000);
1089        assert_eq!(words[4], 0);
1090    }
1091
1092    #[test]
1093    fn placeholder_spv_nonzero_bound() {
1094        let words = trivial_compute_shader();
1095        assert!(words[3] > 0);
1096    }
1097
1098    #[test]
1099    fn spv_module_id_allocation_is_monotonic() {
1100        let mut m = SpvModule::new();
1101        let id1 = m.alloc_id();
1102        let id2 = m.alloc_id();
1103        assert!(id2 > id1);
1104    }
1105
1106    #[test]
1107    fn string_words_null_terminated() {
1108        let words = SpvModule::string_words("abc");
1109        assert!(!words.is_empty());
1110        let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
1111        assert_eq!(bytes[0], b'a');
1112        assert_eq!(bytes[1], b'b');
1113        assert_eq!(bytes[2], b'c');
1114        assert_eq!(bytes[3], 0);
1115    }
1116
1117    #[test]
1118    fn string_words_empty_string() {
1119        let words = SpvModule::string_words("");
1120        assert!(!words.is_empty());
1121        let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
1122        assert_eq!(bytes[0], 0);
1123    }
1124
1125    #[test]
1126    fn generator_magic_is_level_zero() {
1127        assert_eq!(SPIRV_GENERATOR, 0x000D_0002);
1128        assert_ne!(SPIRV_GENERATOR, 0x000D_0001);
1129    }
1130
1131    // ── Compute kernel generation ────────────────────────────
1132
1133    #[test]
1134    fn unary_shader_all_ops() {
1135        let ops = [
1136            UnaryOp::Relu,
1137            UnaryOp::Sigmoid,
1138            UnaryOp::Tanh,
1139            UnaryOp::Exp,
1140            UnaryOp::Log,
1141            UnaryOp::Sqrt,
1142            UnaryOp::Abs,
1143            UnaryOp::Neg,
1144        ];
1145        for op in ops {
1146            let words = unary_compute_shader(op);
1147            check_valid_spirv(&words);
1148        }
1149    }
1150
1151    #[test]
1152    fn binary_shader_all_ops() {
1153        let ops = [
1154            BinaryOp::Add,
1155            BinaryOp::Sub,
1156            BinaryOp::Mul,
1157            BinaryOp::Div,
1158            BinaryOp::Max,
1159            BinaryOp::Min,
1160        ];
1161        for op in ops {
1162            let words = binary_compute_shader(op);
1163            check_valid_spirv(&words);
1164        }
1165    }
1166
1167    #[test]
1168    fn reduce_shader_all_ops() {
1169        let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
1170        for op in ops {
1171            let words = reduce_compute_shader(op);
1172            check_valid_spirv(&words);
1173        }
1174    }
1175
1176    #[test]
1177    fn gemm_shader_valid() {
1178        let words = gemm_compute_shader();
1179        check_valid_spirv(&words);
1180    }
1181
1182    #[test]
1183    fn all_kernel_shaders_word_aligned() {
1184        fn to_bytes(words: &[u32]) -> Vec<u8> {
1185            words.iter().flat_map(|w| w.to_ne_bytes()).collect()
1186        }
1187        assert_eq!(to_bytes(&unary_compute_shader(UnaryOp::Relu)).len() % 4, 0);
1188        assert_eq!(to_bytes(&binary_compute_shader(BinaryOp::Add)).len() % 4, 0);
1189        assert_eq!(to_bytes(&reduce_compute_shader(ReduceOp::Sum)).len() % 4, 0);
1190        assert_eq!(to_bytes(&gemm_compute_shader()).len() % 4, 0);
1191    }
1192
1193    #[test]
1194    fn kernel_shaders_use_opencl_memory_model() {
1195        // Check that kernel shaders use Physical64 + OpenCL memory model,
1196        // while trivial shader uses Logical + GLSL450.
1197        let trivial = trivial_compute_shader();
1198        let unary = unary_compute_shader(UnaryOp::Relu);
1199
1200        // The trivial shader should contain the Shader capability (1)
1201        // The unary shader should contain the Kernel capability (6)
1202        // These appear in OpCapability instructions after the header.
1203
1204        // After header (5 words), first instruction is OpCapability
1205        // Format: (2 << 16) | 17 = 0x00020011
1206        let cap_header = (2u32 << 16) | OP_CAPABILITY;
1207        assert_eq!(trivial[5], cap_header);
1208        assert_eq!(trivial[6], CAPABILITY_SHADER);
1209        assert_eq!(unary[5], cap_header);
1210        assert_eq!(unary[6], CAPABILITY_KERNEL);
1211    }
1212}