Skip to main content

oxicuda_levelzero/
spirv_nn.rs

1//! SPIR-V compute kernel generators for neural-network operations.
2//!
3//! This module extends the Level Zero SPIR-V generator with Conv2D and
4//! scaled dot-product attention kernels.  Both use the OpenCL execution
5//! model (Kernel + Physical64 + Addresses) so they can be consumed by
6//! `zeModuleCreate`.
7
8use super::spirv::{
9    EXECUTION_MODEL_KERNEL, FUNCTION_CONTROL_NONE, OP_F_ADD, OP_F_DIV, OP_F_MUL, OP_F_SUB,
10    OP_I_ADD, OP_I_MUL, OP_U_DIV, OP_U_LESS_THAN, OP_U_MOD, OPENCL_EXP, OPENCL_FMAX,
11    STORAGE_CLASS_FUNCTION, SpvModule, WORKGROUP_SIZE, emit_preamble, load_gid_x,
12};
13
14/// SPIR-V opcode for `OpISub` (integer subtract).
15const OP_I_SUB: u32 = 130;
16/// SPIR-V opcode for `OpFOrdGreaterThan` (ordered float >).
17const OP_F_ORD_GT: u32 = 188;
18
19// ─── Conv2D compute kernel ──────────────────────────────────
20
21/// Generate an OpenCL SPIR-V compute kernel for 2-D convolution (NCHW layout).
22///
23/// Each work-item computes one output element.
24///
25/// Kernel parameters (all passed via `zeKernelSetArgumentValue`):
26///
27/// ```text
28/// (CrossWorkgroup float* input,
29///  CrossWorkgroup float* filter,
30///  CrossWorkgroup float* output)
31/// ```
32///
33/// All dimension constants are baked in as `OpConstant`.
34#[allow(clippy::too_many_arguments)]
35pub fn conv2d_spirv(
36    n: u32,
37    c_in: u32,
38    h_in: u32,
39    w_in: u32,
40    k_out: u32,
41    fh: u32,
42    fw: u32,
43    oh: u32,
44    ow: u32,
45    stride_h: u32,
46    stride_w: u32,
47    pad_h: u32,
48    pad_w: u32,
49) -> Vec<u32> {
50    let mut m = SpvModule::new();
51    let b = emit_preamble(&mut m);
52
53    let main_fn = m.alloc_id();
54    let fn_ty = m.alloc_id();
55    let p_input = m.alloc_id();
56    let p_filter = m.alloc_id();
57    let p_output = m.alloc_id();
58
59    // Function type: void(float*, float*, float*)
60    m.emit_type_function(
61        fn_ty,
62        b.ty_void,
63        &[
64            b.ty_ptr_cross_float,
65            b.ty_ptr_cross_float,
66            b.ty_ptr_cross_float,
67        ],
68    );
69
70    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
71    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
72
73    // Emit constants for dimensions
74    let c_n = m.alloc_id();
75    m.emit_constant_u32(b.ty_uint, c_n, n);
76    let c_c_in = m.alloc_id();
77    m.emit_constant_u32(b.ty_uint, c_c_in, c_in);
78    let c_h_in = m.alloc_id();
79    m.emit_constant_u32(b.ty_uint, c_h_in, h_in);
80    let c_w_in = m.alloc_id();
81    m.emit_constant_u32(b.ty_uint, c_w_in, w_in);
82    let c_k_out = m.alloc_id();
83    m.emit_constant_u32(b.ty_uint, c_k_out, k_out);
84    let c_fh = m.alloc_id();
85    m.emit_constant_u32(b.ty_uint, c_fh, fh);
86    let c_fw = m.alloc_id();
87    m.emit_constant_u32(b.ty_uint, c_fw, fw);
88    let c_oh = m.alloc_id();
89    m.emit_constant_u32(b.ty_uint, c_oh, oh);
90    let c_ow = m.alloc_id();
91    m.emit_constant_u32(b.ty_uint, c_ow, ow);
92    let c_stride_h = m.alloc_id();
93    m.emit_constant_u32(b.ty_uint, c_stride_h, stride_h);
94    let c_stride_w = m.alloc_id();
95    m.emit_constant_u32(b.ty_uint, c_stride_w, stride_w);
96    let c_pad_h = m.alloc_id();
97    m.emit_constant_u32(b.ty_uint, c_pad_h, pad_h);
98    let c_pad_w = m.alloc_id();
99    m.emit_constant_u32(b.ty_uint, c_pad_w, pad_w);
100
101    // Labels
102    let label_entry = m.alloc_id();
103    let label_body = m.alloc_id();
104    let label_merge = m.alloc_id();
105
106    // Function
107    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
108    m.emit_function_parameter(b.ty_ptr_cross_float, p_input);
109    m.emit_function_parameter(b.ty_ptr_cross_float, p_filter);
110    m.emit_function_parameter(b.ty_ptr_cross_float, p_output);
111    m.emit_label(label_entry);
112
113    let gid = load_gid_x(&mut m, &b);
114
115    // total = n * k_out * oh * ow
116    let t1 = m.alloc_id();
117    m.emit(OP_I_MUL, &[b.ty_uint, t1, c_n, c_k_out]);
118    let t2 = m.alloc_id();
119    m.emit(OP_I_MUL, &[b.ty_uint, t2, t1, c_oh]);
120    let total = m.alloc_id();
121    m.emit(OP_I_MUL, &[b.ty_uint, total, t2, c_ow]);
122
123    let cond = m.alloc_id();
124    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, total]);
125    m.emit_selection_merge(label_merge);
126    m.emit_branch_conditional(cond, label_body, label_merge);
127
128    m.emit_label(label_body);
129
130    // Decompose gid -> (b_idx, kf, oy, ox)
131    let ox = m.alloc_id();
132    m.emit(OP_U_MOD, &[b.ty_uint, ox, gid, c_ow]);
133    let tmp1 = m.alloc_id();
134    m.emit(OP_U_DIV, &[b.ty_uint, tmp1, gid, c_ow]);
135    let oy = m.alloc_id();
136    m.emit(OP_U_MOD, &[b.ty_uint, oy, tmp1, c_oh]);
137    let tmp2 = m.alloc_id();
138    m.emit(OP_U_DIV, &[b.ty_uint, tmp2, tmp1, c_oh]);
139    let kf = m.alloc_id();
140    m.emit(OP_U_MOD, &[b.ty_uint, kf, tmp2, c_k_out]);
141    let b_idx = m.alloc_id();
142    m.emit(OP_U_DIV, &[b.ty_uint, b_idx, tmp2, c_k_out]);
143
144    // Accumulator variable
145    let var_acc = m.alloc_id();
146    m.emit_variable(b.ty_ptr_func_float, var_acc, STORAGE_CLASS_FUNCTION);
147    m.emit_store(var_acc, b.c_float_0);
148
149    // Flatten ci * fh * fw
150    let flat_total_id = m.alloc_id();
151    let flat_t1 = m.alloc_id();
152    m.emit(OP_I_MUL, &[b.ty_uint, flat_t1, c_c_in, c_fh]);
153    m.emit(OP_I_MUL, &[b.ty_uint, flat_total_id, flat_t1, c_fw]);
154
155    let var_flat = m.alloc_id();
156    m.emit_variable(b.ty_ptr_func_uint, var_flat, STORAGE_CLASS_FUNCTION);
157    m.emit_store(var_flat, b.c_uint_0);
158
159    let lbl_loop_hdr = m.alloc_id();
160    let lbl_loop_body = m.alloc_id();
161    let lbl_loop_cont = m.alloc_id();
162    let lbl_loop_merge = m.alloc_id();
163
164    m.emit_branch(lbl_loop_hdr);
165
166    // ── Loop header ──
167    m.emit_label(lbl_loop_hdr);
168    let flat_val = m.alloc_id();
169    m.emit_load(b.ty_uint, flat_val, var_flat);
170    let loop_cond = m.alloc_id();
171    m.emit(
172        OP_U_LESS_THAN,
173        &[b.ty_bool, loop_cond, flat_val, flat_total_id],
174    );
175    m.emit_loop_merge(lbl_loop_merge, lbl_loop_cont);
176    m.emit_branch_conditional(loop_cond, lbl_loop_body, lbl_loop_merge);
177
178    // ── Loop body ──
179    m.emit_label(lbl_loop_body);
180
181    // Decompose flat_val -> (ci, fy, fx)
182    let fx = m.alloc_id();
183    m.emit(OP_U_MOD, &[b.ty_uint, fx, flat_val, c_fw]);
184    let ftmp1 = m.alloc_id();
185    m.emit(OP_U_DIV, &[b.ty_uint, ftmp1, flat_val, c_fw]);
186    let fy = m.alloc_id();
187    m.emit(OP_U_MOD, &[b.ty_uint, fy, ftmp1, c_fh]);
188    let ci = m.alloc_id();
189    m.emit(OP_U_DIV, &[b.ty_uint, ci, ftmp1, c_fh]);
190
191    // iy_raw = oy * stride_h + fy, ix_raw = ox * stride_w + fx
192    let oy_sh = m.alloc_id();
193    m.emit(OP_I_MUL, &[b.ty_uint, oy_sh, oy, c_stride_h]);
194    let iy_raw = m.alloc_id();
195    m.emit(OP_I_ADD, &[b.ty_uint, iy_raw, oy_sh, fy]);
196    let ox_sw = m.alloc_id();
197    m.emit(OP_I_MUL, &[b.ty_uint, ox_sw, ox, c_stride_w]);
198    let ix_raw = m.alloc_id();
199    m.emit(OP_I_ADD, &[b.ty_uint, ix_raw, ox_sw, fx]);
200
201    // Bounds check: iy_raw >= pad_h  &&  (iy_raw - pad_h) < h_in
202    //           &&  ix_raw >= pad_w  &&  (ix_raw - pad_w) < w_in
203    let lbl_skip = m.alloc_id();
204
205    let iy_lt_pad = m.alloc_id();
206    m.emit(OP_U_LESS_THAN, &[b.ty_bool, iy_lt_pad, iy_raw, c_pad_h]);
207    let lbl_iy_ok = m.alloc_id();
208    m.emit_selection_merge(lbl_skip);
209    m.emit_branch_conditional(iy_lt_pad, lbl_skip, lbl_iy_ok);
210
211    m.emit_label(lbl_iy_ok);
212    let iy_real = m.alloc_id();
213    m.emit(OP_I_SUB, &[b.ty_uint, iy_real, iy_raw, c_pad_h]);
214    let iy_in_bounds = m.alloc_id();
215    m.emit(OP_U_LESS_THAN, &[b.ty_bool, iy_in_bounds, iy_real, c_h_in]);
216    let lbl_ix_check = m.alloc_id();
217    m.emit_selection_merge(lbl_skip);
218    m.emit_branch_conditional(iy_in_bounds, lbl_ix_check, lbl_skip);
219
220    m.emit_label(lbl_ix_check);
221    let ix_lt_pad = m.alloc_id();
222    m.emit(OP_U_LESS_THAN, &[b.ty_bool, ix_lt_pad, ix_raw, c_pad_w]);
223    let lbl_ix_ok = m.alloc_id();
224    m.emit_selection_merge(lbl_skip);
225    m.emit_branch_conditional(ix_lt_pad, lbl_skip, lbl_ix_ok);
226
227    m.emit_label(lbl_ix_ok);
228    let ix_real = m.alloc_id();
229    m.emit(OP_I_SUB, &[b.ty_uint, ix_real, ix_raw, c_pad_w]);
230    let ix_in_bounds = m.alloc_id();
231    m.emit(OP_U_LESS_THAN, &[b.ty_bool, ix_in_bounds, ix_real, c_w_in]);
232    let lbl_accum = m.alloc_id();
233    m.emit_selection_merge(lbl_skip);
234    m.emit_branch_conditional(ix_in_bounds, lbl_accum, lbl_skip);
235
236    m.emit_label(lbl_accum);
237
238    // input_idx = ((b_idx * c_in + ci) * h_in + iy_real) * w_in + ix_real
239    let in1 = m.alloc_id();
240    m.emit(OP_I_MUL, &[b.ty_uint, in1, b_idx, c_c_in]);
241    let in2 = m.alloc_id();
242    m.emit(OP_I_ADD, &[b.ty_uint, in2, in1, ci]);
243    let in3 = m.alloc_id();
244    m.emit(OP_I_MUL, &[b.ty_uint, in3, in2, c_h_in]);
245    let in4 = m.alloc_id();
246    m.emit(OP_I_ADD, &[b.ty_uint, in4, in3, iy_real]);
247    let in5 = m.alloc_id();
248    m.emit(OP_I_MUL, &[b.ty_uint, in5, in4, c_w_in]);
249    let in_idx = m.alloc_id();
250    m.emit(OP_I_ADD, &[b.ty_uint, in_idx, in5, ix_real]);
251
252    let inp_ptr = m.alloc_id();
253    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, inp_ptr, p_input, in_idx);
254    let inp_val = m.alloc_id();
255    m.emit_load(b.ty_float, inp_val, inp_ptr);
256
257    // filter_idx = ((kf * c_in + ci) * fh + fy) * fw + fx
258    let f1 = m.alloc_id();
259    m.emit(OP_I_MUL, &[b.ty_uint, f1, kf, c_c_in]);
260    let f2 = m.alloc_id();
261    m.emit(OP_I_ADD, &[b.ty_uint, f2, f1, ci]);
262    let f3 = m.alloc_id();
263    m.emit(OP_I_MUL, &[b.ty_uint, f3, f2, c_fh]);
264    let f4 = m.alloc_id();
265    m.emit(OP_I_ADD, &[b.ty_uint, f4, f3, fy]);
266    let f5 = m.alloc_id();
267    m.emit(OP_I_MUL, &[b.ty_uint, f5, f4, c_fw]);
268    let flt_idx = m.alloc_id();
269    m.emit(OP_I_ADD, &[b.ty_uint, flt_idx, f5, fx]);
270
271    let flt_ptr = m.alloc_id();
272    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, flt_ptr, p_filter, flt_idx);
273    let flt_val = m.alloc_id();
274    m.emit_load(b.ty_float, flt_val, flt_ptr);
275
276    // acc += inp * flt
277    let prod = m.alloc_id();
278    m.emit(OP_F_MUL, &[b.ty_float, prod, inp_val, flt_val]);
279    let old_acc = m.alloc_id();
280    m.emit_load(b.ty_float, old_acc, var_acc);
281    let new_acc = m.alloc_id();
282    m.emit(OP_F_ADD, &[b.ty_float, new_acc, old_acc, prod]);
283    m.emit_store(var_acc, new_acc);
284
285    m.emit_branch(lbl_skip);
286
287    m.emit_label(lbl_skip);
288    m.emit_branch(lbl_loop_cont);
289
290    // ── Loop continue ──
291    m.emit_label(lbl_loop_cont);
292    let flat_inc = m.alloc_id();
293    m.emit(OP_I_ADD, &[b.ty_uint, flat_inc, flat_val, b.c_uint_1]);
294    m.emit_store(var_flat, flat_inc);
295    m.emit_branch(lbl_loop_hdr);
296
297    // ── Loop merge: store result ──
298    m.emit_label(lbl_loop_merge);
299
300    let final_acc = m.alloc_id();
301    m.emit_load(b.ty_float, final_acc, var_acc);
302
303    let out_ptr = m.alloc_id();
304    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, out_ptr, p_output, gid);
305    m.emit_store(out_ptr, final_acc);
306
307    m.emit_branch(label_merge);
308
309    m.emit_label(label_merge);
310    m.emit_return();
311    m.emit_function_end();
312
313    m.finalize()
314}
315
316// ─── Attention compute kernel ───────────────────────────────
317
318/// Generate an OpenCL SPIR-V compute kernel for scaled dot-product attention.
319///
320/// Each work-item handles one (batch_head, query_position) pair.
321///
322/// Kernel parameters:
323///
324/// ```text
325/// (CrossWorkgroup float* Q,
326///  CrossWorkgroup float* K,
327///  CrossWorkgroup float* V,
328///  CrossWorkgroup float* O)
329/// ```
330///
331/// Dimension constants are baked in as `OpConstant`.
332#[allow(clippy::too_many_arguments)]
333pub fn attention_spirv(
334    batch_heads: u32,
335    seq_q: u32,
336    seq_kv: u32,
337    head_dim: u32,
338    scale: f32,
339    causal: bool,
340) -> Vec<u32> {
341    let mut m = SpvModule::new();
342    let b = emit_preamble(&mut m);
343
344    let main_fn = m.alloc_id();
345    let fn_ty = m.alloc_id();
346    let p_q = m.alloc_id();
347    let p_k = m.alloc_id();
348    let p_v = m.alloc_id();
349    let p_o = m.alloc_id();
350
351    // Function type: void(float*, float*, float*, float*)
352    m.emit_type_function(
353        fn_ty,
354        b.ty_void,
355        &[
356            b.ty_ptr_cross_float,
357            b.ty_ptr_cross_float,
358            b.ty_ptr_cross_float,
359            b.ty_ptr_cross_float,
360        ],
361    );
362
363    m.emit_entry_point(EXECUTION_MODEL_KERNEL, main_fn, "main", &[b.var_gid]);
364    m.emit_execution_mode_local_size(main_fn, WORKGROUP_SIZE, 1, 1);
365
366    // Constants
367    let c_batch_heads = m.alloc_id();
368    m.emit_constant_u32(b.ty_uint, c_batch_heads, batch_heads);
369    let c_seq_q = m.alloc_id();
370    m.emit_constant_u32(b.ty_uint, c_seq_q, seq_q);
371    let c_seq_kv = m.alloc_id();
372    m.emit_constant_u32(b.ty_uint, c_seq_kv, seq_kv);
373    let c_head_dim = m.alloc_id();
374    m.emit_constant_u32(b.ty_uint, c_head_dim, head_dim);
375    let c_scale = m.alloc_id();
376    m.emit_constant_f32(b.ty_float, c_scale, scale);
377    let c_neg_inf = m.alloc_id();
378    m.emit_constant_f32(b.ty_float, c_neg_inf, f32::NEG_INFINITY);
379    // Stride: seq_kv * head_dim
380    let c_skv_hd = m.alloc_id();
381    m.emit_constant_u32(b.ty_uint, c_skv_hd, seq_kv * head_dim);
382
383    // Labels
384    let label_entry = m.alloc_id();
385    let label_body = m.alloc_id();
386    let label_merge = m.alloc_id();
387
388    m.emit_function(b.ty_void, main_fn, FUNCTION_CONTROL_NONE, fn_ty);
389    m.emit_function_parameter(b.ty_ptr_cross_float, p_q);
390    m.emit_function_parameter(b.ty_ptr_cross_float, p_k);
391    m.emit_function_parameter(b.ty_ptr_cross_float, p_v);
392    m.emit_function_parameter(b.ty_ptr_cross_float, p_o);
393    m.emit_label(label_entry);
394
395    let gid = load_gid_x(&mut m, &b);
396
397    // total = batch_heads * seq_q
398    let total = m.alloc_id();
399    m.emit(OP_I_MUL, &[b.ty_uint, total, c_batch_heads, c_seq_q]);
400
401    let cond = m.alloc_id();
402    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, gid, total]);
403    m.emit_selection_merge(label_merge);
404    m.emit_branch_conditional(cond, label_body, label_merge);
405
406    m.emit_label(label_body);
407
408    // bh = gid / seq_q, sq = gid % seq_q
409    let bh = m.alloc_id();
410    m.emit(OP_U_DIV, &[b.ty_uint, bh, gid, c_seq_q]);
411    let sq = m.alloc_id();
412    m.emit(OP_U_MOD, &[b.ty_uint, sq, gid, c_seq_q]);
413
414    // q_base = gid * head_dim
415    let q_base = m.alloc_id();
416    m.emit(OP_I_MUL, &[b.ty_uint, q_base, gid, c_head_dim]);
417    // kv_base = bh * seq_kv * head_dim
418    let kv_base = m.alloc_id();
419    m.emit(OP_I_MUL, &[b.ty_uint, kv_base, bh, c_skv_hd]);
420
421    // var_max_score
422    let var_max = m.alloc_id();
423    m.emit_variable(b.ty_ptr_func_float, var_max, STORAGE_CLASS_FUNCTION);
424    m.emit_store(var_max, c_neg_inf);
425
426    // ── Pass 1: find max score ──
427    emit_score_pass(
428        &mut m, &b, causal, sq, c_seq_kv, c_head_dim, c_scale, q_base, kv_base, p_q, p_k, var_max,
429        true, None, None, p_v, p_o,
430    );
431
432    let final_max = m.alloc_id();
433    m.emit_load(b.ty_float, final_max, var_max);
434
435    // ── Pass 2: accumulate exp-weighted V ──
436    let var_sum_exp = m.alloc_id();
437    m.emit_variable(b.ty_ptr_func_float, var_sum_exp, STORAGE_CLASS_FUNCTION);
438    m.emit_store(var_sum_exp, b.c_float_0);
439
440    emit_score_pass(
441        &mut m,
442        &b,
443        causal,
444        sq,
445        c_seq_kv,
446        c_head_dim,
447        c_scale,
448        q_base,
449        kv_base,
450        p_q,
451        p_k,
452        var_sum_exp,
453        false,
454        Some(final_max),
455        Some(p_o),
456        p_v,
457        p_o,
458    );
459
460    // Normalize: O[o_base+d] /= sum_exp if sum_exp > 0
461    let sum_final = m.alloc_id();
462    m.emit_load(b.ty_float, sum_final, var_sum_exp);
463
464    let sum_gt_zero = m.alloc_id();
465    m.emit(
466        OP_F_ORD_GT,
467        &[b.ty_bool, sum_gt_zero, sum_final, b.c_float_0],
468    );
469
470    let lbl_norm = m.alloc_id();
471    let lbl_norm_merge = m.alloc_id();
472    m.emit_selection_merge(lbl_norm_merge);
473    m.emit_branch_conditional(sum_gt_zero, lbl_norm, lbl_norm_merge);
474
475    m.emit_label(lbl_norm);
476
477    // Normalize loop
478    let var_d4 = m.alloc_id();
479    m.emit_variable(b.ty_ptr_func_uint, var_d4, STORAGE_CLASS_FUNCTION);
480    m.emit_store(var_d4, b.c_uint_0);
481
482    let lbl_d4_hdr = m.alloc_id();
483    let lbl_d4_body = m.alloc_id();
484    let lbl_d4_cont = m.alloc_id();
485    let lbl_d4_merge = m.alloc_id();
486
487    m.emit_branch(lbl_d4_hdr);
488
489    m.emit_label(lbl_d4_hdr);
490    let d4_val = m.alloc_id();
491    m.emit_load(b.ty_uint, d4_val, var_d4);
492    let d4_cond = m.alloc_id();
493    m.emit(OP_U_LESS_THAN, &[b.ty_bool, d4_cond, d4_val, c_head_dim]);
494    m.emit_loop_merge(lbl_d4_merge, lbl_d4_cont);
495    m.emit_branch_conditional(d4_cond, lbl_d4_body, lbl_d4_merge);
496
497    m.emit_label(lbl_d4_body);
498    let o4_idx = m.alloc_id();
499    m.emit(OP_I_ADD, &[b.ty_uint, o4_idx, q_base, d4_val]);
500    let o4_ptr = m.alloc_id();
501    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, o4_ptr, p_o, o4_idx);
502    let o4_val = m.alloc_id();
503    m.emit_load(b.ty_float, o4_val, o4_ptr);
504    let o4_normed = m.alloc_id();
505    m.emit(OP_F_DIV, &[b.ty_float, o4_normed, o4_val, sum_final]);
506    m.emit_store(o4_ptr, o4_normed);
507
508    m.emit_branch(lbl_d4_cont);
509    m.emit_label(lbl_d4_cont);
510    let d4_inc = m.alloc_id();
511    m.emit(OP_I_ADD, &[b.ty_uint, d4_inc, d4_val, b.c_uint_1]);
512    m.emit_store(var_d4, d4_inc);
513    m.emit_branch(lbl_d4_hdr);
514
515    m.emit_label(lbl_d4_merge);
516
517    m.emit_branch(lbl_norm_merge);
518    m.emit_label(lbl_norm_merge);
519
520    m.emit_branch(label_merge);
521
522    m.emit_label(label_merge);
523    m.emit_return();
524    m.emit_function_end();
525
526    m.finalize()
527}
528
529/// Emit a score-computation pass (used for both max-finding and accumulation).
530///
531/// When `is_max_pass` is true, updates `accum_var` with fmax(accum, score).
532/// When false, uses `max_val` to compute `exp(score - max)`, adds to `accum_var`,
533/// and accumulates weighted V into `o_buf`.
534#[allow(clippy::too_many_arguments)]
535fn emit_score_pass(
536    m: &mut SpvModule,
537    b: &super::spirv::BaseIds,
538    causal: bool,
539    sq: u32,
540    c_seq_kv: u32,
541    c_head_dim: u32,
542    c_scale: u32,
543    q_base: u32,
544    kv_base: u32,
545    p_q: u32,
546    p_k: u32,
547    accum_var: u32,
548    is_max_pass: bool,
549    max_val: Option<u32>,
550    o_buf: Option<u32>,
551    p_v: u32,
552    _p_o_unused: u32,
553) {
554    let var_sk = m.alloc_id();
555    m.emit_variable(b.ty_ptr_func_uint, var_sk, STORAGE_CLASS_FUNCTION);
556    m.emit_store(var_sk, b.c_uint_0);
557
558    let lbl_hdr = m.alloc_id();
559    let lbl_body = m.alloc_id();
560    let lbl_cont = m.alloc_id();
561    let lbl_merge = m.alloc_id();
562
563    m.emit_branch(lbl_hdr);
564
565    m.emit_label(lbl_hdr);
566    let sk_val = m.alloc_id();
567    m.emit_load(b.ty_uint, sk_val, var_sk);
568    let cond = m.alloc_id();
569    m.emit(OP_U_LESS_THAN, &[b.ty_bool, cond, sk_val, c_seq_kv]);
570    m.emit_loop_merge(lbl_merge, lbl_cont);
571    m.emit_branch_conditional(cond, lbl_body, lbl_merge);
572
573    m.emit_label(lbl_body);
574
575    let lbl_compute = m.alloc_id();
576    let lbl_skip = m.alloc_id();
577    if causal {
578        let sk_gt_sq = m.alloc_id();
579        m.emit(OP_U_LESS_THAN, &[b.ty_bool, sk_gt_sq, sq, sk_val]);
580        m.emit_selection_merge(lbl_skip);
581        m.emit_branch_conditional(sk_gt_sq, lbl_skip, lbl_compute);
582    } else {
583        m.emit_branch(lbl_compute);
584    }
585
586    m.emit_label(lbl_compute);
587
588    // k_off = kv_base + sk * head_dim
589    let sk_hd = m.alloc_id();
590    m.emit(OP_I_MUL, &[b.ty_uint, sk_hd, sk_val, c_head_dim]);
591    let k_off = m.alloc_id();
592    m.emit(OP_I_ADD, &[b.ty_uint, k_off, kv_base, sk_hd]);
593
594    // Inner dot product loop
595    let var_d = m.alloc_id();
596    m.emit_variable(b.ty_ptr_func_uint, var_d, STORAGE_CLASS_FUNCTION);
597    m.emit_store(var_d, b.c_uint_0);
598    let var_dot = m.alloc_id();
599    m.emit_variable(b.ty_ptr_func_float, var_dot, STORAGE_CLASS_FUNCTION);
600    m.emit_store(var_dot, b.c_float_0);
601
602    let lbl_d_hdr = m.alloc_id();
603    let lbl_d_body = m.alloc_id();
604    let lbl_d_cont = m.alloc_id();
605    let lbl_d_merge = m.alloc_id();
606
607    m.emit_branch(lbl_d_hdr);
608
609    m.emit_label(lbl_d_hdr);
610    let d_val = m.alloc_id();
611    m.emit_load(b.ty_uint, d_val, var_d);
612    let d_cond = m.alloc_id();
613    m.emit(OP_U_LESS_THAN, &[b.ty_bool, d_cond, d_val, c_head_dim]);
614    m.emit_loop_merge(lbl_d_merge, lbl_d_cont);
615    m.emit_branch_conditional(d_cond, lbl_d_body, lbl_d_merge);
616
617    m.emit_label(lbl_d_body);
618    let q_idx = m.alloc_id();
619    m.emit(OP_I_ADD, &[b.ty_uint, q_idx, q_base, d_val]);
620    let q_ptr = m.alloc_id();
621    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, q_ptr, p_q, q_idx);
622    let q_val = m.alloc_id();
623    m.emit_load(b.ty_float, q_val, q_ptr);
624    let k_idx = m.alloc_id();
625    m.emit(OP_I_ADD, &[b.ty_uint, k_idx, k_off, d_val]);
626    let k_ptr = m.alloc_id();
627    m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, k_ptr, p_k, k_idx);
628    let k_val = m.alloc_id();
629    m.emit_load(b.ty_float, k_val, k_ptr);
630
631    let qk_prod = m.alloc_id();
632    m.emit(OP_F_MUL, &[b.ty_float, qk_prod, q_val, k_val]);
633    let old_dot = m.alloc_id();
634    m.emit_load(b.ty_float, old_dot, var_dot);
635    let new_dot = m.alloc_id();
636    m.emit(OP_F_ADD, &[b.ty_float, new_dot, old_dot, qk_prod]);
637    m.emit_store(var_dot, new_dot);
638
639    m.emit_branch(lbl_d_cont);
640    m.emit_label(lbl_d_cont);
641    let d_inc = m.alloc_id();
642    m.emit(OP_I_ADD, &[b.ty_uint, d_inc, d_val, b.c_uint_1]);
643    m.emit_store(var_d, d_inc);
644    m.emit_branch(lbl_d_hdr);
645
646    m.emit_label(lbl_d_merge);
647
648    // score = dot * scale
649    let dot_final = m.alloc_id();
650    m.emit_load(b.ty_float, dot_final, var_dot);
651    let score = m.alloc_id();
652    m.emit(OP_F_MUL, &[b.ty_float, score, dot_final, c_scale]);
653
654    if is_max_pass {
655        // accum = fmax(accum, score)
656        let old_acc = m.alloc_id();
657        m.emit_load(b.ty_float, old_acc, accum_var);
658        let new_acc = m.alloc_id();
659        m.emit_opencl_ext(
660            b.opencl_ext,
661            b.ty_float,
662            new_acc,
663            OPENCL_FMAX,
664            &[old_acc, score],
665        );
666        m.emit_store(accum_var, new_acc);
667    } else {
668        // w = exp(score - max_score)
669        let max_id = max_val.unwrap_or(b.c_float_0);
670        let score_shifted = m.alloc_id();
671        m.emit(OP_F_SUB, &[b.ty_float, score_shifted, score, max_id]);
672        let w = m.alloc_id();
673        m.emit_opencl_ext(b.opencl_ext, b.ty_float, w, OPENCL_EXP, &[score_shifted]);
674
675        // sum_exp += w
676        let old_sum = m.alloc_id();
677        m.emit_load(b.ty_float, old_sum, accum_var);
678        let new_sum = m.alloc_id();
679        m.emit(OP_F_ADD, &[b.ty_float, new_sum, old_sum, w]);
680        m.emit_store(accum_var, new_sum);
681
682        // Accumulate weighted V
683        if let Some(o_buf_id) = o_buf {
684            let v_off = m.alloc_id();
685            m.emit(OP_I_ADD, &[b.ty_uint, v_off, kv_base, sk_hd]);
686
687            let var_d3 = m.alloc_id();
688            m.emit_variable(b.ty_ptr_func_uint, var_d3, STORAGE_CLASS_FUNCTION);
689            m.emit_store(var_d3, b.c_uint_0);
690
691            let lbl_d3_hdr = m.alloc_id();
692            let lbl_d3_body = m.alloc_id();
693            let lbl_d3_cont = m.alloc_id();
694            let lbl_d3_merge = m.alloc_id();
695
696            m.emit_branch(lbl_d3_hdr);
697
698            m.emit_label(lbl_d3_hdr);
699            let d3_val = m.alloc_id();
700            m.emit_load(b.ty_uint, d3_val, var_d3);
701            let d3_cond = m.alloc_id();
702            m.emit(OP_U_LESS_THAN, &[b.ty_bool, d3_cond, d3_val, c_head_dim]);
703            m.emit_loop_merge(lbl_d3_merge, lbl_d3_cont);
704            m.emit_branch_conditional(d3_cond, lbl_d3_body, lbl_d3_merge);
705
706            m.emit_label(lbl_d3_body);
707            let v_idx = m.alloc_id();
708            m.emit(OP_I_ADD, &[b.ty_uint, v_idx, v_off, d3_val]);
709            let v_ptr = m.alloc_id();
710            m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, v_ptr, p_v, v_idx);
711            let v_val = m.alloc_id();
712            m.emit_load(b.ty_float, v_val, v_ptr);
713            let wv = m.alloc_id();
714            m.emit(OP_F_MUL, &[b.ty_float, wv, w, v_val]);
715
716            let o_idx = m.alloc_id();
717            m.emit(OP_I_ADD, &[b.ty_uint, o_idx, q_base, d3_val]);
718            let o_ptr = m.alloc_id();
719            m.emit_in_bounds_ptr_access_chain(b.ty_ptr_cross_float, o_ptr, o_buf_id, o_idx);
720            let o_old = m.alloc_id();
721            m.emit_load(b.ty_float, o_old, o_ptr);
722            let o_new = m.alloc_id();
723            m.emit(OP_F_ADD, &[b.ty_float, o_new, o_old, wv]);
724            m.emit_store(o_ptr, o_new);
725
726            m.emit_branch(lbl_d3_cont);
727            m.emit_label(lbl_d3_cont);
728            let d3_inc = m.alloc_id();
729            m.emit(OP_I_ADD, &[b.ty_uint, d3_inc, d3_val, b.c_uint_1]);
730            m.emit_store(var_d3, d3_inc);
731            m.emit_branch(lbl_d3_hdr);
732
733            m.emit_label(lbl_d3_merge);
734        }
735    }
736
737    m.emit_branch(lbl_skip);
738    m.emit_label(lbl_skip);
739    m.emit_branch(lbl_cont);
740
741    m.emit_label(lbl_cont);
742    let sk_inc = m.alloc_id();
743    m.emit(OP_I_ADD, &[b.ty_uint, sk_inc, sk_val, b.c_uint_1]);
744    m.emit_store(var_sk, sk_inc);
745    m.emit_branch(lbl_hdr);
746
747    m.emit_label(lbl_merge);
748}
749
750// ─── Tests ──────────────────────────────────────────────────
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755    use crate::spirv::SPIRV_MAGIC;
756
757    fn check_valid_spirv(words: &[u32]) {
758        assert!(words.len() >= 5, "too short for SPIR-V header");
759        assert_eq!(words[0], SPIRV_MAGIC, "bad magic");
760        assert!(words[3] > 0, "ID bound must be > 0");
761        assert_eq!(words[4], 0, "schema must be 0");
762    }
763
764    #[test]
765    fn conv2d_spirv_valid() {
766        let words = conv2d_spirv(1, 3, 8, 8, 16, 3, 3, 6, 6, 1, 1, 0, 0);
767        check_valid_spirv(&words);
768    }
769
770    #[test]
771    fn conv2d_spirv_with_padding() {
772        let words = conv2d_spirv(2, 1, 5, 5, 4, 3, 3, 5, 5, 1, 1, 1, 1);
773        check_valid_spirv(&words);
774    }
775
776    #[test]
777    fn conv2d_spirv_1x1() {
778        let words = conv2d_spirv(1, 3, 4, 4, 8, 1, 1, 4, 4, 1, 1, 0, 0);
779        check_valid_spirv(&words);
780    }
781
782    #[test]
783    fn attention_spirv_valid() {
784        let words = attention_spirv(2, 4, 4, 8, 0.125, false);
785        check_valid_spirv(&words);
786    }
787
788    #[test]
789    fn attention_spirv_causal() {
790        let words = attention_spirv(1, 8, 8, 16, 0.25, true);
791        check_valid_spirv(&words);
792    }
793
794    #[test]
795    fn attention_spirv_magic_number() {
796        let words = attention_spirv(1, 4, 4, 8, 0.125, false);
797        assert_eq!(words[0], 0x07230203);
798    }
799
800    #[test]
801    fn conv2d_spirv_magic_number() {
802        let words = conv2d_spirv(1, 1, 4, 4, 1, 1, 1, 4, 4, 1, 1, 0, 0);
803        assert_eq!(words[0], 0x07230203);
804    }
805}