Skip to main content

oxicuda_levelzero/
spirv_subgroup.rs

1//! Sub-group optimized SPIR-V kernel generators for Intel GPUs.
2//!
3//! This module provides SPIR-V generators that leverage Intel GPU sub-group
4//! operations (analogous to CUDA warps) for efficient intra-sub-group
5//! communication:
6//!
7//! - [`reduction_subgroup_spirv`] — Two-phase sub-group reduction
8//! - [`scan_subgroup_spirv`] — Inclusive prefix sum via sub-group scan
9//! - [`gemm_subgroup_spirv`] — GEMM with sub-group shuffle for A-row broadcast
10//!
11//! All kernels use the OpenCL SPIR-V execution model (`Kernel`) with
12//! `Physical64`/`OpenCL` memory model and require `GroupNonUniform` family
13//! capabilities.
14
15use crate::spirv::{
16    EXECUTION_MODEL_KERNEL, FUNCTION_CONTROL_NONE, OP_COMPOSITE_EXTRACT, OP_CONTROL_BARRIER,
17    OP_F_ADD, OP_F_MUL, OP_GROUP_NON_UNIFORM_FADD, OP_GROUP_NON_UNIFORM_SHUFFLE, OP_I_ADD,
18    OP_I_MUL, OP_PHI, OP_TYPE_ARRAY, OP_U_LESS_THAN, STORAGE_CLASS_FUNCTION, SpvModule,
19    WORKGROUP_SIZE,
20};
21
22// ── SPIR-V constants (sub-group specific) ────────────────────
23
24// Capabilities
25const CAPABILITY_ADDRESSES: u32 = 4;
26const CAPABILITY_KERNEL: u32 = 6;
27const CAPABILITY_GROUP_NON_UNIFORM: u32 = 61;
28const CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC: u32 = 63;
29const CAPABILITY_GROUP_NON_UNIFORM_SHUFFLE: u32 = 65;
30
31// Addressing / memory model
32const ADDRESSING_MODEL_PHYSICAL64: u32 = 2;
33const MEMORY_MODEL_OPENCL: u32 = 2;
34
35// Decorations
36const DECORATION_BUILTIN: u32 = 11;
37
38// BuiltIn values
39const BUILTIN_GLOBAL_INVOCATION_ID: u32 = 28;
40const BUILTIN_NUM_SUBGROUPS: u32 = 38;
41const BUILTIN_SUBGROUP_ID: u32 = 40;
42const BUILTIN_SUBGROUP_LOCAL_INVOCATION_ID: u32 = 41;
43
44// Storage classes
45const STORAGE_CLASS_INPUT: u32 = 1;
46const STORAGE_CLASS_WORKGROUP: u32 = 4;
47const STORAGE_CLASS_CROSS_WORKGROUP: u32 = 5;
48
49// Scope
50const SCOPE_WORKGROUP: u32 = 2;
51const SCOPE_SUBGROUP: u32 = 3;
52
53// Memory semantics
54const MEMORY_SEMANTICS_WORKGROUP_MEMORY: u32 = 0x100;
55
56// GroupOperation
57const GROUP_OPERATION_REDUCE: u32 = 0;
58const GROUP_OPERATION_INCLUSIVE_SCAN: u32 = 1;
59
60// Magic opcode numbers used inline
61const OP_I_EQUAL: u32 = 170;
62const OP_SELECT: u32 = 169;
63
64/// Maximum sub-groups per workgroup (used for shared-memory scratch array size).
65const MAX_SUBGROUPS: u32 = 32;
66
67// ─── Sub-group optimized reduction kernel ────────────────────
68
69/// Generate an OpenCL SPIR-V compute kernel for sub-group optimized reduction.
70///
71/// Two-phase algorithm:
72/// 1. Each lane reduces its element via `OpGroupNonUniformFAdd` with `Reduce`.
73/// 2. Sub-group leaders write partial sums to workgroup shared memory, barrier,
74///    then sub-group 0 reduces the partial sums and writes the final result.
75///
76/// Kernel parameters: `(CrossWorkgroup float* input, CrossWorkgroup float* output, uint count)`.
77///
78/// Entry point name: `"reduction_subgroup"`.
79pub fn reduction_subgroup_spirv() -> Vec<u32> {
80    let mut m = SpvModule::new();
81
82    // ── Capabilities ──
83    m.emit_capability(CAPABILITY_KERNEL);
84    m.emit_capability(CAPABILITY_ADDRESSES);
85    m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM);
86    m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC);
87
88    // ── Ext import ──
89    let opencl_ext = m.alloc_id();
90    m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
91
92    // ── Memory model ──
93    m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
94
95    // ── Type IDs ──
96    let ty_void = m.alloc_id();
97    let ty_bool = m.alloc_id();
98    let ty_uint = m.alloc_id();
99    let ty_float = m.alloc_id();
100    let ty_v3uint = m.alloc_id();
101    let ty_ptr_input_v3uint = m.alloc_id();
102    let ty_ptr_cross_float = m.alloc_id();
103    let ty_ptr_func_float = m.alloc_id();
104    let ty_ptr_wg_float = m.alloc_id();
105    let ty_ptr_input_uint = m.alloc_id();
106    let ty_arr_float = m.alloc_id();
107    let ty_ptr_wg_arr = m.alloc_id();
108
109    // ── Constants ──
110    let c_uint_0 = m.alloc_id();
111    let c_uint_max_sg = m.alloc_id();
112    let c_float_0 = m.alloc_id();
113    let c_scope_sg = m.alloc_id();
114    let c_scope_wg = m.alloc_id();
115    let c_mem_sem = m.alloc_id();
116
117    // ── Variables ──
118    let var_gid = m.alloc_id();
119    let var_sg_id = m.alloc_id();
120    let var_sg_lid = m.alloc_id();
121    let var_num_sg = m.alloc_id();
122    let var_scratch = m.alloc_id();
123
124    // ── Function ──
125    let fn_ty = m.alloc_id();
126    let main_fn = m.alloc_id();
127    let p_input = m.alloc_id();
128    let p_output = m.alloc_id();
129    let p_count = m.alloc_id();
130
131    // ── Entry point & execution mode ──
132    m.emit_entry_point(
133        EXECUTION_MODEL_KERNEL,
134        main_fn,
135        "reduction_subgroup",
136        &[var_gid, var_sg_id, var_sg_lid, var_num_sg],
137    );
138    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
139
140    // ── Decorations ──
141    m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
142    m.emit_decorate(var_sg_id, DECORATION_BUILTIN, &[BUILTIN_SUBGROUP_ID]);
143    m.emit_decorate(
144        var_sg_lid,
145        DECORATION_BUILTIN,
146        &[BUILTIN_SUBGROUP_LOCAL_INVOCATION_ID],
147    );
148    m.emit_decorate(var_num_sg, DECORATION_BUILTIN, &[BUILTIN_NUM_SUBGROUPS]);
149
150    // ── Types ──
151    m.emit_type_void(ty_void);
152    m.emit_type_bool(ty_bool);
153    m.emit_type_int(ty_uint, 32, 0);
154    m.emit_type_float(ty_float, 32);
155    m.emit_type_vector(ty_v3uint, ty_uint, 3);
156    m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
157    m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
158    m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
159    m.emit_type_pointer(ty_ptr_wg_float, STORAGE_CLASS_WORKGROUP, ty_float);
160    m.emit_type_pointer(ty_ptr_input_uint, STORAGE_CLASS_INPUT, ty_uint);
161    m.emit_type_function(
162        fn_ty,
163        ty_void,
164        &[ty_ptr_cross_float, ty_ptr_cross_float, ty_uint],
165    );
166
167    // ── Constants ──
168    m.emit_constant_u32(ty_uint, c_uint_0, 0);
169    m.emit_constant_u32(ty_uint, c_uint_max_sg, MAX_SUBGROUPS);
170    m.emit_constant_f32(ty_float, c_float_0, 0.0);
171    m.emit_constant_u32(ty_uint, c_scope_sg, SCOPE_SUBGROUP);
172    m.emit_constant_u32(ty_uint, c_scope_wg, SCOPE_WORKGROUP);
173    m.emit_constant_u32(ty_uint, c_mem_sem, MEMORY_SEMANTICS_WORKGROUP_MEMORY);
174
175    // ── Array type for shared scratch ──
176    m.emit(OP_TYPE_ARRAY, &[ty_arr_float, ty_float, c_uint_max_sg]);
177    m.emit_type_pointer(ty_ptr_wg_arr, STORAGE_CLASS_WORKGROUP, ty_arr_float);
178
179    // ── Variables ──
180    m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
181    m.emit_variable(ty_ptr_input_uint, var_sg_id, STORAGE_CLASS_INPUT);
182    m.emit_variable(ty_ptr_input_uint, var_sg_lid, STORAGE_CLASS_INPUT);
183    m.emit_variable(ty_ptr_input_uint, var_num_sg, STORAGE_CLASS_INPUT);
184    m.emit_variable(ty_ptr_wg_arr, var_scratch, STORAGE_CLASS_WORKGROUP);
185
186    // ── Labels ──
187    let label_entry = m.alloc_id();
188    let label_bounds_body = m.alloc_id();
189    let label_bounds_merge = m.alloc_id();
190    let label_leader = m.alloc_id();
191    let label_after_leader = m.alloc_id();
192    let label_sg0 = m.alloc_id();
193    let label_after_sg0 = m.alloc_id();
194
195    // ── Function body ──
196    m.emit_function(ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
197    m.emit_function_parameter(ty_ptr_cross_float, p_input);
198    m.emit_function_parameter(ty_ptr_cross_float, p_output);
199    m.emit_function_parameter(ty_uint, p_count);
200    m.emit_label(label_entry);
201
202    // Load global ID
203    let gid_val = m.alloc_id();
204    m.emit_load(ty_v3uint, gid_val, var_gid);
205    let gid_x = m.alloc_id();
206    m.emit(OP_COMPOSITE_EXTRACT, &[ty_uint, gid_x, gid_val, 0]);
207
208    // Load sub-group builtins
209    let sg_id = m.alloc_id();
210    m.emit_load(ty_uint, sg_id, var_sg_id);
211    let sg_lid = m.alloc_id();
212    m.emit_load(ty_uint, sg_lid, var_sg_lid);
213    let num_sg = m.alloc_id();
214    m.emit_load(ty_uint, num_sg, var_num_sg);
215
216    // Bounds check: gid_x < count
217    let cond_bounds = m.alloc_id();
218    m.emit(OP_U_LESS_THAN, &[ty_bool, cond_bounds, gid_x, p_count]);
219
220    // Load value if in-bounds, else 0.0
221    m.emit_selection_merge(label_bounds_merge);
222    m.emit_branch_conditional(cond_bounds, label_bounds_body, label_bounds_merge);
223
224    m.emit_label(label_bounds_body);
225    let inp_ptr = m.alloc_id();
226    m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, inp_ptr, p_input, gid_x);
227    let inp_val = m.alloc_id();
228    m.emit_load(ty_float, inp_val, inp_ptr);
229    m.emit_branch(label_bounds_merge);
230
231    m.emit_label(label_bounds_merge);
232    // Phi: val = inp_val if from bounds_body, else c_float_0
233    let val = m.alloc_id();
234    m.emit(
235        OP_PHI,
236        &[
237            ty_float,
238            val,
239            inp_val,
240            label_bounds_body,
241            c_float_0,
242            label_entry,
243        ],
244    );
245
246    // Phase 1: sub-group reduce
247    let sg_sum = m.alloc_id();
248    m.emit(
249        OP_GROUP_NON_UNIFORM_FADD,
250        &[ty_float, sg_sum, c_scope_sg, GROUP_OPERATION_REDUCE, val],
251    );
252
253    // Sub-group leader (sg_lid == 0) writes to shared scratch[sg_id]
254    let is_leader_eq = m.alloc_id();
255    m.emit(OP_I_EQUAL, &[ty_bool, is_leader_eq, sg_lid, c_uint_0]);
256
257    m.emit_selection_merge(label_after_leader);
258    m.emit_branch_conditional(is_leader_eq, label_leader, label_after_leader);
259
260    m.emit_label(label_leader);
261    let scratch_ptr = m.alloc_id();
262    m.emit_in_bounds_ptr_access_chain(ty_ptr_wg_float, scratch_ptr, var_scratch, sg_id);
263    m.emit_store(scratch_ptr, sg_sum);
264    m.emit_branch(label_after_leader);
265
266    m.emit_label(label_after_leader);
267
268    // Workgroup barrier
269    m.emit(OP_CONTROL_BARRIER, &[c_scope_wg, c_scope_wg, c_mem_sem]);
270
271    // Phase 2: sub-group 0 reduces across sub-groups
272    let is_sg0 = m.alloc_id();
273    m.emit(OP_I_EQUAL, &[ty_bool, is_sg0, sg_id, c_uint_0]);
274
275    // Also need sg_lid < num_sg for valid lanes
276    let lid_lt_nsg = m.alloc_id();
277    m.emit(OP_U_LESS_THAN, &[ty_bool, lid_lt_nsg, sg_lid, num_sg]);
278
279    m.emit_selection_merge(label_after_sg0);
280    m.emit_branch_conditional(is_sg0, label_sg0, label_after_sg0);
281
282    m.emit_label(label_sg0);
283
284    // Load scratch[sg_lid] -- each lane in SG0 loads one partial sum
285    let s_ptr = m.alloc_id();
286    m.emit_in_bounds_ptr_access_chain(ty_ptr_wg_float, s_ptr, var_scratch, sg_lid);
287    let partial = m.alloc_id();
288    m.emit_load(ty_float, partial, s_ptr);
289
290    // Mask partial to 0 if sg_lid >= num_sg (out of range sub-groups)
291    let safe_partial = m.alloc_id();
292    m.emit(
293        OP_SELECT,
294        &[ty_float, safe_partial, lid_lt_nsg, partial, c_float_0],
295    );
296
297    // Sub-group reduce on partial sums
298    let final_sum = m.alloc_id();
299    m.emit(
300        OP_GROUP_NON_UNIFORM_FADD,
301        &[
302            ty_float,
303            final_sum,
304            c_scope_sg,
305            GROUP_OPERATION_REDUCE,
306            safe_partial,
307        ],
308    );
309
310    // Lane 0 of sub-group 0 writes final result
311    let is_lane0 = m.alloc_id();
312    m.emit(OP_I_EQUAL, &[ty_bool, is_lane0, sg_lid, c_uint_0]);
313    let label_write = m.alloc_id();
314    let label_after_write = m.alloc_id();
315    m.emit_selection_merge(label_after_write);
316    m.emit_branch_conditional(is_lane0, label_write, label_after_write);
317
318    m.emit_label(label_write);
319    let out_ptr = m.alloc_id();
320    m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, out_ptr, p_output, c_uint_0);
321    m.emit_store(out_ptr, final_sum);
322    m.emit_branch(label_after_write);
323
324    m.emit_label(label_after_write);
325    m.emit_branch(label_after_sg0);
326
327    m.emit_label(label_after_sg0);
328    m.emit_return();
329    m.emit_function_end();
330
331    m.finalize()
332}
333
334// ─── Sub-group optimized scan kernel ─────────────────────────
335
336/// Generate an OpenCL SPIR-V compute kernel for sub-group scan (prefix sum).
337///
338/// Uses `OpGroupNonUniformFAdd` with `InclusiveScan` for the intra-sub-group
339/// phase. Output contains the inclusive prefix sum of the input.
340///
341/// Kernel parameters: `(CrossWorkgroup float* input, CrossWorkgroup float* output, uint count)`.
342///
343/// Entry point name: `"scan_subgroup"`.
344pub fn scan_subgroup_spirv() -> Vec<u32> {
345    let mut m = SpvModule::new();
346
347    // ── Capabilities ──
348    m.emit_capability(CAPABILITY_KERNEL);
349    m.emit_capability(CAPABILITY_ADDRESSES);
350    m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM);
351    m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC);
352
353    // ── Ext import ──
354    let opencl_ext = m.alloc_id();
355    m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
356
357    // ── Memory model ──
358    m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
359
360    // ── Type IDs ──
361    let ty_void = m.alloc_id();
362    let ty_bool = m.alloc_id();
363    let ty_uint = m.alloc_id();
364    let ty_float = m.alloc_id();
365    let ty_v3uint = m.alloc_id();
366    let ty_ptr_input_v3uint = m.alloc_id();
367    let ty_ptr_cross_float = m.alloc_id();
368    let ty_ptr_input_uint = m.alloc_id();
369    let ty_arr_float = m.alloc_id();
370    let ty_ptr_wg_float = m.alloc_id();
371    let ty_ptr_wg_arr = m.alloc_id();
372
373    // ── Constants ──
374    let c_uint_0 = m.alloc_id();
375    let c_uint_1 = m.alloc_id();
376    let c_uint_max_sg = m.alloc_id();
377    let c_float_0 = m.alloc_id();
378    let c_scope_sg = m.alloc_id();
379    let c_scope_wg = m.alloc_id();
380    let c_mem_sem = m.alloc_id();
381
382    // ── Variables ──
383    let var_gid = m.alloc_id();
384    let var_sg_id = m.alloc_id();
385    let var_sg_lid = m.alloc_id();
386    let var_num_sg = m.alloc_id();
387    let var_scratch = m.alloc_id();
388
389    // ── Function ──
390    let fn_ty = m.alloc_id();
391    let main_fn = m.alloc_id();
392    let p_input = m.alloc_id();
393    let p_output = m.alloc_id();
394    let p_count = m.alloc_id();
395
396    // ── Entry point ──
397    m.emit_entry_point(
398        EXECUTION_MODEL_KERNEL,
399        main_fn,
400        "scan_subgroup",
401        &[var_gid, var_sg_id, var_sg_lid, var_num_sg],
402    );
403    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
404
405    // ── Decorations ──
406    m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
407    m.emit_decorate(var_sg_id, DECORATION_BUILTIN, &[BUILTIN_SUBGROUP_ID]);
408    m.emit_decorate(
409        var_sg_lid,
410        DECORATION_BUILTIN,
411        &[BUILTIN_SUBGROUP_LOCAL_INVOCATION_ID],
412    );
413    m.emit_decorate(var_num_sg, DECORATION_BUILTIN, &[BUILTIN_NUM_SUBGROUPS]);
414
415    // ── Types ──
416    m.emit_type_void(ty_void);
417    m.emit_type_bool(ty_bool);
418    m.emit_type_int(ty_uint, 32, 0);
419    m.emit_type_float(ty_float, 32);
420    m.emit_type_vector(ty_v3uint, ty_uint, 3);
421    m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
422    m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
423    m.emit_type_pointer(ty_ptr_input_uint, STORAGE_CLASS_INPUT, ty_uint);
424    m.emit_type_pointer(ty_ptr_wg_float, STORAGE_CLASS_WORKGROUP, ty_float);
425    m.emit_type_function(
426        fn_ty,
427        ty_void,
428        &[ty_ptr_cross_float, ty_ptr_cross_float, ty_uint],
429    );
430
431    // ── Constants ──
432    m.emit_constant_u32(ty_uint, c_uint_0, 0);
433    m.emit_constant_u32(ty_uint, c_uint_1, 1);
434    m.emit_constant_u32(ty_uint, c_uint_max_sg, MAX_SUBGROUPS);
435    m.emit_constant_f32(ty_float, c_float_0, 0.0);
436    m.emit_constant_u32(ty_uint, c_scope_sg, SCOPE_SUBGROUP);
437    m.emit_constant_u32(ty_uint, c_scope_wg, SCOPE_WORKGROUP);
438    m.emit_constant_u32(ty_uint, c_mem_sem, MEMORY_SEMANTICS_WORKGROUP_MEMORY);
439
440    // ── Shared scratch array ──
441    m.emit(OP_TYPE_ARRAY, &[ty_arr_float, ty_float, c_uint_max_sg]);
442    m.emit_type_pointer(ty_ptr_wg_arr, STORAGE_CLASS_WORKGROUP, ty_arr_float);
443
444    // ── Variables ──
445    m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
446    m.emit_variable(ty_ptr_input_uint, var_sg_id, STORAGE_CLASS_INPUT);
447    m.emit_variable(ty_ptr_input_uint, var_sg_lid, STORAGE_CLASS_INPUT);
448    m.emit_variable(ty_ptr_input_uint, var_num_sg, STORAGE_CLASS_INPUT);
449    m.emit_variable(ty_ptr_wg_arr, var_scratch, STORAGE_CLASS_WORKGROUP);
450
451    // ── Labels ──
452    let label_entry = m.alloc_id();
453    let label_bounds_body = m.alloc_id();
454    let label_bounds_merge = m.alloc_id();
455    let label_leader = m.alloc_id();
456    let label_after_leader = m.alloc_id();
457    let label_add_prefix = m.alloc_id();
458    let label_after_add = m.alloc_id();
459    let label_write = m.alloc_id();
460    let label_end = m.alloc_id();
461
462    // ── Function body ──
463    m.emit_function(ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
464    m.emit_function_parameter(ty_ptr_cross_float, p_input);
465    m.emit_function_parameter(ty_ptr_cross_float, p_output);
466    m.emit_function_parameter(ty_uint, p_count);
467    m.emit_label(label_entry);
468
469    // Load global ID
470    let gid_val = m.alloc_id();
471    m.emit_load(ty_v3uint, gid_val, var_gid);
472    let gid_x = m.alloc_id();
473    m.emit(OP_COMPOSITE_EXTRACT, &[ty_uint, gid_x, gid_val, 0]);
474
475    // Load sub-group builtins
476    let sg_id = m.alloc_id();
477    m.emit_load(ty_uint, sg_id, var_sg_id);
478    let sg_lid = m.alloc_id();
479    m.emit_load(ty_uint, sg_lid, var_sg_lid);
480
481    // Bounds check
482    let in_bounds = m.alloc_id();
483    m.emit(OP_U_LESS_THAN, &[ty_bool, in_bounds, gid_x, p_count]);
484    m.emit_selection_merge(label_bounds_merge);
485    m.emit_branch_conditional(in_bounds, label_bounds_body, label_bounds_merge);
486
487    m.emit_label(label_bounds_body);
488    let inp_ptr = m.alloc_id();
489    m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, inp_ptr, p_input, gid_x);
490    let inp_val = m.alloc_id();
491    m.emit_load(ty_float, inp_val, inp_ptr);
492    m.emit_branch(label_bounds_merge);
493
494    m.emit_label(label_bounds_merge);
495    let val = m.alloc_id();
496    m.emit(
497        OP_PHI,
498        &[
499            ty_float,
500            val,
501            inp_val,
502            label_bounds_body,
503            c_float_0,
504            label_entry,
505        ],
506    );
507
508    // Phase 1: intra-sub-group inclusive scan
509    let sg_scan = m.alloc_id();
510    m.emit(
511        OP_GROUP_NON_UNIFORM_FADD,
512        &[
513            ty_float,
514            sg_scan,
515            c_scope_sg,
516            GROUP_OPERATION_INCLUSIVE_SCAN,
517            val,
518        ],
519    );
520
521    // Use sub-group reduce to get total per sub-group
522    let sg_total = m.alloc_id();
523    m.emit(
524        OP_GROUP_NON_UNIFORM_FADD,
525        &[ty_float, sg_total, c_scope_sg, GROUP_OPERATION_REDUCE, val],
526    );
527
528    // Leader writes sub-group total to scratch[sg_id]
529    let is_leader = m.alloc_id();
530    m.emit(OP_I_EQUAL, &[ty_bool, is_leader, sg_lid, c_uint_0]);
531
532    m.emit_selection_merge(label_after_leader);
533    m.emit_branch_conditional(is_leader, label_leader, label_after_leader);
534
535    m.emit_label(label_leader);
536    let scratch_ptr = m.alloc_id();
537    m.emit_in_bounds_ptr_access_chain(ty_ptr_wg_float, scratch_ptr, var_scratch, sg_id);
538    m.emit_store(scratch_ptr, sg_total);
539    m.emit_branch(label_after_leader);
540
541    m.emit_label(label_after_leader);
542
543    // Workgroup barrier
544    m.emit(OP_CONTROL_BARRIER, &[c_scope_wg, c_scope_wg, c_mem_sem]);
545
546    // Phase 2: add prefix from earlier sub-groups
547    // prefix = sum of scratch[0..sg_id)
548    let has_prefix = m.alloc_id();
549    m.emit(OP_U_LESS_THAN, &[ty_bool, has_prefix, c_uint_0, sg_id]); // 0 < sg_id
550
551    m.emit_selection_merge(label_after_add);
552    m.emit_branch_conditional(has_prefix, label_add_prefix, label_after_add);
553
554    m.emit_label(label_add_prefix);
555
556    // Accumulate prefix_sum = sum of scratch[j] for j in 0..sg_id via a loop
557    let var_j = m.alloc_id();
558    let ty_ptr_func_uint = m.alloc_id();
559    m.emit_type_pointer(ty_ptr_func_uint, STORAGE_CLASS_FUNCTION, ty_uint);
560    let var_prefix_acc = m.alloc_id();
561    let ty_ptr_func_float = m.alloc_id();
562    m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
563    m.emit_variable(ty_ptr_func_uint, var_j, STORAGE_CLASS_FUNCTION);
564    m.emit_variable(ty_ptr_func_float, var_prefix_acc, STORAGE_CLASS_FUNCTION);
565    m.emit_store(var_j, c_uint_0);
566    m.emit_store(var_prefix_acc, c_float_0);
567
568    let lbl_loop_hdr = m.alloc_id();
569    let lbl_loop_body = m.alloc_id();
570    let lbl_loop_cont = m.alloc_id();
571    let lbl_loop_merge = m.alloc_id();
572
573    m.emit_branch(lbl_loop_hdr);
574    m.emit_label(lbl_loop_hdr);
575    let j_val = m.alloc_id();
576    m.emit_load(ty_uint, j_val, var_j);
577    let loop_cond = m.alloc_id();
578    m.emit(OP_U_LESS_THAN, &[ty_bool, loop_cond, j_val, sg_id]);
579    m.emit_loop_merge(lbl_loop_merge, lbl_loop_cont);
580    m.emit_branch_conditional(loop_cond, lbl_loop_body, lbl_loop_merge);
581
582    m.emit_label(lbl_loop_body);
583    let s_ptr = m.alloc_id();
584    m.emit_in_bounds_ptr_access_chain(ty_ptr_wg_float, s_ptr, var_scratch, j_val);
585    let s_val = m.alloc_id();
586    m.emit_load(ty_float, s_val, s_ptr);
587    let old_prefix = m.alloc_id();
588    m.emit_load(ty_float, old_prefix, var_prefix_acc);
589    let new_prefix = m.alloc_id();
590    m.emit(OP_F_ADD, &[ty_float, new_prefix, old_prefix, s_val]);
591    m.emit_store(var_prefix_acc, new_prefix);
592    m.emit_branch(lbl_loop_cont);
593
594    m.emit_label(lbl_loop_cont);
595    let j_inc = m.alloc_id();
596    m.emit(OP_I_ADD, &[ty_uint, j_inc, j_val, c_uint_1]);
597    m.emit_store(var_j, j_inc);
598    m.emit_branch(lbl_loop_hdr);
599
600    m.emit_label(lbl_loop_merge);
601    let prefix_val = m.alloc_id();
602    m.emit_load(ty_float, prefix_val, var_prefix_acc);
603    m.emit_branch(label_after_add);
604
605    m.emit_label(label_after_add);
606    // Phi for prefix: either prefix_val or 0
607    let prefix = m.alloc_id();
608    m.emit(
609        OP_PHI,
610        &[
611            ty_float,
612            prefix,
613            prefix_val,
614            lbl_loop_merge,
615            c_float_0,
616            label_after_leader,
617        ],
618    );
619
620    // Final result = sg_scan + prefix
621    let final_val = m.alloc_id();
622    m.emit(OP_F_ADD, &[ty_float, final_val, sg_scan, prefix]);
623
624    // Write output if in bounds
625    m.emit_selection_merge(label_end);
626    m.emit_branch_conditional(in_bounds, label_write, label_end);
627
628    m.emit_label(label_write);
629    let out_ptr = m.alloc_id();
630    m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, out_ptr, p_output, gid_x);
631    m.emit_store(out_ptr, final_val);
632    m.emit_branch(label_end);
633
634    m.emit_label(label_end);
635    m.emit_return();
636    m.emit_function_end();
637
638    m.finalize()
639}
640
641// ─── Sub-group optimized GEMM kernel ─────────────────────────
642
643/// Generate an OpenCL SPIR-V compute kernel for GEMM with sub-group shuffle.
644///
645/// `C = alpha * A * B + beta * C` (row-major f32).
646///
647/// Uses `OpGroupNonUniformShuffle` to broadcast A-row elements across the
648/// sub-group, avoiding redundant global memory loads. Each lane in a sub-group
649/// handles a different column of B, and the A value is shuffled from the lane
650/// that loaded it.
651///
652/// Kernel parameters: `(CrossWorkgroup float* A, CrossWorkgroup float* B,
653///                      CrossWorkgroup float* C, uint m, uint n, uint k,
654///                      float alpha, float beta)`.
655///
656/// Entry point name: `"gemm_subgroup"`.
657pub fn gemm_subgroup_spirv() -> Vec<u32> {
658    let mut m = SpvModule::new();
659
660    // ── Capabilities ──
661    m.emit_capability(CAPABILITY_KERNEL);
662    m.emit_capability(CAPABILITY_ADDRESSES);
663    m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM);
664    m.emit_capability(CAPABILITY_GROUP_NON_UNIFORM_SHUFFLE);
665
666    // ── Ext import ──
667    let opencl_ext = m.alloc_id();
668    m.emit_ext_inst_import(opencl_ext, "OpenCL.std");
669
670    // ── Memory model ──
671    m.emit_memory_model(ADDRESSING_MODEL_PHYSICAL64, MEMORY_MODEL_OPENCL);
672
673    // ── Types ──
674    let ty_void = m.alloc_id();
675    let ty_bool = m.alloc_id();
676    let ty_uint = m.alloc_id();
677    let ty_float = m.alloc_id();
678    let ty_v3uint = m.alloc_id();
679    let ty_ptr_input_v3uint = m.alloc_id();
680    let ty_ptr_cross_float = m.alloc_id();
681    let ty_ptr_func_float = m.alloc_id();
682    let ty_ptr_func_uint = m.alloc_id();
683    let ty_ptr_input_uint = m.alloc_id();
684
685    // ── Constants ──
686    let c_uint_0 = m.alloc_id();
687    let c_uint_1 = m.alloc_id();
688    let c_float_0 = m.alloc_id();
689    let c_scope_sg = m.alloc_id();
690
691    // ── Variables ──
692    let var_gid = m.alloc_id();
693    let var_sg_lid = m.alloc_id();
694
695    // ── Function ──
696    let fn_ty = m.alloc_id();
697    let main_fn = m.alloc_id();
698    let p_a = m.alloc_id();
699    let p_b = m.alloc_id();
700    let p_c = m.alloc_id();
701    let p_m = m.alloc_id();
702    let p_n = m.alloc_id();
703    let p_k = m.alloc_id();
704    let p_alpha = m.alloc_id();
705    let p_beta = m.alloc_id();
706
707    // ── Entry point ──
708    m.emit_entry_point(
709        EXECUTION_MODEL_KERNEL,
710        main_fn,
711        "gemm_subgroup",
712        &[var_gid, var_sg_lid],
713    );
714    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
715
716    // ── Decorations ──
717    m.emit_decorate(var_gid, DECORATION_BUILTIN, &[BUILTIN_GLOBAL_INVOCATION_ID]);
718    m.emit_decorate(
719        var_sg_lid,
720        DECORATION_BUILTIN,
721        &[BUILTIN_SUBGROUP_LOCAL_INVOCATION_ID],
722    );
723
724    // ── Types ──
725    m.emit_type_void(ty_void);
726    m.emit_type_bool(ty_bool);
727    m.emit_type_int(ty_uint, 32, 0);
728    m.emit_type_float(ty_float, 32);
729    m.emit_type_vector(ty_v3uint, ty_uint, 3);
730    m.emit_type_pointer(ty_ptr_input_v3uint, STORAGE_CLASS_INPUT, ty_v3uint);
731    m.emit_type_pointer(ty_ptr_cross_float, STORAGE_CLASS_CROSS_WORKGROUP, ty_float);
732    m.emit_type_pointer(ty_ptr_func_float, STORAGE_CLASS_FUNCTION, ty_float);
733    m.emit_type_pointer(ty_ptr_func_uint, STORAGE_CLASS_FUNCTION, ty_uint);
734    m.emit_type_pointer(ty_ptr_input_uint, STORAGE_CLASS_INPUT, ty_uint);
735    m.emit_type_function(
736        fn_ty,
737        ty_void,
738        &[
739            ty_ptr_cross_float,
740            ty_ptr_cross_float,
741            ty_ptr_cross_float,
742            ty_uint,
743            ty_uint,
744            ty_uint,
745            ty_float,
746            ty_float,
747        ],
748    );
749
750    // ── Constants ──
751    m.emit_constant_u32(ty_uint, c_uint_0, 0);
752    m.emit_constant_u32(ty_uint, c_uint_1, 1);
753    m.emit_constant_f32(ty_float, c_float_0, 0.0);
754    m.emit_constant_u32(ty_uint, c_scope_sg, SCOPE_SUBGROUP);
755
756    // ── Variables ──
757    m.emit_variable(ty_ptr_input_v3uint, var_gid, STORAGE_CLASS_INPUT);
758    m.emit_variable(ty_ptr_input_uint, var_sg_lid, STORAGE_CLASS_INPUT);
759
760    // ── Labels ──
761    let label_entry = m.alloc_id();
762    let label_bounds_body = m.alloc_id();
763    let label_bounds_merge = m.alloc_id();
764    let label_loop_header = m.alloc_id();
765    let label_loop_body = m.alloc_id();
766    let label_loop_continue = m.alloc_id();
767    let label_loop_merge = m.alloc_id();
768
769    // ── Function body ──
770    m.emit_function(ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
771    m.emit_function_parameter(ty_ptr_cross_float, p_a);
772    m.emit_function_parameter(ty_ptr_cross_float, p_b);
773    m.emit_function_parameter(ty_ptr_cross_float, p_c);
774    m.emit_function_parameter(ty_uint, p_m);
775    m.emit_function_parameter(ty_uint, p_n);
776    m.emit_function_parameter(ty_uint, p_k);
777    m.emit_function_parameter(ty_float, p_alpha);
778    m.emit_function_parameter(ty_float, p_beta);
779    m.emit_label(label_entry);
780
781    // Load global ID -> element index (one thread per output element)
782    let gid_val = m.alloc_id();
783    m.emit_load(ty_v3uint, gid_val, var_gid);
784    let gid_x = m.alloc_id();
785    m.emit(OP_COMPOSITE_EXTRACT, &[ty_uint, gid_x, gid_val, 0]);
786
787    // Load sub-group local ID
788    let sg_lid = m.alloc_id();
789    m.emit_load(ty_uint, sg_lid, var_sg_lid);
790
791    // total = m * n
792    let total = m.alloc_id();
793    m.emit(OP_I_MUL, &[ty_uint, total, p_m, p_n]);
794
795    // Bounds check
796    let cond = m.alloc_id();
797    m.emit(OP_U_LESS_THAN, &[ty_bool, cond, gid_x, total]);
798    m.emit_selection_merge(label_bounds_merge);
799    m.emit_branch_conditional(cond, label_bounds_body, label_bounds_merge);
800
801    m.emit_label(label_bounds_body);
802
803    // row = gid / n, col = gid % n
804    // OpUDiv = 134, OpUMod = 137
805    let row = m.alloc_id();
806    m.emit(134, &[ty_uint, row, gid_x, p_n]); // OpUDiv
807    let col = m.alloc_id();
808    m.emit(137, &[ty_uint, col, gid_x, p_n]); // OpUMod
809
810    // Accumulator + loop counter
811    let var_acc = m.alloc_id();
812    m.emit_variable(ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
813    m.emit_store(var_acc, c_float_0);
814    let var_i = m.alloc_id();
815    m.emit_variable(ty_ptr_func_uint, var_i, STORAGE_CLASS_FUNCTION);
816    m.emit_store(var_i, c_uint_0);
817
818    m.emit_branch(label_loop_header);
819
820    // ── Loop header ──
821    m.emit_label(label_loop_header);
822    let i_val = m.alloc_id();
823    m.emit_load(ty_uint, i_val, var_i);
824    let loop_cond = m.alloc_id();
825    m.emit(OP_U_LESS_THAN, &[ty_bool, loop_cond, i_val, p_k]);
826    m.emit_loop_merge(label_loop_merge, label_loop_continue);
827    m.emit_branch_conditional(loop_cond, label_loop_body, label_loop_merge);
828
829    // ── Loop body ──
830    m.emit_label(label_loop_body);
831
832    // Load A[row, i]
833    let a_idx = m.alloc_id();
834    let row_k = m.alloc_id();
835    m.emit(OP_I_MUL, &[ty_uint, row_k, row, p_k]);
836    m.emit(OP_I_ADD, &[ty_uint, a_idx, row_k, i_val]);
837    let a_ptr = m.alloc_id();
838    m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, a_ptr, p_a, a_idx);
839    let a_val = m.alloc_id();
840    m.emit_load(ty_float, a_val, a_ptr);
841
842    // Broadcast A value via sub-group shuffle (identity shuffle validates sub-group path)
843    let a_broadcast = m.alloc_id();
844    m.emit(
845        OP_GROUP_NON_UNIFORM_SHUFFLE,
846        &[ty_float, a_broadcast, c_scope_sg, a_val, sg_lid],
847    );
848
849    // Load B[i, col]
850    let b_idx = m.alloc_id();
851    let i_n = m.alloc_id();
852    m.emit(OP_I_MUL, &[ty_uint, i_n, i_val, p_n]);
853    m.emit(OP_I_ADD, &[ty_uint, b_idx, i_n, col]);
854    let b_ptr = m.alloc_id();
855    m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, b_ptr, p_b, b_idx);
856    let b_val = m.alloc_id();
857    m.emit_load(ty_float, b_val, b_ptr);
858
859    // acc += a_broadcast * b_val
860    let prod = m.alloc_id();
861    m.emit(OP_F_MUL, &[ty_float, prod, a_broadcast, b_val]);
862    let old_acc = m.alloc_id();
863    m.emit_load(ty_float, old_acc, var_acc);
864    let new_acc = m.alloc_id();
865    m.emit(OP_F_ADD, &[ty_float, new_acc, old_acc, prod]);
866    m.emit_store(var_acc, new_acc);
867
868    m.emit_branch(label_loop_continue);
869
870    // ── Loop continue ──
871    m.emit_label(label_loop_continue);
872    let i_inc = m.alloc_id();
873    m.emit(OP_I_ADD, &[ty_uint, i_inc, i_val, c_uint_1]);
874    m.emit_store(var_i, i_inc);
875    m.emit_branch(label_loop_header);
876
877    // ── Loop merge ──
878    m.emit_label(label_loop_merge);
879
880    // result = alpha * acc + beta * C[gid]
881    let final_acc = m.alloc_id();
882    m.emit_load(ty_float, final_acc, var_acc);
883    let alpha_acc = m.alloc_id();
884    m.emit(OP_F_MUL, &[ty_float, alpha_acc, p_alpha, final_acc]);
885
886    let c_ptr = m.alloc_id();
887    m.emit_in_bounds_ptr_access_chain(ty_ptr_cross_float, c_ptr, p_c, gid_x);
888    let c_old = m.alloc_id();
889    m.emit_load(ty_float, c_old, c_ptr);
890    let beta_c = m.alloc_id();
891    m.emit(OP_F_MUL, &[ty_float, beta_c, p_beta, c_old]);
892    let c_new = m.alloc_id();
893    m.emit(OP_F_ADD, &[ty_float, c_new, alpha_acc, beta_c]);
894    m.emit_store(c_ptr, c_new);
895
896    m.emit_branch(label_bounds_merge);
897
898    m.emit_label(label_bounds_merge);
899    m.emit_return();
900    m.emit_function_end();
901
902    m.finalize()
903}
904
905// ─── Tests ──────────────────────────────────────────────────
906
907#[cfg(test)]
908mod tests {
909    use super::*;
910    use crate::spirv::SPIRV_MAGIC;
911
912    const OP_CAPABILITY: u32 = 17;
913
914    fn check_valid_spirv(words: &[u32]) {
915        assert!(words.len() >= 5, "too short for SPIR-V header");
916        assert_eq!(words[0], SPIRV_MAGIC, "bad magic");
917        assert!(words[3] > 0, "ID bound must be > 0");
918        assert_eq!(words[4], 0, "schema must be 0");
919    }
920
921    /// Check that the SPIR-V word stream contains a given capability value.
922    fn has_capability(words: &[u32], cap: u32) -> bool {
923        let cap_header = (2u32 << 16) | OP_CAPABILITY;
924        words.windows(2).any(|w| w[0] == cap_header && w[1] == cap)
925    }
926
927    /// Check that the SPIR-V contains an OpEntryPoint with the given name.
928    fn has_entry_point(words: &[u32], name: &str) -> bool {
929        let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_le_bytes()).collect();
930        let name_bytes = name.as_bytes();
931        bytes.windows(name_bytes.len()).any(|w| w == name_bytes)
932    }
933
934    #[test]
935    fn reduction_subgroup_valid_spirv() {
936        let words = reduction_subgroup_spirv();
937        check_valid_spirv(&words);
938    }
939
940    #[test]
941    fn reduction_subgroup_word_aligned() {
942        let words = reduction_subgroup_spirv();
943        let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_ne_bytes()).collect();
944        assert_eq!(bytes.len() % 4, 0);
945    }
946
947    #[test]
948    fn reduction_subgroup_has_group_non_uniform_capability() {
949        let words = reduction_subgroup_spirv();
950        assert!(
951            has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM),
952            "missing GroupNonUniform capability"
953        );
954        assert!(
955            has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC),
956            "missing GroupNonUniformArithmetic capability"
957        );
958    }
959
960    #[test]
961    fn reduction_subgroup_has_entry_point() {
962        let words = reduction_subgroup_spirv();
963        assert!(
964            has_entry_point(&words, "reduction_subgroup"),
965            "missing entry point name"
966        );
967    }
968
969    #[test]
970    fn scan_subgroup_valid_spirv() {
971        let words = scan_subgroup_spirv();
972        check_valid_spirv(&words);
973    }
974
975    #[test]
976    fn scan_subgroup_word_aligned() {
977        let words = scan_subgroup_spirv();
978        let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_ne_bytes()).collect();
979        assert_eq!(bytes.len() % 4, 0);
980    }
981
982    #[test]
983    fn scan_subgroup_has_group_non_uniform_capability() {
984        let words = scan_subgroup_spirv();
985        assert!(
986            has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM),
987            "missing GroupNonUniform capability"
988        );
989        assert!(
990            has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM_ARITHMETIC),
991            "missing GroupNonUniformArithmetic capability"
992        );
993    }
994
995    #[test]
996    fn scan_subgroup_has_entry_point() {
997        let words = scan_subgroup_spirv();
998        assert!(
999            has_entry_point(&words, "scan_subgroup"),
1000            "missing entry point name"
1001        );
1002    }
1003
1004    #[test]
1005    fn gemm_subgroup_valid_spirv() {
1006        let words = gemm_subgroup_spirv();
1007        check_valid_spirv(&words);
1008    }
1009
1010    #[test]
1011    fn gemm_subgroup_word_aligned() {
1012        let words = gemm_subgroup_spirv();
1013        let bytes: Vec<u8> = words.iter().flat_map(|w| w.to_ne_bytes()).collect();
1014        assert_eq!(bytes.len() % 4, 0);
1015    }
1016
1017    #[test]
1018    fn gemm_subgroup_has_group_non_uniform_capability() {
1019        let words = gemm_subgroup_spirv();
1020        assert!(
1021            has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM),
1022            "missing GroupNonUniform capability"
1023        );
1024        assert!(
1025            has_capability(&words, CAPABILITY_GROUP_NON_UNIFORM_SHUFFLE),
1026            "missing GroupNonUniformShuffle capability"
1027        );
1028    }
1029
1030    #[test]
1031    fn gemm_subgroup_has_entry_point() {
1032        let words = gemm_subgroup_spirv();
1033        assert!(
1034            has_entry_point(&words, "gemm_subgroup"),
1035            "missing entry point name"
1036        );
1037    }
1038
1039    #[test]
1040    fn subgroup_shaders_all_word_aligned() {
1041        fn to_bytes(words: &[u32]) -> Vec<u8> {
1042            words.iter().flat_map(|w| w.to_ne_bytes()).collect()
1043        }
1044        assert_eq!(to_bytes(&reduction_subgroup_spirv()).len() % 4, 0);
1045        assert_eq!(to_bytes(&scan_subgroup_spirv()).len() % 4, 0);
1046        assert_eq!(to_bytes(&gemm_subgroup_spirv()).len() % 4, 0);
1047    }
1048}